Skip to content

Commit

Permalink
add particle inflation functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
costachris committed Jan 20, 2023
1 parent 5109b7f commit a282c89
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 1 deletion.
80 changes: 79 additions & 1 deletion src/EnsembleKalmanProcess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,15 @@ function get_process(ekp::EnsembleKalmanProcess)
return ekp.process
end

"""
get_localizer(ekp::EnsembleKalmanProcess)
Return localizer type of EnsembleKalmanProcess.
"""
function get_localizer(ekp::EnsembleKalmanProcess)
return Localizers.get_localizer(ekp.localizer)
end


"""
construct_initial_ensemble(
rng::AbstractRNG,
Expand Down Expand Up @@ -448,20 +457,89 @@ function get_cov_blocks(cov::AbstractMatrix{FT}, p::IT) where {FT <: Real, IT <:
return uu_cov, ug_cov, gg_cov
end

"""
multiplicative_inflation!(
ekp::EnsembleKalmanProcess;
s,
) where {FT, IT}
Applies multiplicative noise to particles.
Inputs:
- ekp :: The EnsembleKalmanProcess to update.
- s :: Scaling factor for timestep in multiplicative perturbation.
"""
function multiplicative_inflation!(ekp::EnsembleKalmanProcess; s::FT = 0.0) where {FT <: Real}

u = get_u_final(ekp)
u_mean = get_u_mean_final(ekp)
scaled_Δt = s * ekp.Δt[end]
prefactor = sqrt(1 / (1 - scaled_Δt))

if scaled_Δt >= 1.0
@warn string(
"Parameter noise time step: ",
scaled_Δt,
"is >= 1.0",
"\nChange s or EK time step.",
)
end

u_updated = u_mean .+ prefactor * (u .- u_mean)
ekp.u[end] = DataContainer(u_updated, data_are_columns = true)

end

"""
additive_inflation!(
ekp::EnsembleKalmanProcess;
s,
) where {FT, IT}
Applies additive noise to particles.
Inputs:
- ekp :: The EnsembleKalmanProcess to update.
- s :: Scaling factor for timestep in additive perturbation.
"""
function additive_inflation!(ekp::EnsembleKalmanProcess; s::FT = 0.0) where {FT <: Real}

u = get_u_final(ekp)
scaled_Δt = s * ekp.Δt[end]
noise_std = sqrt(scaled_Δt / (1 - scaled_Δt))
u_updated = u + Random.randn(ekp.N_ens) * noise_std
ekp.u[end] = DataContainer(u_updated, data_are_columns = true)

end

"""
update_ensemble!(
ekp::EnsembleKalmanProcess,
g::AbstractMatrix{FT};
multiplicative_inflation::Bool = false,
additive_inflation::Bool = false,
s::FT = 0.0,
ekp_kwargs...,
) where {FT, IT}
Updates the ensemble according to an Inversion process.
Inputs:
- ekp :: The EnsembleKalmanProcess to update.
- g :: Model outputs, they need to be stored as a `N_obs × N_ens` array (i.e data are columms).
- multiplicative_inflation :: Flag indicating whether to use multiplicative inflation.
- additive_inflation :: Flag indicating whether to use additive inflation.
- s :: Scaling factor for time step in inflation step.
- ekp_kwargs :: Keyword arguments to pass to standard ekp update_ensemble!.
"""
function update_ensemble!(ekp::EnsembleKalmanProcess, g::AbstractMatrix{FT}; ekp_kwargs...) where {FT, IT}
function update_ensemble!(
ekp::EnsembleKalmanProcess,
g::AbstractMatrix{FT};
multiplicative_inflation::Bool = false,
additive_inflation::Bool = false,
s::FT = 0.0,
ekp_kwargs...,
) where {FT, IT}

update_ensemble!(ekp, g, get_process(ekp); ekp_kwargs...)
if s > 0.0
multiplicative_inflation ? multiplicative_inflation!(ekp; s = s) : nothing
additive_inflation ? additive_inflation!(ekp; s = s) : nothing
end
end

## include the different types of Processes and their exports:
Expand Down
8 changes: 8 additions & 0 deletions src/Localizers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,4 +239,12 @@ function Localizer(localization::SECFisher, p::IT, d::IT, J::IT, T = Float64) wh
return Localizer{SECFisher, T}((cov) -> sec_fisher(cov, J))
end

"""
get_localizer(loc::Localizer)
Return localizer type.
"""
function get_localizer(loc::Localizer{T1, T2}) where {T1, T2}
return T1
end

end # module
99 changes: 99 additions & 0 deletions test/StochasticKalmanUpdate/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
using Distributions
using LinearAlgebra
using Random
using Test

using EnsembleKalmanProcesses
using EnsembleKalmanProcesses.ParameterDistributions
using EnsembleKalmanProcesses.Localizers
import EnsembleKalmanProcesses: construct_mean, construct_cov, construct_sigma_ensemble
const EKP = EnsembleKalmanProcesses


# Read inverse problem definitions
include("../EnsembleKalmanProcess/inverse_problem.jl")

n_obs = 30 # dimension of synthetic observation from G(u)
ϕ_star = [-1.0, 2.0] # True parameters in constrained space
n_par = size(ϕ_star, 1)
noise_level = 0.1 # Defining the observation noise level (std)
N_ens = 50 # number of ensemble members
N_iter = 10


obs_corrmat = Diagonal(Matrix(I, n_obs, n_obs))

prior_1 = Dict("distribution" => Parameterized(Normal(0.0, 0.5)), "constraint" => bounded(-2, 2), "name" => "cons_p")
prior_2 = Dict("distribution" => Parameterized(Normal(3.0, 0.5)), "constraint" => no_constraint(), "name" => "uncons_p")
prior = ParameterDistribution([prior_1, prior_2])
prior_mean = mean(prior)
prior_cov = cov(prior)


# Define a few inverse problems to compare algorithmic performance
rng_seed = 42
rng = Random.MersenneTwister(rng_seed)


initial_ensemble = EKP.construct_initial_ensemble(rng, prior, N_ens)

Δts = [0.5, 0.75, 0.95]

@testset "Inflation" begin

ekiobj = nothing

for Δt_i in 1:length(Δts)
Δt = Δts[Δt_i]

# Get inverse problem
y_obs, G, Γy, A =
linear_inv_problem(ϕ_star, noise_level, n_obs, prior, rng; obs_corrmat = obs_corrmat, return_matrix = true)

ekiobj = EKP.EnsembleKalmanProcess(
initial_ensemble,
y_obs,
Γy,
Inversion();
Δt = Δt,
rng = rng,
failure_handler_method = SampleSuccGauss(),
)

g_ens = G(get_u_final(ekiobj))

# EKI iterations
for i in 1:N_iter
# Check SampleSuccGauss handler
params_i = get_u_final(ekiobj)

g_ens = G(params_i)

# standard update
EKP.update_ensemble!(ekiobj, g_ens, EKP.get_process(ekiobj))
ekp_standard = deepcopy(ekiobj)
# add noise to parameters
EKP.multiplicative_inflation!(ekiobj)
ekp_stochastic = deepcopy(ekiobj)

# ensure stochastic update preserves ensemble mean
@test get_u_mean_final(ekp_standard) get_u_mean_final(ekp_stochastic) atol = 1e-10

# ensure stochastic update increases variance of ensemble
ϕ_final_standard = get_ϕ_final(prior, ekp_standard)
ϕ_final_stochastic = get_ϕ_final(prior, ekp_stochastic)
ϕ_mean_final_standard = get_ϕ_mean_final(prior, ekp_standard)
ϕ_mean_final_stochastic = get_ϕ_mean_final(prior, ekp_stochastic)
@test norm(ϕ_final_stochastic .- ϕ_mean_final_stochastic) > norm(ϕ_final_standard .- ϕ_mean_final_standard)

# ensure parameter noise is only added in final iteration
u_standard = get_u(ekp_standard)
u_stochastic = get_u(ekp_stochastic)
@test u_standard[1:(end - 1)] == u_stochastic[1:(end - 1)]
@test u_standard[end] != u_stochastic[end]

end
end
# stochastic update should not affect initial parameter ensemble (drawn from prior)
@test get_u_prior(ekiobj) == initial_ensemble
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ end
"Localizers",
"TOMLInterface",
"SparseInversion",
"Inflation",
]
if all_tests || has_submodule(submodule) || "EnsembleKalmanProcesses" in ARGS
include_test(submodule)
Expand Down

0 comments on commit a282c89

Please sign in to comment.