-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add particle inflation functionality
- Loading branch information
1 parent
5109b7f
commit 40d53ee
Showing
7 changed files
with
246 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Inflation | ||
Inflation is a method which systematically adds stochastic perturbations to the parameter ensemble in the form of gaussian noise. | ||
Following the standard Kalman update, the parameter ensemble is perturbed in the following manner: | ||
```math | ||
u_{n+1} = u_n + \zeta_{n} \qquad (1) | ||
``` | ||
|
||
where `` \zeta_{n} `` is noise drawn from a gaussian distribution. Two distinct forms of inflation are implemented. | ||
For both implementations, a scaling factor ``s`` is included to extend usage to cases with mini-batching. For mini-batching, | ||
the scaling factor should be: | ||
```math | ||
s = \frac{|B|}{|C|} | ||
``` | ||
where `` |B| `` is the mini-batch size and `` |C| `` is the full dataset size, to account for sampling error. | ||
|
||
## Multiplicative Inflation | ||
Multiplicative effectively scales parameter vectors in parameter space, such that the perturbed | ||
ensemble remains in the linear span of the original ensemble. The implementated update equation follows | ||
[Huang et al, 2022](https://arxiv.org/abs/2204.04386) eqn. 41: | ||
|
||
```math | ||
\begin{aligned} | ||
m_{n+1} = m_{n} ; \qquad u^{j}_{n + 1} = m_{n+1} + \sqrt{\frac{1}{1 - s \Delta{t}}} \left(u^{j}_{n} - m_{n} \right) \qquad (2) | ||
\end{aligned} | ||
``` | ||
where ``m`` is the ensemble average which remains unchanged after the perturbation. | ||
|
||
Multiplicative inflation can be used by flagging the `update_ensemble!` method as follows: | ||
```julia | ||
EKP.update_ensemble!(ekiobj, g_ens; multiplicative_inflation = true, s = 1.0) | ||
``` | ||
|
||
## Additive Inflation | ||
Additive inflation breaks the linear subspace property, meaning the parameter ensemble can evolve | ||
outside of the span of the initial ensemble. In additive inflation, the update equation draws | ||
noise from the following distribution. | ||
|
||
```math | ||
\zeta_{n} \sim N(0, \frac{s \Delta{t} }{1 - s \Delta{t}} C_n) \qquad (3) | ||
``` | ||
|
||
Additive inflation can be used by flagging the `update_ensemble!` method as follows: | ||
```julia | ||
EKP.update_ensemble!(ekiobj, g_ens; additive_inflation = true, s = 1.0) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
using Distributions | ||
using LinearAlgebra | ||
using Random | ||
using Test | ||
|
||
using EnsembleKalmanProcesses | ||
using EnsembleKalmanProcesses.ParameterDistributions | ||
using EnsembleKalmanProcesses.Localizers | ||
import EnsembleKalmanProcesses: construct_mean, construct_cov, construct_sigma_ensemble | ||
const EKP = EnsembleKalmanProcesses | ||
|
||
|
||
# Read inverse problem definitions | ||
include("../EnsembleKalmanProcess/inverse_problem.jl") | ||
|
||
n_obs = 30 # dimension of synthetic observation from G(u) | ||
ϕ_star = [-1.0, 2.0] # True parameters in constrained space | ||
n_par = size(ϕ_star, 1) | ||
noise_level = 0.1 # Defining the observation noise level (std) | ||
N_ens = 5000 # number of ensemble members | ||
N_iter = 5 | ||
|
||
|
||
obs_corrmat = Diagonal(Matrix(I, n_obs, n_obs)) | ||
|
||
prior_1 = Dict("distribution" => Parameterized(Normal(0.0, 0.5)), "constraint" => bounded(-2, 2), "name" => "cons_p") | ||
prior_2 = Dict("distribution" => Parameterized(Normal(3.0, 0.5)), "constraint" => no_constraint(), "name" => "uncons_p") | ||
prior = ParameterDistribution([prior_1, prior_2]) | ||
prior_mean = mean(prior) | ||
prior_cov = cov(prior) | ||
|
||
rng_seed = 42 | ||
rng = Random.MersenneTwister(rng_seed) | ||
|
||
initial_ensemble = EKP.construct_initial_ensemble(rng, prior, N_ens) | ||
|
||
Δts = [0.5, 0.75, 0.95] | ||
|
||
@testset "Inflation" begin | ||
|
||
ekiobj = nothing | ||
|
||
for Δt_i in 1:length(Δts) | ||
Δt = Δts[Δt_i] | ||
|
||
# Get inverse problem | ||
y_obs, G, Γy, A = | ||
linear_inv_problem(ϕ_star, noise_level, n_obs, prior, rng; obs_corrmat = obs_corrmat, return_matrix = true) | ||
|
||
ekiobj = EKP.EnsembleKalmanProcess( | ||
initial_ensemble, | ||
y_obs, | ||
Γy, | ||
Inversion(); | ||
Δt = Δt, | ||
rng = rng, | ||
failure_handler_method = SampleSuccGauss(), | ||
) | ||
|
||
g_ens = G(get_u_final(ekiobj)) | ||
|
||
# ensure error is thrown when scaled time step >= 1 | ||
@test_throws ErrorException EKP.update_ensemble!(ekiobj, g_ens; multiplicative_inflation = true, s = 3.0) | ||
@test_throws ErrorException EKP.update_ensemble!(ekiobj, g_ens; additive_inflation = true, s = 3.0) | ||
|
||
# EKI iterations | ||
for i in 1:N_iter | ||
# Check SampleSuccGauss handler | ||
params_i = get_u_final(ekiobj) | ||
|
||
g_ens = G(params_i) | ||
|
||
# standard update | ||
EKP.update_ensemble!(ekiobj, g_ens, EKP.get_process(ekiobj)) | ||
eki_mult_inflation = deepcopy(ekiobj) | ||
eki_add_inflation = deepcopy(ekiobj) | ||
|
||
# multiplicative inflation after standard update | ||
EKP.multiplicative_inflation!(eki_mult_inflation) | ||
# additive inflation after standard update | ||
EKP.additive_inflation!(eki_add_inflation) | ||
|
||
# ensure multiplicative inflation preserves ensemble mean | ||
@test get_u_mean_final(ekiobj) ≈ get_u_mean_final(eki_mult_inflation) atol = 1e-10 | ||
# ensure additive inflation approximately preserves ensemble mean | ||
@test get_u_mean_final(ekiobj) ≈ get_u_mean_final(eki_add_inflation) rtol = 0.1 | ||
|
||
# ensure inflation increases variance of ensemble | ||
ϕ_final_standard = get_ϕ_final(prior, ekiobj) | ||
ϕ_final_mult_inflation = get_ϕ_final(prior, eki_mult_inflation) | ||
ϕ_final_add_inflation = get_ϕ_final(prior, eki_add_inflation) | ||
ϕ_mean_final_standard = get_ϕ_mean_final(prior, ekiobj) | ||
ϕ_mean_final_mult_inflation = get_ϕ_mean_final(prior, eki_mult_inflation) | ||
ϕ_mean_final_add_inflation = get_ϕ_mean_final(prior, eki_add_inflation) | ||
standard_update_variance = norm(ϕ_final_standard .- ϕ_mean_final_standard) | ||
@test norm(ϕ_final_mult_inflation .- ϕ_mean_final_mult_inflation) > standard_update_variance | ||
@test norm(ϕ_final_add_inflation .- ϕ_mean_final_add_inflation) > standard_update_variance | ||
|
||
# ensure inflation is only added in final iteration | ||
u_standard = get_u(ekiobj) | ||
u_mult_inflation = get_u(eki_mult_inflation) | ||
u_add_inflation = get_u(eki_add_inflation) | ||
@test u_standard[1:(end - 1)] == u_mult_inflation[1:(end - 1)] | ||
@test u_standard[1:(end - 1)] == u_add_inflation[1:(end - 1)] | ||
@test u_standard[end] != u_mult_inflation[end] | ||
@test u_standard[end] != u_add_inflation[end] | ||
|
||
end | ||
end | ||
# inflation update should not affect initial parameter ensemble (drawn from prior) | ||
@test get_u_prior(ekiobj) == initial_ensemble | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters