Skip to content

Commit

Permalink
add particle inflation functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
costachris committed Feb 10, 2023
1 parent 26306e5 commit f9df438
Show file tree
Hide file tree
Showing 7 changed files with 288 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ pages = [
"Prior distributions" => "parameter_distributions.md",
"Internal data representation" => "internal_data_representation.md",
"Localization and SEC" => "localization.md",
"Inflation" => "inflation.md",
"Parallelism and HPC" => "parallel_hpc.md",
"Observations" => "observations.md",
"API" => api,
Expand Down
60 changes: 60 additions & 0 deletions docs/src/inflation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Inflation
Inflation is an approach that slows down collapse in ensemble Kalman methods.
Two distinct forms of inflation are implemented in this package. Both involve perturbing the ensemble members following the standard update rule of the chosen Kalman process.
Multiplicative inflation expands ensemble members away from their mean in a
deterministic manner, whereas additive inflation hinges on the addition of stochastic noise to ensemble members.

For both implementations, a scaling factor ``s`` is included to extend functionality to cases with mini-batching.
The scaling factor ``s`` multiplies the artificial time step ``\Delta t`` in the inflation equations to account for sampling error. For mini-batching, the scaling factor should be:
```math
s = \frac{|B|}{|C|}
```
where `` |B| `` is the mini-batch size and `` |C| `` is the full dataset size.

## Multiplicative Inflation
Multiplicative inflation effectively scales parameter vectors in parameter space, such that the perturbed
ensemble remains in the linear span of the original ensemble. The implemented update equation follows
[Huang et al, 2022](https://arxiv.org/abs/2204.04386) eqn. 41:

```math
\begin{aligned}
m_{n+1} = m_{n} ; \qquad u^{j}_{n + 1} = m_{n+1} + \sqrt{\frac{1}{1 - s \Delta{t}}} \left(u^{j}_{n} - m_{n} \right) \qquad (1)
\end{aligned}
```
where ``m`` is the ensemble average. In this way,
the parameter covariance is inflated by a factor of ``\frac{1}{1 - s \Delta{t}}``, while the ensemble mean remains fixed.
```math
C_{n + 1} = \frac{1}{1 - s \Delta{t}} C_{n} \qquad (2)
```

Multiplicative inflation can be used by flagging the `update_ensemble!` method as follows:
```julia
EKP.update_ensemble!(ekiobj, g_ens; multiplicative_inflation = true, s = 1.0)
```

## Additive Inflation
Additive inflation is implemented by systematically adding stochastic perturbations to the parameter ensemble in the form of Gaussian noise. Additive inflation breaks the linear subspace property, meaning the parameter ensemble can evolve outside of the span of the initial ensemble. In additive inflation, the ensemble is perturbed in the following manner after the standard Kalman update:

```math
u_{n+1} = u_n + \zeta_{n} \qquad (3) \\
\zeta_{n} \sim N(0, \frac{s \Delta{t} }{1 - s \Delta{t}} C_n) \qquad (4)
```
This inflates the parameter covariance by a factor of ``\frac{1}{1 - s \Delta{t}}`` as in eqn. 2 , while the ensemble mean remains fixed.

Additive inflation can be used by flagging the `update_ensemble!` method as follows:
```julia
EKP.update_ensemble!(ekiobj, g_ens; additive_inflation = true, s = 1.0)
```
Alternatively, the prior covariance matrix may be used to generate additive noise, following:
```math
\zeta_{n} \sim N(0, \frac{s \Delta{t} }{1 - s \Delta{t}} C_{0}) \qquad (5)
```
This results in an additive increase in the parameter covariance by `` \frac{s \Delta{t} }{1 - s \Delta{t}} * C_{0}`` , while the mean remains fixed.
```math
C_{n + 1} = C_{n} + \frac{s \Delta{t} }{1 - s \Delta{t}} C_{0} \qquad (6)
```

Additive inflation using the scaled prior covariance (parameter covariance of initial ensemble) can be used by flagging the `update_ensemble!` method as follows:
```julia
EKP.update_ensemble!(ekiobj, g_ens; additive_inflation = true, use_prior_cov = true, s = 1.0)
```
104 changes: 101 additions & 3 deletions src/EnsembleKalmanProcess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ export get_u, get_g, get_ϕ
export get_u_prior, get_u_final, get_g_final, get_ϕ_final
export get_N_iterations, get_error, get_cov_blocks
export get_u_mean, get_u_cov, get_g_mean, get_ϕ_mean
export get_u_mean_final, get_u_cov_final, get_g_mean_final, get_ϕ_mean_final
export get_u_mean_final, get_u_cov_prior, get_u_cov_final, get_g_mean_final, get_ϕ_mean_final
export compute_error!
export update_ensemble!
export sample_empirical_gaussian, split_indices_by_success
Expand Down Expand Up @@ -238,6 +238,15 @@ function get_u_cov(ekp::EnsembleKalmanProcess, iteration::IT) where {IT <: Integ
return cov(u, u, dims = 2)
end

"""
get_u_cov_prior(ekp::EnsembleKalmanProcess)
Returns the unconstrained parameter sample covariance for the initial ensemble.
"""
function get_u_cov_prior(ekp::EnsembleKalmanProcess)
return cov(get_u_prior(ekp), get_u_prior(ekp), dims = 2)
end

"""
get_g_mean(ekp::EnsembleKalmanProcess, iteration::IT) where {IT <: Integer}
Expand Down Expand Up @@ -300,7 +309,7 @@ get_ϕ_mean_final(prior::ParameterDistribution, ekp::EnsembleKalmanProcess) =
"""
get_u_cov_final(ekp::EnsembleKalmanProcess)
Get the mean unconstrained parameter at the last iteration.
Get the mean unconstrained parameter covariance at the last iteration.
"""
get_u_cov_final(ekp::EnsembleKalmanProcess) = get_u_cov(ekp, size(ekp.u, 1))

Expand Down Expand Up @@ -328,6 +337,15 @@ function get_process(ekp::EnsembleKalmanProcess)
return ekp.process
end

"""
get_localizer(ekp::EnsembleKalmanProcess)
Return localizer type of EnsembleKalmanProcess.
"""
function get_localizer(ekp::EnsembleKalmanProcess)
return Localizers.get_localizer(ekp.localizer)
end


"""
construct_initial_ensemble(
rng::AbstractRNG,
Expand Down Expand Up @@ -448,20 +466,100 @@ function get_cov_blocks(cov::AbstractMatrix{FT}, p::IT) where {FT <: Real, IT <:
return uu_cov, ug_cov, gg_cov
end

"""
multiplicative_inflation!(
ekp::EnsembleKalmanProcess;
s,
) where {FT, IT}
Applies multiplicative noise to particles.
Inputs:
- ekp :: The EnsembleKalmanProcess to update.
- s :: Scaling factor for time step in multiplicative perturbation.
"""
function multiplicative_inflation!(ekp::EnsembleKalmanProcess; s::FT = 1.0) where {FT <: Real}

scaled_Δt = s * ekp.Δt[end]

if scaled_Δt >= 1.0
error(string("Scaled time step: ", scaled_Δt, " is >= 1.0", "\nChange s or EK time step."))
end

u = get_u_final(ekp)
u_mean = get_u_mean_final(ekp)
prefactor = sqrt(1 / (1 - scaled_Δt))
u_updated = u_mean .+ prefactor * (u .- u_mean)
ekp.u[end] = DataContainer(u_updated, data_are_columns = true)

end

"""
additive_inflation!(
ekp::EnsembleKalmanProcess;
use_prior_cov::Bool = false,
s::FT = 1.0,
) where {FT <: Real}
Applies additive Gaussian noise to particles. Noise is drawn from normal distribution with 0 mean
and scaled parameter covariance. If use_prior_cov=false (default), scales parameter covariance matrix from
current ekp iteration. Otherwise, scales parameter covariance of initial ensemble.
Inputs:
- ekp :: The EnsembleKalmanProcess to update.
- s :: Scaling factor for time step in additive perturbation.
- use_prior_cov :: Bool specifying whether to use prior covariance estimate for additive inflation.
If false (default), parameter covariance from the current iteration is used.
"""
function additive_inflation!(ekp::EnsembleKalmanProcess; use_prior_cov::Bool = false, s::FT = 1.0) where {FT <: Real}

scaled_Δt = s * ekp.Δt[end]

if scaled_Δt >= 1.0
error(string("Scaled time step: ", scaled_Δt, " is >= 1.0", "\nChange s or EK time step."))
end

Σ = use_prior_cov ? get_u_cov_prior(ekp) : get_u_cov_final(ekp)

u = get_u_final(ekp)
# add multivariate noise with 0 mean and scaled covariance
noise_multivariate = MvNormal((scaled_Δt / (1 - scaled_Δt)) .* Σ)
u_updated = u + rand(noise_multivariate, size(u, 2))
ekp.u[end] = DataContainer(u_updated, data_are_columns = true)
end

"""
update_ensemble!(
ekp::EnsembleKalmanProcess,
g::AbstractMatrix{FT};
multiplicative_inflation::Bool = false,
additive_inflation::Bool = false,
use_prior_cov::Bool = false,
s::FT = 0.0,
ekp_kwargs...,
) where {FT, IT}
Updates the ensemble according to an Inversion process.
Inputs:
- ekp :: The EnsembleKalmanProcess to update.
- g :: Model outputs, they need to be stored as a `N_obs × N_ens` array (i.e data are columms).
- multiplicative_inflation :: Flag indicating whether to use multiplicative inflation.
- additive_inflation :: Flag indicating whether to use additive inflation.
- use_prior_cov :: Bool specifying whether to use prior covariance estimate for additive inflation.
If false (default), parameter covariance from the current iteration is used.
- s :: Scaling factor for time step in inflation step.
- ekp_kwargs :: Keyword arguments to pass to standard ekp update_ensemble!.
"""
function update_ensemble!(ekp::EnsembleKalmanProcess, g::AbstractMatrix{FT}; ekp_kwargs...) where {FT, IT}
function update_ensemble!(
ekp::EnsembleKalmanProcess,
g::AbstractMatrix{FT};
multiplicative_inflation::Bool = false,
additive_inflation::Bool = false,
use_prior_cov::Bool = false,
s::FT = 0.0,
ekp_kwargs...,
) where {FT, IT}

update_ensemble!(ekp, g, get_process(ekp); ekp_kwargs...)
if s > 0.0
multiplicative_inflation ? multiplicative_inflation!(ekp; s = s) : nothing
additive_inflation ? additive_inflation!(ekp; use_prior_cov = use_prior_cov, s = s) : nothing
end
end

## include the different types of Processes and their exports:
Expand Down
8 changes: 8 additions & 0 deletions src/Localizers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,4 +239,12 @@ function Localizer(localization::SECFisher, p::IT, d::IT, J::IT, T = Float64) wh
return Localizer{SECFisher, T}((cov) -> sec_fisher(cov, J))
end

"""
get_localizer(loc::Localizer)
Return localizer type.
"""
function get_localizer(loc::Localizer{T1, T2}) where {T1, T2}
return T1
end

end # module
115 changes: 115 additions & 0 deletions test/Inflation/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
using Distributions
using LinearAlgebra
using Random
using Test

using EnsembleKalmanProcesses
using EnsembleKalmanProcesses.ParameterDistributions
using EnsembleKalmanProcesses.Localizers
import EnsembleKalmanProcesses: construct_mean, construct_cov, construct_sigma_ensemble
const EKP = EnsembleKalmanProcesses


# Read inverse problem definitions
include("../EnsembleKalmanProcess/inverse_problem.jl")

n_obs = 30 # dimension of synthetic observation from G(u)
ϕ_star = [-1.0, 2.0] # True parameters in constrained space
n_par = size(ϕ_star, 1)
noise_level = 0.1 # Defining the observation noise level (std)
N_ens = 1000 # number of ensemble members
N_iter = 1 # number of EKI iterations

obs_corrmat = Diagonal(Matrix(I, n_obs, n_obs))

prior_1 = Dict("distribution" => Parameterized(Normal(0.0, 0.5)), "constraint" => bounded(-2, 2), "name" => "cons_p")
prior_2 = Dict("distribution" => Parameterized(Normal(3.0, 0.5)), "constraint" => no_constraint(), "name" => "uncons_p")
prior = ParameterDistribution([prior_1, prior_2])
prior_cov = cov(prior)

rng_seed = 42
rng = Random.MersenneTwister(rng_seed)

initial_ensemble = EKP.construct_initial_ensemble(rng, prior, N_ens)

Δts = [0.5, 0.75]

@testset "Inflation" begin

ekiobj = nothing

for Δt_i in 1:length(Δts)
Δt = Δts[Δt_i]

# Get inverse problem
y_obs, G, Γy, A =
linear_inv_problem(ϕ_star, noise_level, n_obs, prior, rng; obs_corrmat = obs_corrmat, return_matrix = true)

ekiobj = EKP.EnsembleKalmanProcess(
initial_ensemble,
y_obs,
Γy,
Inversion();
Δt = Δt,
rng = rng,
failure_handler_method = SampleSuccGauss(),
)

g_ens = G(get_u_final(ekiobj))

# ensure error is thrown when scaled time step >= 1
@test_throws ErrorException EKP.update_ensemble!(ekiobj, g_ens; multiplicative_inflation = true, s = 3.0)
@test_throws ErrorException EKP.update_ensemble!(ekiobj, g_ens; additive_inflation = true, s = 3.0)

# EKI iterations
for i in 1:N_iter
# Check SampleSuccGauss handler
params_i = get_u_final(ekiobj)

g_ens = G(params_i)

# standard update
EKP.update_ensemble!(ekiobj, g_ens, EKP.get_process(ekiobj))
eki_mult_inflation = deepcopy(ekiobj)
eki_add_inflation = deepcopy(ekiobj)
eki_add_inflation_prior = deepcopy(ekiobj)

# multiplicative inflation after standard update
EKP.multiplicative_inflation!(eki_mult_inflation)
# additive inflation after standard update
EKP.additive_inflation!(eki_add_inflation)
# additive inflation (scaling prior cov) after standard update
EKP.additive_inflation!(eki_add_inflation_prior; use_prior_cov = true)

# ensure multiplicative inflation approximately preserves ensemble mean
@test get_u_mean_final(ekiobj) get_u_mean_final(eki_mult_inflation) atol = 0.1
# ensure additive inflation approximately preserves ensemble mean
@test get_u_mean_final(ekiobj) get_u_mean_final(eki_add_inflation) atol = 0.1
# ensure additive inflation (scaling prior cov) approximately preserves ensemble mean
@test get_u_mean_final(ekiobj) get_u_mean_final(eki_add_inflation_prior) atol = 0.1

# ensure inflation expands ensemble variance as expected
expected_var_gain = 1 / (1 - Δt)
@test get_u_cov_final(ekiobj) .* expected_var_gain get_u_cov_final(eki_mult_inflation) atol = 1e-3
# implemented default additive, multiplicative inflation have same effect on ensemble covariance
@test get_u_cov_final(eki_add_inflation) get_u_cov_final(eki_mult_inflation) atol = 1e-3
# ensure additive inflation with prior affects variance as expected
# note: we accept a higher relative tolerance here because the 2 parameter ensemble collapses
# note: quickly so the added noise (scaled from prior) is relatively large (difference eliminated with large ensemble)
@test get_u_cov_final(eki_add_inflation_prior) - get_u_cov_final(ekiobj)
(Δt * expected_var_gain) .* prior_cov atol = 0.2

# ensure inflation is only added in final iteration
u_standard = get_u(ekiobj)
u_mult_inflation = get_u(eki_mult_inflation)
u_add_inflation = get_u(eki_add_inflation)
@test u_standard[1:(end - 1)] == u_mult_inflation[1:(end - 1)]
@test u_standard[1:(end - 1)] == u_add_inflation[1:(end - 1)]
@test u_standard[end] != u_mult_inflation[end]
@test u_standard[end] != u_add_inflation[end]

end
end
# inflation update should not affect initial parameter ensemble (drawn from prior)
@test get_u_prior(ekiobj) == initial_ensemble
end
2 changes: 2 additions & 0 deletions test/Localizers/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ const EKP = EnsembleKalmanProcesses
cov_est = cov([u_final; g_final], [u_final; g_final], dims = 2, corrected = false)
cov_localized = ekiobj.localizer.localize(cov_est)
@test rank(cov_est) < rank(cov_localized)
# Test localization getter method
@test isa(loc_method, EKP.get_localizer(ekiobj))
end

end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ end
"Localizers",
"TOMLInterface",
"SparseInversion",
"Inflation",
]
if all_tests || has_submodule(submodule) || "EnsembleKalmanProcesses" in ARGS
include_test(submodule)
Expand Down

0 comments on commit f9df438

Please sign in to comment.