-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add particle inflation functionality
- Loading branch information
1 parent
5109b7f
commit 8b71248
Showing
7 changed files
with
250 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Inflation | ||
Inflation is a approach that slows ensemble collapse in Kalman methods. | ||
Two distinct forms of inflation are implemented, in which perturbations are applied immediately following the standard Kalman update. | ||
Multiplicative inflation expands ensemble members away from their mean in a | ||
deterministic manner, whereas additive inflation hinges on the addition of stochastic noise to ensemble members. | ||
For both implementations, a scaling factor ``s`` is included to extend functionality to cases with mini-batching. | ||
The scaling factor ``s`` multiplies the artificial time step ``\Delta t`` in the inflation equations to account for sampling error. For mini-batching, the scaling factor should be: | ||
```math | ||
s = \frac{|B|}{|C|} | ||
``` | ||
where `` |B| `` is the mini-batch size and `` |C| `` is the full dataset size. | ||
|
||
## Multiplicative Inflation | ||
Multiplicative inflation effectively scales parameter vectors in parameter space, such that the perturbed | ||
ensemble remains in the linear span of the original ensemble. The implemented update equation follows | ||
[Huang et al, 2022](https://arxiv.org/abs/2204.04386) eqn. 41: | ||
|
||
```math | ||
\begin{aligned} | ||
m_{n+1} = m_{n} ; \qquad u^{j}_{n + 1} = m_{n+1} + \sqrt{\frac{1}{1 - s \Delta{t}}} \left(u^{j}_{n} - m_{n} \right) \qquad (1) | ||
\end{aligned} | ||
``` | ||
where ``m`` is the ensemble average. In this way, | ||
the variance across parameter vector magnitudes is increased by a factor of ``\frac{1}{1 - s \Delta{t}}``, while the mean remains fixed. | ||
|
||
Multiplicative inflation can be used by flagging the `update_ensemble!` method as follows: | ||
```julia | ||
EKP.update_ensemble!(ekiobj, g_ens; multiplicative_inflation = true, s = 1.0) | ||
``` | ||
|
||
## Additive Inflation | ||
Additive inflation is implemented by systematically adding stochastic perturbations to the parameter ensemble in the form of gaussian noise. Additive inflation breaks the linear subspace property, meaning the parameter ensemble can evolve outside of the span of the initial ensemble. In additive inflation, the ensemble is perturbed in the following manner after the standard Kalman update: | ||
|
||
```math | ||
u_{n+1} = u_n + \zeta_{n} \qquad (2) \\ | ||
\zeta_{n} \sim N(0, \frac{s \Delta{t} }{1 - s \Delta{t}} C_n) \qquad (3) | ||
``` | ||
This increases the variance across vector components by `` \frac{s \Delta{t} }{1 - s \Delta{t}} ``, while the mean remains fixed. | ||
Additive inflation can be used by flagging the `update_ensemble!` method as follows: | ||
```julia | ||
EKP.update_ensemble!(ekiobj, g_ens; additive_inflation = true, s = 1.0) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
using Distributions | ||
using LinearAlgebra | ||
using Random | ||
using Test | ||
|
||
using EnsembleKalmanProcesses | ||
using EnsembleKalmanProcesses.ParameterDistributions | ||
using EnsembleKalmanProcesses.Localizers | ||
import EnsembleKalmanProcesses: construct_mean, construct_cov, construct_sigma_ensemble | ||
const EKP = EnsembleKalmanProcesses | ||
|
||
|
||
# Read inverse problem definitions | ||
include("../EnsembleKalmanProcess/inverse_problem.jl") | ||
|
||
n_obs = 30 # dimension of synthetic observation from G(u) | ||
ϕ_star = [-1.0, 2.0] # True parameters in constrained space | ||
n_par = size(ϕ_star, 1) | ||
noise_level = 0.1 # Defining the observation noise level (std) | ||
N_ens = 10000 # number of ensemble members | ||
N_iter = 1 # number of EKI iterations | ||
|
||
obs_corrmat = Diagonal(Matrix(I, n_obs, n_obs)) | ||
|
||
prior_1 = Dict("distribution" => Parameterized(Normal(0.0, 0.5)), "constraint" => bounded(-2, 2), "name" => "cons_p") | ||
prior_2 = Dict("distribution" => Parameterized(Normal(3.0, 0.5)), "constraint" => no_constraint(), "name" => "uncons_p") | ||
prior = ParameterDistribution([prior_1, prior_2]) | ||
prior_mean = mean(prior) | ||
prior_cov = cov(prior) | ||
|
||
rng_seed = 42 | ||
rng = Random.MersenneTwister(rng_seed) | ||
|
||
initial_ensemble = EKP.construct_initial_ensemble(rng, prior, N_ens) | ||
|
||
Δts = [0.5, 0.75, 0.95] | ||
|
||
@testset "Inflation" begin | ||
|
||
ekiobj = nothing | ||
|
||
for Δt_i in 1:length(Δts) | ||
Δt = Δts[Δt_i] | ||
|
||
# Get inverse problem | ||
y_obs, G, Γy, A = | ||
linear_inv_problem(ϕ_star, noise_level, n_obs, prior, rng; obs_corrmat = obs_corrmat, return_matrix = true) | ||
|
||
ekiobj = EKP.EnsembleKalmanProcess( | ||
initial_ensemble, | ||
y_obs, | ||
Γy, | ||
Inversion(); | ||
Δt = Δt, | ||
rng = rng, | ||
failure_handler_method = SampleSuccGauss(), | ||
) | ||
|
||
g_ens = G(get_u_final(ekiobj)) | ||
|
||
# ensure error is thrown when scaled time step >= 1 | ||
@test_throws ErrorException EKP.update_ensemble!(ekiobj, g_ens; multiplicative_inflation = true, s = 3.0) | ||
@test_throws ErrorException EKP.update_ensemble!(ekiobj, g_ens; additive_inflation = true, s = 3.0) | ||
|
||
# EKI iterations | ||
for i in 1:N_iter | ||
# Check SampleSuccGauss handler | ||
params_i = get_u_final(ekiobj) | ||
|
||
g_ens = G(params_i) | ||
|
||
# standard update | ||
EKP.update_ensemble!(ekiobj, g_ens, EKP.get_process(ekiobj)) | ||
eki_mult_inflation = deepcopy(ekiobj) | ||
eki_add_inflation = deepcopy(ekiobj) | ||
|
||
# multiplicative inflation after standard update | ||
EKP.multiplicative_inflation!(eki_mult_inflation) | ||
# additive inflation after standard update | ||
EKP.additive_inflation!(eki_add_inflation) | ||
|
||
# ensure multiplicative inflation preserves ensemble mean | ||
@test get_u_mean_final(ekiobj) ≈ get_u_mean_final(eki_mult_inflation) atol = 1e-10 | ||
# ensure additive inflation approximately preserves ensemble mean | ||
@test get_u_mean_final(ekiobj) ≈ get_u_mean_final(eki_add_inflation) rtol = 0.02 | ||
|
||
# check if ensemble is inflated as expected | ||
u_final_standard = get_u_final(ekiobj) | ||
u_final_mult_inflation = get_u_final(eki_mult_inflation) | ||
u_final_add_inflation = get_u_final(eki_add_inflation) | ||
standard_update_var = var(u_final_standard) | ||
|
||
# expected increase in variance (of vector components) after additive inflation | ||
expected_additive_var = Δt / (1 - Δt) | ||
# empirical increase in variance (of vector components) after additive inflation | ||
empirical_additive_var_increase = var(u_final_add_inflation) - standard_update_var | ||
@test empirical_additive_var_increase ≈ expected_additive_var rtol = 0.02 | ||
|
||
# variance of parameter vector magnitudes | ||
standard_update_mag_var = var(norm.(eachcol(u_final_standard))) | ||
mult_update_mag_var = var(norm.(eachcol(u_final_mult_inflation))) | ||
# expected ratio of variance of parameter vector magnitudes after/before additive inflation | ||
expected_additive_var_ratio = 1 / (1 - Δt) | ||
@test expected_additive_var_ratio ≈ mult_update_mag_var / standard_update_mag_var rtol = 0.01 | ||
|
||
# ensure inflation is only added in final iteration | ||
u_standard = get_u(ekiobj) | ||
u_mult_inflation = get_u(eki_mult_inflation) | ||
u_add_inflation = get_u(eki_add_inflation) | ||
@test u_standard[1:(end - 1)] == u_mult_inflation[1:(end - 1)] | ||
@test u_standard[1:(end - 1)] == u_add_inflation[1:(end - 1)] | ||
@test u_standard[end] != u_mult_inflation[end] | ||
@test u_standard[end] != u_add_inflation[end] | ||
|
||
end | ||
end | ||
# inflation update should not affect initial parameter ensemble (drawn from prior) | ||
@test get_u_prior(ekiobj) == initial_ensemble | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters