Skip to content

Commit

Permalink
Merge #338
Browse files Browse the repository at this point in the history
338: Prevent DataMisfitController scheduler from modifying EKP obs_noise_cov r=costachris a=costachris

`posdef_correct` was modifying inputs in-place due to the for loop over diagonal elements in the symmetric matrix case. This resulted in the DMC scheduler directly modifying the EKP obs_noise_cov matrix in the first iteration. This PR:
1. Uses deepcopy to prevent in-place modification of the EKP obs_noise_cov and modify tests to detect this.
2. Exports `posdef_correct` for external use.

Note: examples/LearningRateSchedulers is using `UniformScaling` and remains unaffected. 

Co-authored-by: costachris <[email protected]>
  • Loading branch information
bors[bot] and costachris committed Oct 23, 2023
2 parents ad23cac + 04b7a42 commit 8f8539e
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/LearningRateSchedulers.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# included in EnsembleKalmanProcess.jl

export DefaultScheduler, MutableScheduler, EKSStableScheduler, DataMisfitController
export calculate_timestep!
export calculate_timestep!, posdef_correct

# default unless user overrides

Expand Down Expand Up @@ -222,6 +222,7 @@ $(DocStringExtensions.TYPEDSIGNATURES)
Makes square matrix `mat` positive definite, by symmetrizing and bounding the minimum eigenvalue below by `tol`
"""
function posdef_correct(mat::AbstractMatrix; tol::Real = 1e8 * eps())
mat = deepcopy(mat)
if !issymmetric(mat)
out = 0.5 * (mat + permutedims(mat, (2, 1))) #symmetrize
if isposdef(out)
Expand Down
5 changes: 5 additions & 0 deletions test/EnsembleKalmanProcess/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ end
else #no initial ensemble for UKI
ekpobj = EKP.EnsembleKalmanProcess(y_obs, Γy, process, rng = copy(rng), scheduler = scheduler)
end
initial_obs_noise_cov = deepcopy(ekpobj.obs_noise_cov)
for i in 1:N_iter
params_i = get_ϕ_final(prior, ekpobj)
g_ens = G(params_i)
Expand All @@ -290,9 +291,13 @@ end
if !isnothing(terminated)
break
end
# ensure Δt is updated
@test length(ekpobj.Δt) == i
end
push!(init_means, vec(mean(get_u_prior(ekpobj), dims = 2)))
push!(final_means, vec(mean(get_u_final(ekpobj), dims = 2)))
# ensure obs_noise_cov matrix remains unchanged
@test initial_obs_noise_cov == ekpobj.obs_noise_cov

# this test is fine so long as N_iter is large enough to hit the termination time
if nameof(typeof(scheduler)) == DataMisfitController
Expand Down

0 comments on commit 8f8539e

Please sign in to comment.