Skip to content

Commit

Permalink
add particle inflation functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
costachris committed Jan 23, 2023
1 parent 5109b7f commit 40d53ee
Show file tree
Hide file tree
Showing 7 changed files with 246 additions and 1 deletion.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SCS = "c946c3f1-0d1f-5ce8-9dea-7daa1f7e2d13"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
Expand Down
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ pages = [
"Prior distributions" => "parameter_distributions.md",
"Internal data representation" => "internal_data_representation.md",
"Localization and SEC" => "localization.md",
"Inflation" => "inflation.md",
"Parallelism and HPC" => "parallel_hpc.md",
"Observations" => "observations.md",
"API" => api,
Expand Down
45 changes: 45 additions & 0 deletions docs/src/inflation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Inflation
Inflation is a method which systematically adds stochastic perturbations to the parameter ensemble in the form of gaussian noise.
Following the standard Kalman update, the parameter ensemble is perturbed in the following manner:
```math
u_{n+1} = u_n + \zeta_{n} \qquad (1)
```

where `` \zeta_{n} `` is noise drawn from a gaussian distribution. Two distinct forms of inflation are implemented.
For both implementations, a scaling factor ``s`` is included to extend usage to cases with mini-batching. For mini-batching,
the scaling factor should be:
```math
s = \frac{|B|}{|C|}
```
where `` |B| `` is the mini-batch size and `` |C| `` is the full dataset size, to account for sampling error.

## Multiplicative Inflation
Multiplicative effectively scales parameter vectors in parameter space, such that the perturbed
ensemble remains in the linear span of the original ensemble. The implementated update equation follows
[Huang et al, 2022](https://arxiv.org/abs/2204.04386) eqn. 41:

```math
\begin{aligned}
m_{n+1} = m_{n} ; \qquad u^{j}_{n + 1} = m_{n+1} + \sqrt{\frac{1}{1 - s \Delta{t}}} \left(u^{j}_{n} - m_{n} \right) \qquad (2)
\end{aligned}
```
where ``m`` is the ensemble average which remains unchanged after the perturbation.

Multiplicative inflation can be used by flagging the `update_ensemble!` method as follows:
```julia
EKP.update_ensemble!(ekiobj, g_ens; multiplicative_inflation = true, s = 1.0)
```

## Additive Inflation
Additive inflation breaks the linear subspace property, meaning the parameter ensemble can evolve
outside of the span of the initial ensemble. In additive inflation, the update equation draws
noise from the following distribution.

```math
\zeta_{n} \sim N(0, \frac{s \Delta{t} }{1 - s \Delta{t}} C_n) \qquad (3)
```

Additive inflation can be used by flagging the `update_ensemble!` method as follows:
```julia
EKP.update_ensemble!(ekiobj, g_ens; additive_inflation = true, s = 1.0)
```
79 changes: 78 additions & 1 deletion src/EnsembleKalmanProcess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,15 @@ function get_process(ekp::EnsembleKalmanProcess)
return ekp.process
end

"""
get_localizer(ekp::EnsembleKalmanProcess)
Return localizer type of EnsembleKalmanProcess.
"""
function get_localizer(ekp::EnsembleKalmanProcess)
return Localizers.get_localizer(ekp.localizer)
end


"""
construct_initial_ensemble(
rng::AbstractRNG,
Expand Down Expand Up @@ -448,20 +457,88 @@ function get_cov_blocks(cov::AbstractMatrix{FT}, p::IT) where {FT <: Real, IT <:
return uu_cov, ug_cov, gg_cov
end

"""
multiplicative_inflation!(
ekp::EnsembleKalmanProcess;
s,
) where {FT, IT}
Applies multiplicative noise to particles.
Inputs:
- ekp :: The EnsembleKalmanProcess to update.
- s :: Scaling factor for timestep in multiplicative perturbation.
"""
function multiplicative_inflation!(ekp::EnsembleKalmanProcess; s::FT = 1.0) where {FT <: Real}

scaled_Δt = s * ekp.Δt[end]

if scaled_Δt >= 1.0
error(string("Scaled time step: ", scaled_Δt, " is >= 1.0", "\nChange s or EK time step."))
end

u = get_u_final(ekp)
u_mean = get_u_mean_final(ekp)
prefactor = sqrt(1 / (1 - scaled_Δt))
u_updated = u_mean .+ prefactor * (u .- u_mean)
ekp.u[end] = DataContainer(u_updated, data_are_columns = true)

end

"""
additive_inflation!(
ekp::EnsembleKalmanProcess;
s,
) where {FT, IT}
Applies additive noise to particles.
Inputs:
- ekp :: The EnsembleKalmanProcess to update.
- s :: Scaling factor for timestep in additive perturbation.
"""
function additive_inflation!(ekp::EnsembleKalmanProcess; s::FT = 1.0) where {FT <: Real}

scaled_Δt = s * ekp.Δt[end]

if scaled_Δt >= 1.0
error(string("Scaled time step: ", scaled_Δt, " is >= 1.0", "\nChange s or EK time step."))
end

u = get_u_final(ekp)
noise_std = sqrt(scaled_Δt / (1 - scaled_Δt))
u_updated = u + randn(size(u)) * noise_std
ekp.u[end] = DataContainer(u_updated, data_are_columns = true)
end

"""
update_ensemble!(
ekp::EnsembleKalmanProcess,
g::AbstractMatrix{FT};
multiplicative_inflation::Bool = false,
additive_inflation::Bool = false,
s::FT = 0.0,
ekp_kwargs...,
) where {FT, IT}
Updates the ensemble according to an Inversion process.
Inputs:
- ekp :: The EnsembleKalmanProcess to update.
- g :: Model outputs, they need to be stored as a `N_obs × N_ens` array (i.e data are columms).
- multiplicative_inflation :: Flag indicating whether to use multiplicative inflation.
- additive_inflation :: Flag indicating whether to use additive inflation.
- s :: Scaling factor for time step in inflation step.
- ekp_kwargs :: Keyword arguments to pass to standard ekp update_ensemble!.
"""
function update_ensemble!(ekp::EnsembleKalmanProcess, g::AbstractMatrix{FT}; ekp_kwargs...) where {FT, IT}
function update_ensemble!(
ekp::EnsembleKalmanProcess,
g::AbstractMatrix{FT};
multiplicative_inflation::Bool = false,
additive_inflation::Bool = false,
s::FT = 0.0,
ekp_kwargs...,
) where {FT, IT}

update_ensemble!(ekp, g, get_process(ekp); ekp_kwargs...)
if s > 0.0
multiplicative_inflation ? multiplicative_inflation!(ekp; s = s) : nothing
additive_inflation ? additive_inflation!(ekp; s = s) : nothing
end
end

## include the different types of Processes and their exports:
Expand Down
8 changes: 8 additions & 0 deletions src/Localizers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,4 +239,12 @@ function Localizer(localization::SECFisher, p::IT, d::IT, J::IT, T = Float64) wh
return Localizer{SECFisher, T}((cov) -> sec_fisher(cov, J))
end

"""
get_localizer(loc::Localizer)
Return localizer type.
"""
function get_localizer(loc::Localizer{T1, T2}) where {T1, T2}
return T1
end

end # module
112 changes: 112 additions & 0 deletions test/Inflation/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
using Distributions
using LinearAlgebra
using Random
using Test

using EnsembleKalmanProcesses
using EnsembleKalmanProcesses.ParameterDistributions
using EnsembleKalmanProcesses.Localizers
import EnsembleKalmanProcesses: construct_mean, construct_cov, construct_sigma_ensemble
const EKP = EnsembleKalmanProcesses


# Read inverse problem definitions
include("../EnsembleKalmanProcess/inverse_problem.jl")

n_obs = 30 # dimension of synthetic observation from G(u)
ϕ_star = [-1.0, 2.0] # True parameters in constrained space
n_par = size(ϕ_star, 1)
noise_level = 0.1 # Defining the observation noise level (std)
N_ens = 5000 # number of ensemble members
N_iter = 5


obs_corrmat = Diagonal(Matrix(I, n_obs, n_obs))

prior_1 = Dict("distribution" => Parameterized(Normal(0.0, 0.5)), "constraint" => bounded(-2, 2), "name" => "cons_p")
prior_2 = Dict("distribution" => Parameterized(Normal(3.0, 0.5)), "constraint" => no_constraint(), "name" => "uncons_p")
prior = ParameterDistribution([prior_1, prior_2])
prior_mean = mean(prior)
prior_cov = cov(prior)

rng_seed = 42
rng = Random.MersenneTwister(rng_seed)

initial_ensemble = EKP.construct_initial_ensemble(rng, prior, N_ens)

Δts = [0.5, 0.75, 0.95]

@testset "Inflation" begin

ekiobj = nothing

for Δt_i in 1:length(Δts)
Δt = Δts[Δt_i]

# Get inverse problem
y_obs, G, Γy, A =
linear_inv_problem(ϕ_star, noise_level, n_obs, prior, rng; obs_corrmat = obs_corrmat, return_matrix = true)

ekiobj = EKP.EnsembleKalmanProcess(
initial_ensemble,
y_obs,
Γy,
Inversion();
Δt = Δt,
rng = rng,
failure_handler_method = SampleSuccGauss(),
)

g_ens = G(get_u_final(ekiobj))

# ensure error is thrown when scaled time step >= 1
@test_throws ErrorException EKP.update_ensemble!(ekiobj, g_ens; multiplicative_inflation = true, s = 3.0)
@test_throws ErrorException EKP.update_ensemble!(ekiobj, g_ens; additive_inflation = true, s = 3.0)

# EKI iterations
for i in 1:N_iter
# Check SampleSuccGauss handler
params_i = get_u_final(ekiobj)

g_ens = G(params_i)

# standard update
EKP.update_ensemble!(ekiobj, g_ens, EKP.get_process(ekiobj))
eki_mult_inflation = deepcopy(ekiobj)
eki_add_inflation = deepcopy(ekiobj)

# multiplicative inflation after standard update
EKP.multiplicative_inflation!(eki_mult_inflation)
# additive inflation after standard update
EKP.additive_inflation!(eki_add_inflation)

# ensure multiplicative inflation preserves ensemble mean
@test get_u_mean_final(ekiobj) get_u_mean_final(eki_mult_inflation) atol = 1e-10
# ensure additive inflation approximately preserves ensemble mean
@test get_u_mean_final(ekiobj) get_u_mean_final(eki_add_inflation) rtol = 0.1

# ensure inflation increases variance of ensemble
ϕ_final_standard = get_ϕ_final(prior, ekiobj)
ϕ_final_mult_inflation = get_ϕ_final(prior, eki_mult_inflation)
ϕ_final_add_inflation = get_ϕ_final(prior, eki_add_inflation)
ϕ_mean_final_standard = get_ϕ_mean_final(prior, ekiobj)
ϕ_mean_final_mult_inflation = get_ϕ_mean_final(prior, eki_mult_inflation)
ϕ_mean_final_add_inflation = get_ϕ_mean_final(prior, eki_add_inflation)
standard_update_variance = norm(ϕ_final_standard .- ϕ_mean_final_standard)
@test norm(ϕ_final_mult_inflation .- ϕ_mean_final_mult_inflation) > standard_update_variance
@test norm(ϕ_final_add_inflation .- ϕ_mean_final_add_inflation) > standard_update_variance

# ensure inflation is only added in final iteration
u_standard = get_u(ekiobj)
u_mult_inflation = get_u(eki_mult_inflation)
u_add_inflation = get_u(eki_add_inflation)
@test u_standard[1:(end - 1)] == u_mult_inflation[1:(end - 1)]
@test u_standard[1:(end - 1)] == u_add_inflation[1:(end - 1)]
@test u_standard[end] != u_mult_inflation[end]
@test u_standard[end] != u_add_inflation[end]

end
end
# inflation update should not affect initial parameter ensemble (drawn from prior)
@test get_u_prior(ekiobj) == initial_ensemble
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ end
"Localizers",
"TOMLInterface",
"SparseInversion",
"Inflation",
]
if all_tests || has_submodule(submodule) || "EnsembleKalmanProcesses" in ARGS
include_test(submodule)
Expand Down

0 comments on commit 40d53ee

Please sign in to comment.