Skip to content

Commit

Permalink
add particle inflation functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
costachris committed Feb 2, 2023
1 parent 5109b7f commit 8b71248
Show file tree
Hide file tree
Showing 7 changed files with 250 additions and 1 deletion.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SCS = "c946c3f1-0d1f-5ce8-9dea-7daa1f7e2d13"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
Expand Down
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ pages = [
"Prior distributions" => "parameter_distributions.md",
"Internal data representation" => "internal_data_representation.md",
"Localization and SEC" => "localization.md",
"Inflation" => "inflation.md",
"Parallelism and HPC" => "parallel_hpc.md",
"Observations" => "observations.md",
"API" => api,
Expand Down
42 changes: 42 additions & 0 deletions docs/src/inflation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Inflation
Inflation is a approach that slows ensemble collapse in Kalman methods.
Two distinct forms of inflation are implemented, in which perturbations are applied immediately following the standard Kalman update.
Multiplicative inflation expands ensemble members away from their mean in a
deterministic manner, whereas additive inflation hinges on the addition of stochastic noise to ensemble members.
For both implementations, a scaling factor ``s`` is included to extend functionality to cases with mini-batching.
The scaling factor ``s`` multiplies the artificial time step ``\Delta t`` in the inflation equations to account for sampling error. For mini-batching, the scaling factor should be:
```math
s = \frac{|B|}{|C|}
```
where `` |B| `` is the mini-batch size and `` |C| `` is the full dataset size.

## Multiplicative Inflation
Multiplicative inflation effectively scales parameter vectors in parameter space, such that the perturbed
ensemble remains in the linear span of the original ensemble. The implemented update equation follows
[Huang et al, 2022](https://arxiv.org/abs/2204.04386) eqn. 41:

```math
\begin{aligned}
m_{n+1} = m_{n} ; \qquad u^{j}_{n + 1} = m_{n+1} + \sqrt{\frac{1}{1 - s \Delta{t}}} \left(u^{j}_{n} - m_{n} \right) \qquad (1)
\end{aligned}
```
where ``m`` is the ensemble average. In this way,
the variance across parameter vector magnitudes is increased by a factor of ``\frac{1}{1 - s \Delta{t}}``, while the mean remains fixed.

Multiplicative inflation can be used by flagging the `update_ensemble!` method as follows:
```julia
EKP.update_ensemble!(ekiobj, g_ens; multiplicative_inflation = true, s = 1.0)
```

## Additive Inflation
Additive inflation is implemented by systematically adding stochastic perturbations to the parameter ensemble in the form of gaussian noise. Additive inflation breaks the linear subspace property, meaning the parameter ensemble can evolve outside of the span of the initial ensemble. In additive inflation, the ensemble is perturbed in the following manner after the standard Kalman update:

```math
u_{n+1} = u_n + \zeta_{n} \qquad (2) \\
\zeta_{n} \sim N(0, \frac{s \Delta{t} }{1 - s \Delta{t}} C_n) \qquad (3)
```
This increases the variance across vector components by `` \frac{s \Delta{t} }{1 - s \Delta{t}} ``, while the mean remains fixed.
Additive inflation can be used by flagging the `update_ensemble!` method as follows:
```julia
EKP.update_ensemble!(ekiobj, g_ens; additive_inflation = true, s = 1.0)
```
79 changes: 78 additions & 1 deletion src/EnsembleKalmanProcess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,15 @@ function get_process(ekp::EnsembleKalmanProcess)
return ekp.process
end

"""
get_localizer(ekp::EnsembleKalmanProcess)
Return localizer type of EnsembleKalmanProcess.
"""
function get_localizer(ekp::EnsembleKalmanProcess)
return Localizers.get_localizer(ekp.localizer)
end


"""
construct_initial_ensemble(
rng::AbstractRNG,
Expand Down Expand Up @@ -448,20 +457,88 @@ function get_cov_blocks(cov::AbstractMatrix{FT}, p::IT) where {FT <: Real, IT <:
return uu_cov, ug_cov, gg_cov
end

"""
multiplicative_inflation!(
ekp::EnsembleKalmanProcess;
s,
) where {FT, IT}
Applies multiplicative noise to particles.
Inputs:
- ekp :: The EnsembleKalmanProcess to update.
- s :: Scaling factor for timestep in multiplicative perturbation.
"""
function multiplicative_inflation!(ekp::EnsembleKalmanProcess; s::FT = 1.0) where {FT <: Real}

scaled_Δt = s * ekp.Δt[end]

if scaled_Δt >= 1.0
error(string("Scaled time step: ", scaled_Δt, " is >= 1.0", "\nChange s or EK time step."))
end

u = get_u_final(ekp)
u_mean = get_u_mean_final(ekp)
prefactor = sqrt(1 / (1 - scaled_Δt))
u_updated = u_mean .+ prefactor * (u .- u_mean)
ekp.u[end] = DataContainer(u_updated, data_are_columns = true)

end

"""
additive_inflation!(
ekp::EnsembleKalmanProcess;
s,
) where {FT, IT}
Applies additive noise to particles.
Inputs:
- ekp :: The EnsembleKalmanProcess to update.
- s :: Scaling factor for timestep in additive perturbation.
"""
function additive_inflation!(ekp::EnsembleKalmanProcess; s::FT = 1.0) where {FT <: Real}

scaled_Δt = s * ekp.Δt[end]

if scaled_Δt >= 1.0
error(string("Scaled time step: ", scaled_Δt, " is >= 1.0", "\nChange s or EK time step."))
end

u = get_u_final(ekp)
noise_std = sqrt(scaled_Δt / (1 - scaled_Δt))
u_updated = u + randn(size(u)) * noise_std
ekp.u[end] = DataContainer(u_updated, data_are_columns = true)
end

"""
update_ensemble!(
ekp::EnsembleKalmanProcess,
g::AbstractMatrix{FT};
multiplicative_inflation::Bool = false,
additive_inflation::Bool = false,
s::FT = 0.0,
ekp_kwargs...,
) where {FT, IT}
Updates the ensemble according to an Inversion process.
Inputs:
- ekp :: The EnsembleKalmanProcess to update.
- g :: Model outputs, they need to be stored as a `N_obs × N_ens` array (i.e data are columms).
- multiplicative_inflation :: Flag indicating whether to use multiplicative inflation.
- additive_inflation :: Flag indicating whether to use additive inflation.
- s :: Scaling factor for time step in inflation step.
- ekp_kwargs :: Keyword arguments to pass to standard ekp update_ensemble!.
"""
function update_ensemble!(ekp::EnsembleKalmanProcess, g::AbstractMatrix{FT}; ekp_kwargs...) where {FT, IT}
function update_ensemble!(
ekp::EnsembleKalmanProcess,
g::AbstractMatrix{FT};
multiplicative_inflation::Bool = false,
additive_inflation::Bool = false,
s::FT = 0.0,
ekp_kwargs...,
) where {FT, IT}

update_ensemble!(ekp, g, get_process(ekp); ekp_kwargs...)
if s > 0.0
multiplicative_inflation ? multiplicative_inflation!(ekp; s = s) : nothing
additive_inflation ? additive_inflation!(ekp; s = s) : nothing
end
end

## include the different types of Processes and their exports:
Expand Down
8 changes: 8 additions & 0 deletions src/Localizers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,4 +239,12 @@ function Localizer(localization::SECFisher, p::IT, d::IT, J::IT, T = Float64) wh
return Localizer{SECFisher, T}((cov) -> sec_fisher(cov, J))
end

"""
get_localizer(loc::Localizer)
Return localizer type.
"""
function get_localizer(loc::Localizer{T1, T2}) where {T1, T2}
return T1
end

end # module
119 changes: 119 additions & 0 deletions test/Inflation/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
using Distributions
using LinearAlgebra
using Random
using Test

using EnsembleKalmanProcesses
using EnsembleKalmanProcesses.ParameterDistributions
using EnsembleKalmanProcesses.Localizers
import EnsembleKalmanProcesses: construct_mean, construct_cov, construct_sigma_ensemble
const EKP = EnsembleKalmanProcesses


# Read inverse problem definitions
include("../EnsembleKalmanProcess/inverse_problem.jl")

n_obs = 30 # dimension of synthetic observation from G(u)
ϕ_star = [-1.0, 2.0] # True parameters in constrained space
n_par = size(ϕ_star, 1)
noise_level = 0.1 # Defining the observation noise level (std)
N_ens = 10000 # number of ensemble members
N_iter = 1 # number of EKI iterations

obs_corrmat = Diagonal(Matrix(I, n_obs, n_obs))

prior_1 = Dict("distribution" => Parameterized(Normal(0.0, 0.5)), "constraint" => bounded(-2, 2), "name" => "cons_p")
prior_2 = Dict("distribution" => Parameterized(Normal(3.0, 0.5)), "constraint" => no_constraint(), "name" => "uncons_p")
prior = ParameterDistribution([prior_1, prior_2])
prior_mean = mean(prior)
prior_cov = cov(prior)

rng_seed = 42
rng = Random.MersenneTwister(rng_seed)

initial_ensemble = EKP.construct_initial_ensemble(rng, prior, N_ens)

Δts = [0.5, 0.75, 0.95]

@testset "Inflation" begin

ekiobj = nothing

for Δt_i in 1:length(Δts)
Δt = Δts[Δt_i]

# Get inverse problem
y_obs, G, Γy, A =
linear_inv_problem(ϕ_star, noise_level, n_obs, prior, rng; obs_corrmat = obs_corrmat, return_matrix = true)

ekiobj = EKP.EnsembleKalmanProcess(
initial_ensemble,
y_obs,
Γy,
Inversion();
Δt = Δt,
rng = rng,
failure_handler_method = SampleSuccGauss(),
)

g_ens = G(get_u_final(ekiobj))

# ensure error is thrown when scaled time step >= 1
@test_throws ErrorException EKP.update_ensemble!(ekiobj, g_ens; multiplicative_inflation = true, s = 3.0)
@test_throws ErrorException EKP.update_ensemble!(ekiobj, g_ens; additive_inflation = true, s = 3.0)

# EKI iterations
for i in 1:N_iter
# Check SampleSuccGauss handler
params_i = get_u_final(ekiobj)

g_ens = G(params_i)

# standard update
EKP.update_ensemble!(ekiobj, g_ens, EKP.get_process(ekiobj))
eki_mult_inflation = deepcopy(ekiobj)
eki_add_inflation = deepcopy(ekiobj)

# multiplicative inflation after standard update
EKP.multiplicative_inflation!(eki_mult_inflation)
# additive inflation after standard update
EKP.additive_inflation!(eki_add_inflation)

# ensure multiplicative inflation preserves ensemble mean
@test get_u_mean_final(ekiobj) get_u_mean_final(eki_mult_inflation) atol = 1e-10
# ensure additive inflation approximately preserves ensemble mean
@test get_u_mean_final(ekiobj) get_u_mean_final(eki_add_inflation) rtol = 0.02

# check if ensemble is inflated as expected
u_final_standard = get_u_final(ekiobj)
u_final_mult_inflation = get_u_final(eki_mult_inflation)
u_final_add_inflation = get_u_final(eki_add_inflation)
standard_update_var = var(u_final_standard)

# expected increase in variance (of vector components) after additive inflation
expected_additive_var = Δt / (1 - Δt)
# empirical increase in variance (of vector components) after additive inflation
empirical_additive_var_increase = var(u_final_add_inflation) - standard_update_var
@test empirical_additive_var_increase expected_additive_var rtol = 0.02

# variance of parameter vector magnitudes
standard_update_mag_var = var(norm.(eachcol(u_final_standard)))
mult_update_mag_var = var(norm.(eachcol(u_final_mult_inflation)))
# expected ratio of variance of parameter vector magnitudes after/before additive inflation
expected_additive_var_ratio = 1 / (1 - Δt)
@test expected_additive_var_ratio mult_update_mag_var / standard_update_mag_var rtol = 0.01

# ensure inflation is only added in final iteration
u_standard = get_u(ekiobj)
u_mult_inflation = get_u(eki_mult_inflation)
u_add_inflation = get_u(eki_add_inflation)
@test u_standard[1:(end - 1)] == u_mult_inflation[1:(end - 1)]
@test u_standard[1:(end - 1)] == u_add_inflation[1:(end - 1)]
@test u_standard[end] != u_mult_inflation[end]
@test u_standard[end] != u_add_inflation[end]

end
end
# inflation update should not affect initial parameter ensemble (drawn from prior)
@test get_u_prior(ekiobj) == initial_ensemble
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ end
"Localizers",
"TOMLInterface",
"SparseInversion",
"Inflation",
]
if all_tests || has_submodule(submodule) || "EnsembleKalmanProcesses" in ARGS
include_test(submodule)
Expand Down

0 comments on commit 8b71248

Please sign in to comment.