Skip to content

Commit

Permalink
get u cov always corrects positive definiteness (#360)
Browse files Browse the repository at this point in the history
!posdef fix

format

correct scaled matrix

typo

format

try without get_u_cov posdefs

re-add posdefs as old tests still fail

try making tests more recent

try 1.8 LTS

add pd correct after sqrt

remove sqrt, bugfix: add rng into sample empirical gaussian

LTS 1.6 revert

remove repeated call to posdef_correct

adapt sparse tests to pass

format

 removed DMC with sparse inversion

format

remove unstable case

allow user-defined inflation matrices

format

docs typo

codecov

format
  • Loading branch information
odunbar committed Feb 6, 2024
1 parent 3428c95 commit b420d83
Show file tree
Hide file tree
Showing 10 changed files with 89 additions and 53 deletions.
27 changes: 12 additions & 15 deletions docs/src/inflation.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,25 @@ Multiplicative inflation can be used by flagging the `update_ensemble!` method a
```

## Additive Inflation
Additive inflation is implemented by systematically adding stochastic perturbations to the parameter ensemble in the form of Gaussian noise. Additive inflation breaks the linear subspace property, meaning the parameter ensemble can evolve outside of the span of the initial ensemble. In additive inflation, the ensemble is perturbed in the following manner after the standard Kalman update:
Additive inflation is implemented by systematically adding stochastic perturbations to the parameter ensemble in the form of Gaussian noise. Additive inflation is capable of breaking the linear subspace property, meaning the parameter ensemble can evolve outside of the span of the current ensemble. In additive inflation, the ensemble is perturbed in the following manner after the standard Kalman update:

```math
u_{n+1} = u_n + \zeta_{n} \qquad (3) \\
\zeta_{n} \sim N(0, \frac{s \Delta{t} }{1 - s \Delta{t}} C_n) \qquad (4)
\zeta_{n} \sim N(0, \frac{s \Delta{t} }{1 - s \Delta{t}} \Sigma) \qquad (4)
```
This can be seen as a stochastic modification of the ensemble covariance, while the mean remains fixed
```math
C_{n + 1} = C_{n} + \frac{s \Delta{t} }{1 - s \Delta{t}} \Sigma \qquad (5)
```
This inflates the parameter covariance by a factor of ``\frac{1}{1 - s \Delta{t}}`` as in eqn. 2 , while the ensemble mean remains fixed.

Additive inflation can be used by flagging the `update_ensemble!` method as follows:
For example, if ``\Sigma = C_{n}`` we see inflation that is statistically equivalent to scaling the parameter covariance by a factor of ``\frac{1}{1 - s \Delta{t}}`` as in eqn. 2.

Additive inflation, by default takes ``\Sigma = C_0`` (the prior covariance), and can be used by flagging the `update_ensemble!` method as follows:
```julia
EKP.update_ensemble!(ekiobj, g_ens; additive_inflation = true, s = 1.0)
```
Alternatively, the prior covariance matrix may be used to generate additive noise, following:
```math
\zeta_{n} \sim N(0, \frac{s \Delta{t} }{1 - s \Delta{t}} C_{0}) \qquad (5)
```
This results in an additive increase in the parameter covariance by `` \frac{s \Delta{t} }{1 - s \Delta{t}} * C_{0}`` , while the mean remains fixed.
```math
C_{n + 1} = C_{n} + \frac{s \Delta{t} }{1 - s \Delta{t}} C_{0} \qquad (6)
```

Additive inflation using the scaled prior covariance (parameter covariance of initial ensemble) can be used by flagging the `update_ensemble!` method as follows:
Any positive semi-definite matrix (or uniform scaling) ``\Sigma`` may be provided to generate additive noise to the ensemble by flagging the `update_ensemble!` method as follows:
```julia
EKP.update_ensemble!(ekiobj, g_ens; additive_inflation = true, use_prior_cov = true, s = 1.0)
Σ = 0.01*I # user defined inflation
EKP.update_ensemble!(ekiobj, g_ens; additive_inflation = true, additive_inflation_cov = Σ, s = 1.0)
```
4 changes: 2 additions & 2 deletions src/EnsembleKalmanInversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ function FailureHandler(process::Inversion, method::SampleSuccGauss)
u[:, successful_ens] =
eki_update(ekp, u[:, successful_ens], g[:, successful_ens], y[:, successful_ens], obs_noise_cov)
if !isempty(failed_ens)
u[:, failed_ens] = sample_empirical_gaussian(u[:, successful_ens], n_failed)
u[:, failed_ens] = sample_empirical_gaussian(ekp.rng, u[:, successful_ens], n_failed)
end
return u
end
Expand Down Expand Up @@ -123,7 +123,7 @@ function update_ensemble!(

# Scale noise using Δt
scaled_obs_noise_cov = ekp.obs_noise_cov / ekp.Δt[end]
noise = rand(ekp.rng, MvNormal(zeros(N_obs), scaled_obs_noise_cov), ekp.N_ens)
noise = sqrt(scaled_obs_noise_cov) * rand(ekp.rng, MvNormal(zeros(N_obs), I), ekp.N_ens)

# Add obs_mean (N_obs) to each column of noise (N_obs × N_ens) if
# G is deterministic
Expand Down
51 changes: 32 additions & 19 deletions src/EnsembleKalmanProcess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -502,15 +502,16 @@ get_error(ekp::EnsembleKalmanProcess) = ekp.err

"""
sample_empirical_gaussian(
rng::AbstractRNG,
u::AbstractMatrix{FT},
n::IT;
inflation::Union{FT, Nothing} = nothing,
) where {FT <: Real, IT <: Int}
Returns `n` samples from an empirical Gaussian based on point estimates `u`, adding inflation
if the covariance is singular.
Returns `n` samples from an empirical Gaussian based on point estimates `u`, adding inflation if the covariance is singular.
"""
function sample_empirical_gaussian(
rng::AbstractRNG,
u::AbstractMatrix{FT},
n::IT;
inflation::Union{FT, Nothing} = nothing,
Expand All @@ -525,9 +526,18 @@ function sample_empirical_gaussian(
cov_u_new = cov_u_new + inflation * I
end
mean_u_new = mean(u, dims = 2)
return rand(MvNormal(mean_u_new[:], cov_u_new), n)
return mean_u_new .+ sqrt(cov_u_new) * rand(rng, MvNormal(zeros(length(mean_u_new[:])), I), n)
end

function sample_empirical_gaussian(
u::AbstractMatrix{FT},
n::IT;
inflation::Union{FT, Nothing} = nothing,
) where {FT <: Real, IT <: Int}
return sample_empirical_gaussian(Random.GLOBAL_RNG, u, n, inflation = inflation)
end


"""
split_indices_by_success(g::AbstractMatrix{FT}) where {FT <: Real}
Expand Down Expand Up @@ -588,33 +598,35 @@ end

"""
additive_inflation!(
ekp::EnsembleKalmanProcess;
use_prior_cov::Bool = false,
ekp::EnsembleKalmanProcess
inflation_cov::AM;
s::FT = 1.0,
) where {FT <: Real}
Applies additive Gaussian noise to particles. Noise is drawn from normal distribution with 0 mean
and scaled parameter covariance. If use_prior_cov=false (default), scales parameter covariance matrix from
current ekp iteration. Otherwise, scales parameter covariance of initial ensemble.
and scaled parameter covariance. The original parameter covariance is a provided matrix, assumed positive semi-definite.
Inputs:
- ekp :: The EnsembleKalmanProcess to update.
- s :: Scaling factor for time step in additive perturbation.
- use_prior_cov :: Bool specifying whether to use prior covariance estimate for additive inflation.
If false (default), parameter covariance from the current iteration is used.
- inflation_cov :: AbstractMatrix provide a N_par x N_par matrix to use.
"""
function additive_inflation!(ekp::EnsembleKalmanProcess; use_prior_cov::Bool = false, s::FT = 1.0) where {FT <: Real}
function additive_inflation!(
ekp::EnsembleKalmanProcess,
inflation_cov::MorUS;
s::FT = 1.0,
) where {FT <: Real, MorUS <: Union{AbstractMatrix, UniformScaling}}

scaled_Δt = s * ekp.Δt[end]

if scaled_Δt >= 1.0
error(string("Scaled time step: ", scaled_Δt, " is >= 1.0", "\nChange s or EK time step."))
end

Σ = use_prior_cov ? get_u_cov_prior(ekp) : get_u_cov_final(ekp)

u = get_u_final(ekp)

Σ_sqrt = sqrt(scaled_Δt / (1 - scaled_Δt) .* inflation_cov)

# add multivariate noise with 0 mean and scaled covariance
noise_multivariate = MvNormal((scaled_Δt / (1 - scaled_Δt)) .* Σ)
u_updated = u + rand(noise_multivariate, size(u, 2))
u_updated = u .+ Σ_sqrt * rand(ekp.rng, MvNormal(zeros(size(u, 1)), I), size(u, 2))
ekp.u[end] = DataContainer(u_updated, data_are_columns = true)
end

Expand All @@ -627,7 +639,7 @@ end
g::AbstractMatrix{FT};
multiplicative_inflation::Bool = false,
additive_inflation::Bool = false,
use_prior_cov::Bool = false,
additive_inflation_cov::MorUS = get_u_cov_prior(ekp),
s::FT = 0.0,
ekp_kwargs...,
) where {FT, IT}
Expand All @@ -637,7 +649,7 @@ Inputs:
- g :: Model outputs, they need to be stored as a `N_obs × N_ens` array (i.e data are columms).
- multiplicative_inflation :: Flag indicating whether to use multiplicative inflation.
- additive_inflation :: Flag indicating whether to use additive inflation.
- use_prior_cov :: Bool specifying whether to use prior covariance estimate for additive inflation.
- additive_inflation_cov :: specifying an additive inflation matrix (default is the prior covariance) assumed positive semi-definite
If false (default), parameter covariance from the current iteration is used.
- s :: Scaling factor for time step in inflation step.
- ekp_kwargs :: Keyword arguments to pass to standard ekp update_ensemble!.
Expand All @@ -647,11 +659,11 @@ function update_ensemble!(
g::AbstractMatrix{FT};
multiplicative_inflation::Bool = false,
additive_inflation::Bool = false,
use_prior_cov::Bool = false,
additive_inflation_cov::MorUS = get_u_cov_prior(ekp),
s::FT = 0.0,
Δt_new::NFT = nothing,
ekp_kwargs...,
) where {FT, NFT <: Union{Nothing, AbstractFloat}}
) where {FT, NFT <: Union{Nothing, AbstractFloat}, MorUS <: Union{AbstractMatrix, UniformScaling}}

#catch works when g non-square
if !(size(g)[2] == ekp.N_ens)
Expand All @@ -668,8 +680,9 @@ function update_ensemble!(
accelerate!(ekp, u)
if s > 0.0
multiplicative_inflation ? multiplicative_inflation!(ekp; s = s) : nothing
additive_inflation ? additive_inflation!(ekp; use_prior_cov = use_prior_cov, s = s) : nothing
additive_inflation ? additive_inflation!(ekp, additive_inflation_cov, s = s) : nothing
end

else
return terminate # true if scheduler has not stepped
end
Expand Down
4 changes: 2 additions & 2 deletions src/EnsembleKalmanSampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,13 @@ function eks_update(
# Default: Δt = 1 / (norm(D) + eps(FT))
Δt = ekp.Δt[end]

noise = MvNormal(u_cov)
noise = MvNormal(zeros(size(u_cov, 1)), I)

implicit =
(1 * Matrix(I, size(u)[2], size(u)[2]) + Δt * (ekp.process.prior_cov' \ u_cov')') \
(u' .- Δt * (u' .- u_mean) * D .+ Δt * u_cov * (ekp.process.prior_cov \ ekp.process.prior_mean))

u = implicit' + sqrt(2 * Δt) * rand(ekp.rng, noise, ekp.N_ens)'
u = implicit' + sqrt(2 * Δt) * (sqrt(u_cov) * rand(ekp.rng, noise, ekp.N_ens))'

return u
end
Expand Down
2 changes: 1 addition & 1 deletion src/EnsembleTransformKalmanInversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ function FailureHandler(process::TransformInversion, method::SampleSuccGauss)
n_failed = length(failed_ens)
u[:, successful_ens] = etki_update(ekp, u[:, successful_ens], g[:, successful_ens], y, obs_noise_cov)
if !isempty(failed_ens)
u[:, failed_ens] = sample_empirical_gaussian(u[:, successful_ens], n_failed)
u[:, failed_ens] = sample_empirical_gaussian(ekp.rng, u[:, successful_ens], n_failed)
end
return u
end
Expand Down
8 changes: 5 additions & 3 deletions src/LearningRateSchedulers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,11 @@ function posdef_correct(mat::AbstractMatrix; tol::Real = 1e8 * eps())
out = mat
end

nugget = abs(minimum(eigvals(out)))
for i in 1:size(out, 1)
out[i, i] += nugget + tol #add to diag
if !isposdef(out)
nugget = abs(minimum(eigvals(out)))
for i in 1:size(out, 1)
out[i, i] += nugget + tol # add to diag
end
end
return out
end
Expand Down
4 changes: 2 additions & 2 deletions src/SparseEnsembleKalmanInversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ function FailureHandler(process::SparseInversion, method::SampleSuccGauss)
u[:, successful_ens] =
sparse_eki_update(ekp, u[:, successful_ens], g[:, successful_ens], y[:, successful_ens], obs_noise_cov)
if !isempty(failed_ens)
u[:, failed_ens] = sample_empirical_gaussian(u[:, successful_ens], n_failed)
u[:, failed_ens] = sample_empirical_gaussian(ekp.rng, u[:, successful_ens], n_failed)
end
return u
end
Expand Down Expand Up @@ -206,7 +206,7 @@ function update_ensemble!(

# Scale noise using Δt
scaled_obs_noise_cov = ekp.obs_noise_cov / ekp.Δt[end]
noise = rand(ekp.rng, MvNormal(zeros(N_obs), scaled_obs_noise_cov), ekp.N_ens)
noise = sqrt(scaled_obs_noise_cov) * rand(ekp.rng, MvNormal(zeros(N_obs), I), ekp.N_ens)

# Add obs_mean (N_obs) to each column of noise (N_obs × N_ens) if
# G is deterministic
Expand Down
26 changes: 23 additions & 3 deletions test/EnsembleKalmanProcess/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,18 @@ end


@testset "LearningRateSchedulers" begin

# Utility
X = [2 1; 1.1 2] # correct with symmetrisation
@test isposdef(posdef_correct(X))
@test posdef_correct(X) 0.5 * (X + permutedims(X, (2, 1))) atol = 1e-8
Y = [0 1; -1 0]
tol = 1e-8
@test isposdef(posdef_correct(Y, tol = tol)) # symmetrize and add to diagonal
@test posdef_correct(Y, tol = tol) tol * I(2) atol = 1e-8



# Default
Δt = 3
dlrs1 = EKP.DefaultScheduler()
Expand Down Expand Up @@ -944,15 +956,23 @@ end
end
@test_logs (:warn, r"More than 50% of runs produced NaNs") match_mode = :any split_indices_by_success(g)


rng = Random.MersenneTwister(rng_seed)

u = rand(10, 4)
@test_logs (:warn, r"Sample covariance matrix over ensemble is singular.") match_mode = :any sample_empirical_gaussian(
u,
2,
)
@test_throws PosDefException sample_empirical_gaussian(u, 2, inflation = 0.0)

# Initial ensemble construction
rng = Random.MersenneTwister(rng_seed)
u2 = rand(rng, 5, 20)
@test all(
isapprox.(
sample_empirical_gaussian(copy(rng), u2, 2),
sample_empirical_gaussian(copy(rng), u2, 2, inflation = 0.0);
atol = 1e-8,
),
)

### sanity check on rng:
d = Parameterized(Normal(0, 1))
Expand Down
9 changes: 7 additions & 2 deletions test/Inflation/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,20 +71,25 @@ initial_ensemble = EKP.construct_initial_ensemble(rng, prior, N_ens)
eki_mult_inflation = deepcopy(ekiobj)
eki_add_inflation = deepcopy(ekiobj)
eki_add_inflation_prior = deepcopy(ekiobj)
eki_add_inflation_I = deepcopy(ekiobj)

# multiplicative inflation after standard update
EKP.multiplicative_inflation!(eki_mult_inflation)
# additive inflation after standard update
EKP.additive_inflation!(eki_add_inflation)
EKP.additive_inflation!(eki_add_inflation, get_u_cov_final(eki_add_inflation))
# additive inflation (scaling prior cov) after standard update
EKP.additive_inflation!(eki_add_inflation_prior; use_prior_cov = true)
EKP.additive_inflation!(eki_add_inflation_prior, get_u_cov_prior(eki_add_inflation_prior))
# additive inflation (scaling prior cov) after standard update
EKP.additive_inflation!(eki_add_inflation_I, I)

# ensure multiplicative inflation approximately preserves ensemble mean
@test get_u_mean_final(ekiobj) get_u_mean_final(eki_mult_inflation) atol = 0.2
# ensure additive inflation approximately preserves ensemble mean
@test get_u_mean_final(ekiobj) get_u_mean_final(eki_add_inflation) atol = 0.2
# ensure additive inflation (scaling prior cov) approximately preserves ensemble mean
@test get_u_mean_final(ekiobj) get_u_mean_final(eki_add_inflation_prior) atol = 0.2
# ensure additive inflation approximately preserves ensemble mean
@test get_u_mean_final(ekiobj) get_u_mean_final(eki_add_inflation_I) atol = 0.2

# ensure inflation expands ensemble variance as expected
expected_var_gain = 1 / (1 - Δt)
Expand Down
7 changes: 3 additions & 4 deletions test/SparseInversion/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,14 @@ include("../EnsembleKalmanProcess/inverse_problem.jl")

## Repeat first test with several schedulers
y_obs, G, Γy = nl_inv_problems[1]

T_end = 3
schedulers = [
DefaultScheduler(0.1),
MutableScheduler(0.1),
DataMisfitController(terminate_at = T_end),
DataMisfitController(on_terminate = "continue"),
DataMisfitController(on_terminate = "continue_fixed"),
# DataMisfitController(terminate_at = T_end), # This test can be unstable
]
N_iters = [10, 10, 50, 50, 50]
N_iters = [10, 10]# ..., 20]

final_ensembles = []
init_means = []
Expand Down

0 comments on commit b420d83

Please sign in to comment.