Skip to content

Commit

Permalink
reduce test time, and convert all UniformScalings to Diagonals
Browse files Browse the repository at this point in the history
  • Loading branch information
odunbar committed Jul 11, 2024
1 parent 58a5b03 commit cc33614
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 24 deletions.
54 changes: 32 additions & 22 deletions src/Observations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,25 @@ function Observation(obs_dict::Dict)
snew = samples_tmp # [[1,2,3]]
end
end

if !isa(covariances, AbstractVector) # [2 1;1 2] -> [[2 1;1 2]]
cnew = [covariances]
ctmp = [covariances]
else
T = promote_type((typeof(c) for c in covariances)...)
cnew = [convert(T, c) for c in covariances] # to re-infer eltype
ctmp = covariances
end

# additionally provide a dimension for UniformScalings for covariances
ctmp2 = []
for (id,c) in enumerate(ctmp)
if isa(c,UniformScaling)
push!(ctmp2, Diagonal(c.λ*ones(length(snew[id])))) # get dim from samples
else
push!(ctmp2, c)
end
end
# then promote
T = promote_type((typeof(c) for c in ctmp2)...)
cnew = [convert(T, c) for c in ctmp2] # to re-infer eltype

if !("inv_covariances" collect(keys(obs_dict)))
inv_covariances = []
for c in cnew # ensures its a vector
Expand All @@ -142,12 +154,22 @@ function Observation(obs_dict::Dict)
inv_covariances = obs_dict["inv_covariances"]
end
if !isa(inv_covariances, AbstractVector) # [2 1;1 2] -> [[2 1;1 2]]
icnew = [inv_covariances]
ictmp = [inv_covariances]
else
T = promote_type((typeof(c) for c in inv_covariances)...)
icnew = [convert(T, c) for c in inv_covariances] # to re-infer eltype
ictmp = inv_covariances
end

# additionally provide a dimension for UniformScalings
ictmp2=[]
for (id,c) in enumerate(ictmp)
if isa(c,UniformScaling)
push!(ictmp2, Diagonal(c.λ*ones(length(snew[id])))) # get dim from samples
else
push!(ictmp2, c)
end
end
T = promote_type((typeof(c) for c in ictmp2)...)
icnew = [convert(T, c) for c in ictmp2] # to re-infer eltype

if !isa(names, AbstractVector) # "name" -> ["name"]
nnew = [names]
else
Expand Down Expand Up @@ -249,13 +271,7 @@ function get_obs_noise_cov(o::Observation; build = true)
covs = get_covs(o)
cov_full = zeros(maximum(indices[end]), maximum(indices[end]))
for (idx, c) in zip(indices, covs)
if isa(c, UniformScaling)
for idxx in idx
cov_full[idxx, idxx] = c.λ
end
else
cov_full[idx, idx] .= c
end
cov_full[idx, idx] .= c
end

return cov_full
Expand All @@ -276,13 +292,7 @@ function get_obs_noise_cov_inv(o::Observation; build = true)
inv_covs = get_inv_covs(o)
inv_cov_full = zeros(maximum(indices[end]), maximum(indices[end]))
for (idx, c) in zip(indices, inv_covs)
if isa(c, UniformScaling)
for idxx in idx
inv_cov_full[idxx, idxx] = c.λ
end
else
inv_cov_full[idx, idx] .= c
end
inv_cov_full[idx, idx] .= c
end

return inv_cov_full
Expand Down
5 changes: 3 additions & 2 deletions test/EnsembleKalmanProcess/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -932,15 +932,16 @@ end
failure_handler_method = SampleSuccGauss(),
)
T = 0.0
for i in 1:N_iter
N_iter_new = 5
for i in 1:N_iter_new
params_i = get_ϕ_final(prior, ekiobj)
g_ens = G_test(params_i)
dt = @elapsed EKP.update_ensemble!(ekiobj, g_ens)
T += dt
end
# Skip timing of first due to precompilation
if i >= 2
@info "ETKI with $n_obs_test observations took $T seconds. (average update taking: $(T/Float64(N_iter)))"
@info "$N_iter_new iterations of ETKI with $n_obs_test observations took $T seconds. (avg update: $(T/Float64(N_iter)))"
end
end
end
Expand Down

0 comments on commit cc33614

Please sign in to comment.