Skip to content

Commit

Permalink
Ensemble tranform Kalman inversion
Browse files Browse the repository at this point in the history
Different error in old Julia versions
  • Loading branch information
eviatarbach committed Oct 1, 2023
1 parent 7bad162 commit 9301478
Show file tree
Hide file tree
Showing 4 changed files with 322 additions and 10 deletions.
12 changes: 10 additions & 2 deletions src/EnsembleKalmanProcess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,11 @@ struct EnsembleKalmanProcess{FT <: AbstractFloat, IT <: Int, P <: Process, LRS <
scheduler::LRS
"stored vector of timesteps used in each EK iteration"
Δt::Vector{FT}
"the particular EK process (`Inversion` or `Sampler` or `Unscented` or `SparseInversion`)"
"the particular EK process (`Inversion` or `Sampler` or `Unscented` or `TransformInversion` or `SparseInversion`)"
process::P
"Random number generator object (algorithm + seed) used for sampling and noise, for reproducibility. Defaults to `Random.GLOBAL_RNG`."
rng::AbstractRNG
"struct storing failsafe update directives, implemented for (`Inversion`, `SparseInversion`, `Unscented`)"
"struct storing failsafe update directives, implemented for (`Inversion`, `SparseInversion`, `Unscented`, `TransformInversion`)"
failure_handler::FailureHandler
"Localization kernel, implemented for (`Inversion`, `SparseInversion`, `Unscented`)"
localizer::Localizer
Expand Down Expand Up @@ -165,6 +165,10 @@ function EnsembleKalmanProcess(
# error store
err = FT[]

if (typeof(process) <: TransformInversion) & !(typeof(localization_method) == NoLocalization)
throw(ArgumentError("`TransformInversion` cannot currently be used with localization."))

Check warning on line 169 in src/EnsembleKalmanProcess.jl

View check run for this annotation

Codecov / codecov/patch

src/EnsembleKalmanProcess.jl#L169

Added line #L169 was not covered by tests
end

# set the timestep methods (being cautious of EKS scheduler)
if isnothing(scheduler)
if !(isnothing(Δt))
Expand Down Expand Up @@ -643,6 +647,10 @@ end
export Inversion
include("EnsembleKalmanInversion.jl")

# struct TransformInversion
export TransformInversion
include("EnsembleTransformKalmanInversion.jl")

# struct SparseInversion
export SparseInversion
include("SparseEnsembleKalmanInversion.jl")
Expand Down
139 changes: 139 additions & 0 deletions src/EnsembleTransformKalmanInversion.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
#Ensemble Transform Kalman Inversion: specific structures and function definitions

"""
TransformInversion <: Process
An ensemble transform Kalman inversion process.
# Fields
$(TYPEDFIELDS)
"""
struct TransformInversion{FT <: AbstractFloat} <: Process
"Inverse of the observation error covariance matrix"
Γ_inv::Union{AbstractMatrix{FT}, UniformScaling{FT}}
end

function FailureHandler(process::TransformInversion, method::IgnoreFailures)
failsafe_update(ekp, u, g, y, obs_noise_cov, failed_ens) = etki_update(ekp, u, g, y, obs_noise_cov)
return FailureHandler{TransformInversion, IgnoreFailures}(failsafe_update)
end

"""
FailureHandler(process::TransformInversion, method::SampleSuccGauss)
Provides a failsafe update that
- updates the successful ensemble according to the ETKI update,
- updates the failed ensemble by sampling from the updated successful ensemble.
"""
function FailureHandler(process::TransformInversion, method::SampleSuccGauss)
function failsafe_update(ekp, u, g, y, obs_noise_cov, failed_ens)
successful_ens = filter(x -> !(x in failed_ens), collect(1:size(g, 2)))
n_failed = length(failed_ens)
u[:, successful_ens] = etki_update(ekp, u[:, successful_ens], g[:, successful_ens], y, obs_noise_cov)
if !isempty(failed_ens)
u[:, failed_ens] = sample_empirical_gaussian(u[:, successful_ens], n_failed)
end
return u
end
return FailureHandler{TransformInversion, SampleSuccGauss}(failsafe_update)
end

"""
etki_update(
ekp::EnsembleKalmanProcess{FT, IT, TransformInversion},
u::AbstractMatrix{FT},
g::AbstractMatrix{FT},
y::AbstractVector{FT},
obs_noise_cov::Union{AbstractMatrix{CT}, UniformScaling{CT}},
) where {FT <: Real, IT, CT <: Real}
Returns the updated parameter vectors given their current values and
the corresponding forward model evaluations.
"""
function etki_update(
ekp::EnsembleKalmanProcess{FT, IT, TransformInversion{FT}},
u::AbstractMatrix{FT},
g::AbstractMatrix{FT},
y::AbstractVector{FT},
obs_noise_cov::Union{AbstractMatrix{CT}, UniformScaling{CT}},
) where {FT <: Real, IT, CT <: Real}
m = size(u, 2)
Γ_inv = ekp.process.Γ_inv

X = FT.((u .- mean(u, dims = 2)) / sqrt(m - 1))
Y = FT.((g .- mean(g, dims = 2)) / sqrt(m - 1))
Ω = inv(I + Y' * Γ_inv * Y)
w = FT.(Ω * Y' * Γ_inv * (y .- mean(g, dims = 2)))

return mean(u, dims = 2) .+ X * (w .+ sqrt(m - 1) * real(sqrt(Ω))) # [N_par × N_ens]
end

"""
update_ensemble!(
ekp::EnsembleKalmanProcess{FT, IT, TransformInversion},
g::AbstractMatrix{FT},
process::TransformInversion;
failed_ens = nothing,
) where {FT, IT}
Updates the ensemble according to a TransformInversion process.
Inputs:
- ekp :: The EnsembleKalmanProcess to update.
- g :: Model outputs, they need to be stored as a `N_obs × N_ens` array (i.e data are columms).
- process :: Type of the EKP.
- failed_ens :: Indices of failed particles. If nothing, failures are computed as columns of `g` with NaN entries.
"""
function update_ensemble!(
ekp::EnsembleKalmanProcess{FT, IT, TransformInversion{FT}},
g::AbstractMatrix{FT},
process::TransformInversion{FT};
failed_ens = nothing,
) where {FT, IT}

# u: N_par × N_ens
# g: N_obs × N_ens
u = get_u_final(ekp)
N_obs = size(g, 1)
cov_init = cov(u, dims = 2)

if ekp.verbose
if get_N_iterations(ekp) == 0
@info "Iteration 0 (prior)"
@info "Covariance trace: $(tr(cov_init))"

Check warning on line 104 in src/EnsembleTransformKalmanInversion.jl

View check run for this annotation

Codecov / codecov/patch

src/EnsembleTransformKalmanInversion.jl#L102-L104

Added lines #L102 - L104 were not covered by tests
end

@info "Iteration $(get_N_iterations(ekp)+1) (T=$(sum(ekp.Δt)))"

Check warning on line 107 in src/EnsembleTransformKalmanInversion.jl

View check run for this annotation

Codecov / codecov/patch

src/EnsembleTransformKalmanInversion.jl#L107

Added line #L107 was not covered by tests
end

fh = ekp.failure_handler

# Scale noise using Δt
scaled_obs_noise_cov = ekp.obs_noise_cov / ekp.Δt[end]

y = ekp.obs_mean

if isnothing(failed_ens)
_, failed_ens = split_indices_by_success(g)
end
if !isempty(failed_ens)
@info "$(length(failed_ens)) particle failure(s) detected. Handler used: $(nameof(typeof(fh).parameters[2]))."
end

u = fh.failsafe_update(ekp, u, g, y, scaled_obs_noise_cov, failed_ens)

# store new parameters (and model outputs)
push!(ekp.u, DataContainer(u, data_are_columns = true))
push!(ekp.g, DataContainer(g, data_are_columns = true))

# Store error
compute_error!(ekp)

# Diagnostics
cov_new = cov(get_u_final(ekp), dims = 2)

if ekp.verbose
@info "Covariance-weighted error: $(get_error(ekp)[end])\nCovariance trace: $(tr(cov_new))\nCovariance trace ratio (current/previous): $(tr(cov_new)/tr(cov_init))"

Check warning on line 137 in src/EnsembleTransformKalmanInversion.jl

View check run for this annotation

Codecov / codecov/patch

src/EnsembleTransformKalmanInversion.jl#L137

Added line #L137 was not covered by tests
end
end
165 changes: 165 additions & 0 deletions test/EnsembleKalmanProcess/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,171 @@ end
end
end

@testset "EnsembleTransformKalmanInversion" begin

# Seed for pseudo-random number generator
rng = Random.MersenneTwister(rng_seed)

initial_ensemble = EKP.construct_initial_ensemble(rng, prior, N_ens)

ekiobj = nothing
eki_final_result = nothing
iters_with_failure = [5, 8, 9, 15]

for (i_prob, inv_problem) in enumerate(inv_problems)

# Get inverse problem
y_obs, G, Γy, A = inv_problem
if i_prob == 1
scheduler = DataMisfitController(on_terminate = "continue")
else
scheduler = DefaultScheduler()
end

ekiobj = EKP.EnsembleKalmanProcess(
initial_ensemble,
y_obs,
Γy,
TransformInversion(inv(Γy));
rng = rng,
failure_handler_method = SampleSuccGauss(),
scheduler = scheduler,
)

ekiobj_unsafe = EKP.EnsembleKalmanProcess(
initial_ensemble,
y_obs,
Γy,
TransformInversion(inv(Γy));
rng = rng,
failure_handler_method = IgnoreFailures(),
scheduler = scheduler,
)


g_ens = G(get_ϕ_final(prior, ekiobj))
g_ens_t = permutedims(g_ens, (2, 1))

@test size(g_ens) == (n_obs, N_ens)

# ETKI iterations
u_i_vec = Array{Float64, 2}[]
g_ens_vec = Array{Float64, 2}[]
for i in 1:N_iter
params_i = get_ϕ_final(prior, ekiobj)
push!(u_i_vec, get_u_final(ekiobj))
g_ens = G(params_i)

# Add random failures
if i in iters_with_failure
g_ens[:, 1] .= NaN
end

EKP.update_ensemble!(ekiobj, g_ens)
push!(g_ens_vec, g_ens)
if i == 1
if !(size(g_ens, 1) == size(g_ens, 2))
g_ens_t = permutedims(g_ens, (2, 1))
@test_throws DimensionMismatch EKP.update_ensemble!(ekiobj, g_ens_t)
end
end

# Correct handling of failures
@test !any(isnan.(params_i))

# Check IgnoreFailures handler
if i <= iters_with_failure[1]
params_i_unsafe = get_ϕ_final(prior, ekiobj_unsafe)
g_ens_unsafe = G(params_i_unsafe)
if i < iters_with_failure[1]
EKP.update_ensemble!(ekiobj_unsafe, g_ens_unsafe)
elseif i == iters_with_failure[1]
g_ens_unsafe[:, 1] .= NaN
#inconsistent behaviour before/after v1.9 regarding NaNs in matrices
if (VERSION.major >= 1) && (VERSION.minor >= 9)
# new versions the NaNs break LinearAlgebra.jl
@test_throws ArgumentError EKP.update_ensemble!(ekiobj_unsafe, g_ens_unsafe)
end
end
end
end

push!(u_i_vec, get_u_final(ekiobj))

@test get_u_prior(ekiobj) == u_i_vec[1]
@test get_u(ekiobj) == u_i_vec
@test isequal(get_g(ekiobj), g_ens_vec)
@test isequal(get_g_final(ekiobj), g_ens_vec[end])
@test isequal(get_error(ekiobj), ekiobj.err)

# ETKI results: Test if ensemble has collapsed toward the true parameter
# values
eki_init_result = vec(mean(get_u_prior(ekiobj), dims = 2))
eki_final_result = get_u_mean_final(ekiobj)
eki_init_spread = tr(get_u_cov(ekiobj, 1))
eki_final_spread = tr(get_u_cov_final(ekiobj))

g_mean_init = get_g_mean(ekiobj, 1)
g_mean_final = get_g_mean_final(ekiobj)

@test eki_init_result == get_u_mean(ekiobj, 1)
@test eki_final_result == vec(mean(get_u_final(ekiobj), dims = 2))

@test eki_final_spread < 2 * eki_init_spread # we wouldn't expect the spread to increase much in any one dimension

ϕ_final_mean = get_ϕ_mean_final(prior, ekiobj)
ϕ_init_mean = get_ϕ_mean(prior, ekiobj, 1)

if nameof(typeof(ekiobj.localizer)) == EKP.Localizers.NoLocalization
@test norm(ϕ_star - ϕ_final_mean) < norm(ϕ_star - ϕ_init_mean)
@test norm(y_obs .- G(eki_final_result))^2 < norm(y_obs .- G(eki_init_result))^2
@test norm(y_obs .- g_mean_final)^2 < norm(y_obs .- g_mean_init)^2
end

if i_prob <= n_lin_inv_probs && nameof(typeof(ekiobj.localizer)) == EKP.Localizers.NoLocalization

posterior_cov_inv = (A' * (Γy \ A) + 1 * Matrix(I, n_par, n_par) / prior_cov)
ols_mean = (A' * (Γy \ A)) \ (A' * (Γy \ y_obs))
posterior_mean = posterior_cov_inv \ ((A' * (Γy \ A)) * ols_mean + (prior_cov \ prior_mean))

# ETKI provides a solution closer to the ordinary Least Squares estimate
@test norm(ols_mean - ϕ_final_mean) < norm(ols_mean - ϕ_init_mean)
end

# Plot evolution of the ETKI particles
if TEST_PLOT_OUTPUT
plot_inv_problem_ensemble(prior, ekiobj, joinpath(@__DIR__, "ETKI_test_$(i_prob).png"))
end
end

for (i, n_obs_test) in enumerate([10, 10, 100, 1000, 10000])
initial_ensemble = EKP.construct_initial_ensemble(rng, prior, N_ens)

y_obs_test, G_test, Γ_test, A_test =
linear_inv_problem(ϕ_star, noise_level, n_obs_test, rng; return_matrix = true)

ekiobj = EKP.EnsembleKalmanProcess(
initial_ensemble,
y_obs_test,
Γ_test,
TransformInversion(inv(Γ_test));
rng = rng,
failure_handler_method = SampleSuccGauss(),
)
T = 0.0
for i in 1:N_iter
params_i = get_ϕ_final(prior, ekiobj)
g_ens = G_test(params_i)

dt = @elapsed EKP.update_ensemble!(ekiobj, g_ens)
T += dt
end
# Skip timing of first due to precompilation
if i >= 2
@info "ETKI with $n_obs_test observations took $T seconds."
end
end
end

@testset "EnsembleKalmanProcess utils" begin
# Success/failure splitting
Expand Down
16 changes: 8 additions & 8 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ end
end

for submodule in [
"DataContainers",
"ParameterDistributions",
"PlotRecipes",
"Observations",
# "DataContainers",
# "ParameterDistributions",
# "PlotRecipes",
# "Observations",
"EnsembleKalmanProcess",
"Localizers",
"TOMLInterface",
"SparseInversion",
"Inflation",
# "Localizers",
# "TOMLInterface",
# "SparseInversion",
# "Inflation",
]
if all_tests || has_submodule(submodule) || "EnsembleKalmanProcesses" in ARGS
include_test(submodule)
Expand Down

0 comments on commit 9301478

Please sign in to comment.