Skip to content

Commit

Permalink
Merge #272
Browse files Browse the repository at this point in the history
272: Diagnostics r=eviatarbach a=eviatarbach

Closes #250

Co-authored-by: Eviatar Bach <[email protected]>
  • Loading branch information
bors[bot] and eviatarbach committed Apr 6, 2023
2 parents dd4673d + 4d88688 commit 5455fff
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 27 deletions.
26 changes: 14 additions & 12 deletions src/EnsembleKalmanInversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ end
ekp::EnsembleKalmanProcess{FT, IT, Inversion},
g::AbstractMatrix{FT},
process::Inversion;
cov_threshold::Real = 0.01,
Δt_new::Union{Nothing, FT} = nothing,
deterministic_forward_map::Bool = true,
failed_ens = nothing,
Expand All @@ -132,7 +131,6 @@ Inputs:
- ekp :: The EnsembleKalmanProcess to update.
- g :: Model outputs, they need to be stored as a `N_obs × N_ens` array (i.e data are columms).
- process :: Type of the EKP.
- cov_threshold :: Threshold below which the reduction in covariance determinant results in a warning.
- Δt_new :: Time step to be used in the current update.
- deterministic_forward_map :: Whether output `g` comes from a deterministic model.
- failed_ens :: Indices of failed particles. If nothing, failures are computed as columns of `g` with NaN entries.
Expand All @@ -141,7 +139,6 @@ function update_ensemble!(
ekp::EnsembleKalmanProcess{FT, IT, Inversion},
g::AbstractMatrix{FT},
process::Inversion;
cov_threshold::Real = 0.01,
Δt_new::Union{Nothing, FT} = nothing,
deterministic_forward_map::Bool = true,
failed_ens = nothing,
Expand All @@ -161,6 +158,16 @@ function update_ensemble!(
u = get_u_final(ekp)
N_obs = size(g, 1)
cov_init = cov(u, dims = 2)

if ekp.verbose
if get_N_iterations(ekp) == 0
@info "Iteration 0 (prior)"
@info "Covariance trace: $(tr(cov_init))"
end

@info "Iteration $(get_N_iterations(ekp)+1) (T=$(sum(ekp.Δt)))"
end

set_Δt!(ekp, Δt_new)
fh = ekp.failure_handler

Expand Down Expand Up @@ -188,15 +195,10 @@ function update_ensemble!(
# Store error
compute_error!(ekp)

# Check convergence
# Diagnostics
cov_new = cov(get_u_final(ekp), dims = 2)
cov_ratio = det(cov_new) / det(cov_init)
if cov_ratio < cov_threshold
@warn string(
"New ensemble covariance determinant is less than ",
cov_threshold,
" times its previous value.",
"\nConsider reducing the EK time step.",
)

if ekp.verbose
@info "Covariance-weighted error: $(get_error(ekp)[end])\nCovariance trace: $(tr(cov_new))\nCovariance trace ratio (current/previous): $(tr(cov_new)/tr(cov_init))"
end
end
26 changes: 24 additions & 2 deletions src/EnsembleKalmanProcess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ $(TYPEDFIELDS)
rng::AbstractRNG = Random.GLOBAL_RNG,
failure_handler_method::FM = IgnoreFailures(),
localization_method::LM = NoLocalization(),
verbose::Bool = false,
) where {FT <: AbstractFloat, P <: Process, FM <: FailureHandlingMethod, LM <: LocalizationMethod}
Inputs:
Expand All @@ -88,6 +89,7 @@ Inputs:
- `rng` :: Random number generator
- `failure_handler_method` :: Method used to handle particle failures
- `localization_method` :: Method used to localize sample covariances
- `verbose` :: Whether to print diagnostic information
# Other constructors:
Expand Down Expand Up @@ -116,6 +118,8 @@ struct EnsembleKalmanProcess{FT <: AbstractFloat, IT <: Int, P <: Process}
failure_handler::FailureHandler
"Localization kernel, implemented for (`Inversion`, `SparseInversion`, `Unscented`)"
localizer::Localizer
"Whether to print diagnostics for each EK iteration"
verbose::Bool
end

function EnsembleKalmanProcess(
Expand All @@ -127,6 +131,7 @@ function EnsembleKalmanProcess(
rng::AbstractRNG = Random.GLOBAL_RNG,
failure_handler_method::FM = IgnoreFailures(),
localization_method::LM = NoLocalization(),
verbose::Bool = false,
) where {FT <: AbstractFloat, P <: Process, FM <: FailureHandlingMethod, LM <: LocalizationMethod}

#initial parameters stored as columns
Expand All @@ -148,7 +153,24 @@ function EnsembleKalmanProcess(
# localizer
loc = Localizer(localization_method, N_par, N_obs, N_ens, FT)

EnsembleKalmanProcess{FT, IT, P}([init_params], obs_mean, obs_noise_cov, N_ens, g, err, Δt, process, rng, fh, loc)
if verbose
@info "Initializing ensemble Kalman process of type $(nameof(typeof(process)))\nNumber of ensemble members: $(N_ens)\nLocalization: $(nameof(typeof(localization_method)))\nFailure handler: $(nameof(typeof(failure_handler_method)))"
end

EnsembleKalmanProcess{FT, IT, P}(
[init_params],
obs_mean,
obs_noise_cov,
N_ens,
g,
err,
Δt,
process,
rng,
fh,
loc,
verbose,
)
end


Expand Down Expand Up @@ -423,7 +445,7 @@ function sample_empirical_gaussian(
) where {FT <: Real, IT <: Int}
cov_u_new = Symmetric(cov(u, dims = 2))
if !isposdef(cov_u_new)
@warn string("Sample covariance matrix over ensemble is singular.", "\n Appplying variance inflation.")
@warn string("Sample covariance matrix over ensemble is singular.", "\n Applying variance inflation.")
if isnothing(inflation)
# Reduce condition number to 1/sqrt(eps(FT))
inflation = eigmax(cov_u_new) * sqrt(eps(FT))
Expand Down
19 changes: 18 additions & 1 deletion src/EnsembleKalmanSampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ function update_ensemble!(
process::Sampler{FT};
failed_ens = nothing,
) where {FT, IT}

#catch works when g non-square
if !(size(g)[2] == ekp.N_ens)
throw(
Expand All @@ -124,8 +123,19 @@ function update_ensemble!(
# u: N_ens × N_par
# g: N_ens × N_obs
u_old = get_u_final(ekp)
cov_init = get_u_cov_final(ekp)

fh = ekp.failure_handler

if ekp.verbose
if get_N_iterations(ekp) == 0
@info "Iteration 0 (prior)"
@info "Covariance trace: $(tr(cov_init))"
end

@info "Iteration $(get_N_iterations(ekp)+1) (T=$(sum(ekp.Δt)))"
end

if isnothing(failed_ens)
_, failed_ens = split_indices_by_success(g)
end
Expand All @@ -143,4 +153,11 @@ function update_ensemble!(
# but stored in data container with N_ens as the 2nd dim

compute_error!(ekp)

# Diagnostics
cov_new = get_u_cov_final(ekp)

if ekp.verbose
@info "Covariance-weighted error: $(get_error(ekp)[end])\nCovariance trace: $(tr(cov_new))\nCovariance trace ratio (current/previous): $(tr(cov_new)/tr(cov_init))"
end
end
12 changes: 0 additions & 12 deletions src/SparseEnsembleKalmanInversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ end
ekp::EnsembleKalmanProcess{FT, IT, SparseInversion{FT}},
g::AbstractMatrix{FT},
process::SparseInversion{FT};
cov_threshold::Real = 0.01,
Δt_new = nothing,
deterministic_forward_map = true,
failed_ens = nothing,
Expand All @@ -187,7 +186,6 @@ Inputs:
- `ekp` :: The EnsembleKalmanProcess to update.
- `g` :: Model outputs, they need to be stored as a `N_obs × N_ens` array (i.e data are columms).
- `process` :: Type of the EKP.
- `cov_threshold` :: Threshold below which the reduction in covariance determinant results in a warning.
- `Δt_new` :: Time step to be used in the current update.
- `deterministic_forward_map` :: Whether output `g` comes from a deterministic model.
- `failed_ens` :: Indices of failed particles. If nothing, failures are computed as columns of `g`
Expand All @@ -197,7 +195,6 @@ function update_ensemble!(
ekp::EnsembleKalmanProcess{FT, IT, SparseInversion{FT}},
g::AbstractMatrix{FT},
process::SparseInversion{FT};
cov_threshold::Real = 0.01,
Δt_new = nothing,
deterministic_forward_map = true,
failed_ens = nothing,
Expand Down Expand Up @@ -246,13 +243,4 @@ function update_ensemble!(

# Check convergence
cov_new = cov(get_u_final(ekp), dims = 2)
cov_ratio = det(cov_new) / det(cov_init)
if cov_ratio < cov_threshold
@warn string(
"New ensemble covariance determinant is less than ",
cov_threshold,
" times its previous value.",
"\nConsider reducing the EK time step.",
)
end
end
22 changes: 22 additions & 0 deletions src/UnscentedKalmanInversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ function EnsembleKalmanProcess(
rng::AbstractRNG = Random.GLOBAL_RNG,
failure_handler_method::FM = IgnoreFailures(),
localization_method::LM = NoLocalization(),
verbose::Bool = false,
) where {FT <: AbstractFloat, IT <: Int, FM <: FailureHandlingMethod, LM <: LocalizationMethod}

#initial parameters stored as columns
Expand All @@ -210,6 +211,10 @@ function EnsembleKalmanProcess(
# localizer
loc = Localizer(localization_method, N_par, N_obs, N_ens, FT)

if verbose
@info "Initializing ensemble Kalman process of type $(nameof(typeof(process)))\nNumber of ensemble members: $(N_ens)\nLocalization: $(nameof(typeof(localization_method)))\nFailure handler: $(nameof(typeof(failure_handler_method)))"
end

EnsembleKalmanProcess{FT, IT, Unscented}(
init_params,
obs_mean,
Expand All @@ -222,6 +227,7 @@ function EnsembleKalmanProcess(
rng,
fh,
loc,
verbose,
)
end

Expand Down Expand Up @@ -652,6 +658,17 @@ function update_ensemble!(

u_p_old = get_u_final(uki)

if uki.verbose
cov_init = get_u_cov_final(uki)

if get_N_iterations(uki) == 0
@info "Iteration 0 (prior)"
@info "Covariance trace: $(tr(cov_init))"
end

@info "Iteration $(get_N_iterations(uki)+1) (T=$(sum(uki.Δt)))"
end

set_Δt!(uki, Δt_new)
fh = uki.failure_handler

Expand All @@ -666,6 +683,11 @@ function update_ensemble!(

push!(uki.u, DataContainer(u_p, data_are_columns = true))

if uki.verbose
cov_new = get_u_cov_final(uki)
@info "Covariance-weighted error: $(get_error(uki)[end])\nCovariance trace: $(tr(cov_new))\nCovariance trace ratio (current/previous): $(tr(cov_new)/tr(cov_init))"
end

return u_p
end

Expand Down

0 comments on commit 5455fff

Please sign in to comment.