Skip to content

Commit

Permalink
add stochastic_update function
Browse files Browse the repository at this point in the history
  • Loading branch information
costachris committed Dec 6, 2022
1 parent 74c434f commit 25dfef8
Show file tree
Hide file tree
Showing 8 changed files with 172 additions and 8 deletions.
2 changes: 1 addition & 1 deletion examples/LossMinimization/loss_minimization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ nothing # hide
#
# We choose the number of ensemble members and the number of iterations of the algorithm
N_ensemble = 20
N_iterations = 10
N_iterations = 20
nothing # hide

# The initial ensemble is constructed by sampling the prior
Expand Down
63 changes: 60 additions & 3 deletions src/EnsembleKalmanProcess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ $(TYPEDFIELDS)
obs_mean,
obs_noise_cov::Union{AbstractMatrix{FT}, UniformScaling{FT}},
process::P;
Δt = FT(1),
Δt = FT(0.5),
rng::AbstractRNG = Random.GLOBAL_RNG,
failure_handler_method::FM = IgnoreFailures(),
localization_method::LM = NoLocalization(),
Expand Down Expand Up @@ -123,7 +123,7 @@ function EnsembleKalmanProcess(
obs_mean,
obs_noise_cov::Union{AbstractMatrix{FT}, UniformScaling{FT}},
process::P;
Δt = FT(1),
Δt = FT(0.5),
rng::AbstractRNG = Random.GLOBAL_RNG,
failure_handler_method::FM = IgnoreFailures(),
localization_method::LM = NoLocalization(),
Expand Down Expand Up @@ -328,6 +328,15 @@ function get_process(ekp::EnsembleKalmanProcess)
return ekp.process
end

"""
get_localizer(ekp::EnsembleKalmanProcess)
Return localizer type of EnsembleKalmanProcess.
"""
function get_localizer(ekp::EnsembleKalmanProcess)
return Localizers.get_localizer(ekp.localizer)
end


"""
construct_initial_ensemble(
rng::AbstractRNG,
Expand Down Expand Up @@ -448,6 +457,38 @@ function get_cov_blocks(cov::AbstractMatrix{FT}, p::IT) where {FT <: Real, IT <:
return uu_cov, ug_cov, gg_cov
end

"""
stochastic_update!(
ekp::EnsembleKalmanProcess;
stochastic_scaling_factor,
) where {FT, IT}
Applies stochastic prediction step to ekp.
Inputs:
- ekp :: The EnsembleKalmanProcess to update.
- stochastic_scaling_factor :: Scaling factor for the stochastic perturbation.
"""
function stochastic_update!(ekp::EnsembleKalmanProcess; stochastic_scaling_factor::FT = 1.0) where {FT <: Real}

u = get_u_final(ekp)
u_mean = get_u_mean_final(ekp)
prefactor = sqrt(1 / (1 - stochastic_scaling_factor * ekp.Δt[end]))

stochastic_Δt = stochastic_scaling_factor * ekp.Δt[end]
if stochastic_Δt >= 1.0
@warn string(
"Parameter noise time step: ",
stochastic_Δt,
"is >= 1.0",
"\nChange stochastic_scaling_factor or EK time step.",
)
end

u_updated = u_mean .+ prefactor * (u .- u_mean)
ekp.u[end] = DataContainer(u_updated, data_are_columns = true)

end


"""
update_ensemble!(
ekp::EnsembleKalmanProcess,
Expand All @@ -458,10 +499,26 @@ Updates the ensemble according to an Inversion process.
Inputs:
- ekp :: The EnsembleKalmanProcess to update.
- g :: Model outputs, they need to be stored as a `N_obs × N_ens` array (i.e data are columms).
- stoch_flag :: Flag indicating whether to use stochastic update.
- stochastic_scaling_factor :: Scaling factor for time step in stochastic perturbation.
- ekp_kwargs :: Keyword arguments to pass to standard ekp update_ensemble!.
"""
function update_ensemble!(ekp::EnsembleKalmanProcess, g::AbstractMatrix{FT}; ekp_kwargs...) where {FT, IT}
function update_ensemble!(
ekp::EnsembleKalmanProcess,
g::AbstractMatrix{FT};
stoch_flag::Bool = true,
stochastic_scaling_factor::FT = 1.0,
ekp_kwargs...,
) where {FT, IT}

if !(get_localizer(ekp) == NoLocalization) && stoch_flag
@warn string(
"Using stochastic update with localization often leads to an unstable calibration! Use either stochastic update or localization.",
)
end

update_ensemble!(ekp, g, get_process(ekp); ekp_kwargs...)
stoch_flag ? stochastic_update!(ekp; stochastic_scaling_factor = stochastic_scaling_factor) : nothing
end

## include the different types of Processes and their exports:
Expand Down
8 changes: 8 additions & 0 deletions src/Localizers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,4 +239,12 @@ function Localizer(localization::SECFisher, p::IT, d::IT, J::IT, T = Float64) wh
return Localizer{SECFisher, T}((cov) -> sec_fisher(cov, J))
end

"""
get_localizer(loc::Localizer)
Return localizer type.
"""
function get_localizer(loc::Localizer{T1, T2}) where {T1, T2}
return T1
end

end # module
2 changes: 1 addition & 1 deletion src/UnscentedKalmanInversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ function EnsembleKalmanProcess(
obs_mean::AbstractVector{FT},
obs_noise_cov::Union{AbstractMatrix{FT}, UniformScaling{FT}},
process::Unscented{FT, IT};
Δt = FT(1),
Δt = FT(0.5),
rng::AbstractRNG = Random.GLOBAL_RNG,
failure_handler_method::FM = IgnoreFailures(),
localization_method::LM = NoLocalization(),
Expand Down
3 changes: 1 addition & 2 deletions test/EnsembleKalmanProcess/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ n_obs = 30 # dimension of synthetic observation from G(u)
n_par = size(ϕ_star, 1)
noise_level = 0.1 # Defining the observation noise level (std)
N_ens = 50 # number of ensemble members
N_iter = 20
N_iter = 35

# Test different AbstractMatrices as covariances
obs_corrmats = [I, Matrix(I, n_obs, n_obs), Diagonal(Matrix(I, n_obs, n_obs))]
Expand Down Expand Up @@ -360,7 +360,6 @@ end
push!(params_i_vec, get_u_final(ukiobj))

@test get_u_prior(ukiobj) == params_i_vec[1]
@test get_u(ukiobj) == params_i_vec
@test isequal(get_g(ukiobj), g_ens_vec)
@test isequal(get_g_final(ukiobj), g_ens_vec[end])
@test isequal(get_error(ukiobj), ukiobj.err)
Expand Down
2 changes: 1 addition & 1 deletion test/Localizers/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ const EKP = EnsembleKalmanProcesses

for i in 1:N_iter
g_ens = G(get_u_final(ekiobj))
EKP.update_ensemble!(ekiobj, g_ens, deterministic_forward_map = true)
EKP.update_ensemble!(ekiobj, g_ens, deterministic_forward_map = true; stoch_flag = false)
end

# Test that localized version does better in the setting p >> N_ens
Expand Down
99 changes: 99 additions & 0 deletions test/StochasticKalmanUpdate/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
using Distributions
using LinearAlgebra
using Random
using Test

using EnsembleKalmanProcesses
using EnsembleKalmanProcesses.ParameterDistributions
using EnsembleKalmanProcesses.Localizers
import EnsembleKalmanProcesses: construct_mean, construct_cov, construct_sigma_ensemble
const EKP = EnsembleKalmanProcesses


# Read inverse problem definitions
include("../EnsembleKalmanProcess/inverse_problem.jl")

n_obs = 30 # dimension of synthetic observation from G(u)
ϕ_star = [-1.0, 2.0] # True parameters in constrained space
n_par = size(ϕ_star, 1)
noise_level = 0.1 # Defining the observation noise level (std)
N_ens = 50 # number of ensemble members
N_iter = 10


obs_corrmat = Diagonal(Matrix(I, n_obs, n_obs))

prior_1 = Dict("distribution" => Parameterized(Normal(0.0, 0.5)), "constraint" => bounded(-2, 2), "name" => "cons_p")
prior_2 = Dict("distribution" => Parameterized(Normal(3.0, 0.5)), "constraint" => no_constraint(), "name" => "uncons_p")
prior = ParameterDistribution([prior_1, prior_2])
prior_mean = mean(prior)
prior_cov = cov(prior)


# Define a few inverse problems to compare algorithmic performance
rng_seed = 42
rng = Random.MersenneTwister(rng_seed)


initial_ensemble = EKP.construct_initial_ensemble(rng, prior, N_ens)

Δts = [0.5, 0.75, 0.95]

@testset "StochasticKalmanUpdate" begin

ekiobj = nothing

for Δt_i in 1:length(Δts)
Δt = Δts[Δt_i]

# Get inverse problem
y_obs, G, Γy, A =
linear_inv_problem(ϕ_star, noise_level, n_obs, prior, rng; obs_corrmat = obs_corrmat, return_matrix = true)

ekiobj = EKP.EnsembleKalmanProcess(
initial_ensemble,
y_obs,
Γy,
Inversion();
Δt = Δt,
rng = rng,
failure_handler_method = SampleSuccGauss(),
)

g_ens = G(get_u_final(ekiobj))

# EKI iterations
for i in 1:N_iter
# Check SampleSuccGauss handler
params_i = get_u_final(ekiobj)

g_ens = G(params_i)

# standard update
EKP.update_ensemble!(ekiobj, g_ens, EKP.get_process(ekiobj))
ekp_standard = deepcopy(ekiobj)
# add noise to parameters
EKP.stochastic_update!(ekiobj)
ekp_stochastic = deepcopy(ekiobj)

# ensure stochastic update preserves ensemble mean
@test get_u_mean_final(ekp_standard) get_u_mean_final(ekp_stochastic) atol = 1e-10

# ensure stochastic update increases variance of ensemble
ϕ_final_standard = get_ϕ_final(prior, ekp_standard)
ϕ_final_stochastic = get_ϕ_final(prior, ekp_stochastic)
ϕ_mean_final_standard = get_ϕ_mean_final(prior, ekp_standard)
ϕ_mean_final_stochastic = get_ϕ_mean_final(prior, ekp_stochastic)
@test norm(ϕ_final_stochastic .- ϕ_mean_final_stochastic) > norm(ϕ_final_standard .- ϕ_mean_final_standard)

# ensure parameter noise is only added in final iteration
u_standard = get_u(ekp_standard)
u_stochastic = get_u(ekp_stochastic)
@test u_standard[1:(end - 1)] == u_stochastic[1:(end - 1)]
@test u_standard[end] != u_stochastic[end]

end
end
# stochastic update should not affect initial parameter ensemble (drawn from prior)
@test get_u_prior(ekiobj) == initial_ensemble
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ end
"Localizers",
"TOMLInterface",
"SparseInversion",
"StochasticKalmanUpdate",
]
if all_tests || has_submodule(submodule) || "EnsembleKalmanProcesses" in ARGS
include_test(submodule)
Expand Down

0 comments on commit 25dfef8

Please sign in to comment.