Skip to content

Commit

Permalink
interface for EKP
Browse files Browse the repository at this point in the history
  • Loading branch information
odunbar committed Jun 21, 2024
1 parent 87b45b9 commit 5e0c301
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 25 deletions.
76 changes: 55 additions & 21 deletions src/EnsembleKalmanProcess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ export get_u_prior, get_u_final, get_g_final, get_ϕ_final
export get_N_iterations, get_error, get_cov_blocks
export get_u_mean, get_u_cov, get_g_mean, get_ϕ_mean
export get_u_mean_final, get_u_cov_prior, get_u_cov_final, get_g_mean_final, get_ϕ_mean_final, get_accelerator
export get_obs, get_obs_noise_cov
export get_observation_series, get_obs, get_obs_noise_cov
export compute_error!
export update_ensemble!
export sample_empirical_gaussian, split_indices_by_success
Expand Down Expand Up @@ -80,7 +80,7 @@ $(TYPEDFIELDS)
EnsembleKalmanProcess(
params::AbstractMatrix{FT},
obs,
observation_series::OS,
obs_noise_cov::Union{AbstractMatrix{FT}, UniformScaling{FT}},
process::P;
scheduler = DefaultScheduler(1),
Expand All @@ -89,13 +89,12 @@ $(TYPEDFIELDS)
failure_handler_method::FM = IgnoreFailures(),
localization_method::LM = NoLocalization(),
verbose::Bool = false,
) where {FT <: AbstractFloat, P <: Process, FM <: FailureHandlingMethod, LM <: LocalizationMethod}
) where {FT <: AbstractFloat, P <: Process, FM <: FailureHandlingMethod, LM <: LocalizationMethod, OS <: ObservationSeries}
Inputs:
- `params` :: Initial parameter ensemble
- `obs` :: Vector of observations
- `obs_noise_cov` :: Noise covariance associated with the observations `obs`
- `observation_series` :: Container for observations (and possible minibatching)
- `process` :: Algorithm used to evolve the ensemble
- `scheduler` :: Adaptive timestep calculator
- `Δt` :: Initial time step or learning rate
Expand All @@ -117,10 +116,8 @@ struct EnsembleKalmanProcess{
}
"array of stores for parameters (`u`), each of size [`N_par × N_ens`]"
u::Array{DataContainer{FT}}
"vector of the observed vector size [`N_obs`]"
obs::Vector{FT}
"covariance matrix of the observational noise, of size [`N_obs × N_obs`]"
obs_noise_cov::Union{AbstractMatrix{FT}, UniformScaling{FT}}
"Container for the observation(s) - and minibatching mechanism"
observation_series::ObservationSeries
"ensemble size"
N_ens::IT
"Array of stores for forward model outputs, each of size [`N_obs × N_ens`]"
Expand All @@ -147,8 +144,7 @@ end

function EnsembleKalmanProcess(
params::AbstractMatrix{FT},
obs,
obs_noise_cov::Union{AbstractMatrix{FT}, UniformScaling{FT}},
observation_series::OS,
process::P;
scheduler::Union{Nothing, LRS} = nothing,
accelerator::Union{Nothing, ACC} = nothing,
Expand All @@ -164,13 +160,15 @@ function EnsembleKalmanProcess(
P <: Process,
FM <: FailureHandlingMethod,
LM <: LocalizationMethod,
OS <: ObservationSeries,
}

#initial parameters stored as columns
init_params = DataContainer(params, data_are_columns = true)

# dimensionality
N_par, N_ens = size(init_params) #stored with data as columns
obs = get_obs(observation_series)
N_obs = length(obs)

IT = typeof(N_ens)
Expand Down Expand Up @@ -233,8 +231,7 @@ function EnsembleKalmanProcess(

EnsembleKalmanProcess{FT, IT, P, RS, AC}(
[init_params],
obs,
obs_noise_cov,
observation_series,
N_ens,
g,
err,
Expand All @@ -249,6 +246,28 @@ function EnsembleKalmanProcess(
)
end

function EnsembleKalmanProcess(
params::AbstractMatrix{FT},
observation::OB,
args...;
kwargs...,
) where {FT <: AbstractFloat, OB <: Observation}
observation_series = ObservationSeries(observation)
return EnsembleKalmanProcess(params, observation_series, args...; kwargs...)
end

function EnsembleKalmanProcess(
params::AbstractMatrix{FT},
obs,
obs_noise_cov::Union{AbstractMatrix{FT}, UniformScaling{FT}},
args...;
kwargs...,
) where {FT <: AbstractFloat}

observation = Observation(Dict("samples" => obs, "covariances" => obs_noise_cov, "names" => "observation"))

return EnsembleKalmanProcess(params, observation, args...; kwargs...)
end

include("LearningRateSchedulers.jl")

Expand Down Expand Up @@ -461,21 +480,32 @@ function get_accelerator(ekp::EnsembleKalmanProcess)
end

"""
get_obs_noise_cov(ekp::EnsembleKalmanProcess)
get_observation_series(ekp::EnsembleKalmanProcess)
Return `obs_noise_cov` field of EnsembleKalmanProcess.
"""
function get_obs_noise_cov(ekp::EnsembleKalmanProcess)
return ekp.obs_noise_cov
function get_observation_series(ekp::EnsembleKalmanProcess)
return ekp.observation_series
end

"""
get_obs(ekp::EnsembleKalmanProcess)
Return `obs` field of EnsembleKalmanProcess.
get_obs_noise_cov(ekp::EnsembleKalmanProcess; build=true)
convenience function to get the obs_noise_cov from the current batch in ObservationSeries
build=false:, returns a vector of blocks,
build=true: returns a block matrix,
"""
function get_obs(ekp::EnsembleKalmanProcess)
return ekp.obs
function get_obs_noise_cov(ekp::EnsembleKalmanProcess, build = true)
return get_obs_noise_cov(get_observation_series(ekp), build = build)
end

"""
get_obs(ekp::EnsembleKalmanProcess; build=true)
Get the observation from the current batch in ObservationSeries
build=false: returns a vector of vectors,
build=true: returns a concatenated vector,
"""
function get_obs(ekp::EnsembleKalmanProcess; build = true)
return get_obs(get_observation_series(ekp), build = build)
end


"""
Expand Down Expand Up @@ -504,7 +534,7 @@ The error is stored within the `EnsembleKalmanProcess`.
"""
function compute_error!(ekp::EnsembleKalmanProcess)
mean_g = dropdims(mean(get_g_final(ekp), dims = 2), dims = 2)
diff = ekp.obs - mean_g
diff = get_obs(ekp) - mean_g
X = get_obs_noise_cov(ekp) \ diff # diff: column vector
newerr = dot(diff, X)
push!(ekp.err, newerr)
Expand Down Expand Up @@ -704,6 +734,10 @@ function update_ensemble!(
else
return terminate # true if scheduler has not stepped
end

# update to next minibatch (if minibatching)
observation_series = get_observation_series(ekp)
next_minibatch = update_minibatch!(observation_series)
return nothing

end
Expand Down
12 changes: 9 additions & 3 deletions src/Observations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,13 @@ function get_obs_noise_cov(o::Observation; build = true)
covs = get_covs(o)
cov_full = zeros(maximum(indices[end]), maximum(indices[end]))
for (idx, c) in zip(indices, covs)
cov_full[idx, idx] = c
if isa(c, UniformScaling)
for idxx in idx
cov_full[idxx, idxx] = c.λ
end
else
cov_full[idx, idx] .= c
end
end

return cov_full
Expand Down Expand Up @@ -439,7 +445,7 @@ function get_obs(os::OS; build = true) where {OS <: ObservationSeries}
minibatch_length = length(minibatch)
observations_vec = get_observations(os)[minibatch] # gives observation objects
if minibatch_length == 1
return get_obs(observations_vec, build = build)
return get_obs(observations_vec[1], build = build)
end

if !build # return y as vec of vecs
Expand All @@ -461,7 +467,7 @@ function get_obs_noise_cov(os::OS; build = true) where {OS <: ObservationSeries}
minibatch_length = length(minibatch)
observations_vec = get_observations(os)[minibatch] # gives observation objects
if minibatch_length == 1 # if only 1 sample then return it
return get_obs_noise_cov(observations_vec[minibatch], build = build)
return get_obs_noise_cov(observations_vec[1], build = build)
else
minibatch_covs = []
for observation in observations_vec
Expand Down
2 changes: 1 addition & 1 deletion test/EnsembleKalmanProcess/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ end

## some getters in EKP
@test get_obs(ekiobj) == y_obs
@test get_obs_noise_cov(ekiobj) == obs_noise_cov
@test get_obs_noise_cov(ekiobj) == Γy

g_ens = G(get_ϕ_final(prior, ekiobj))
g_ens_t = permutedims(g_ens, (2, 1))
Expand Down

0 comments on commit 5e0c301

Please sign in to comment.