Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: adding momentum-inspired accelerators to EKP #322

Merged
merged 1 commit into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,4 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["StableRNGs", "Test", "Plots"]
test = ["StableRNGs", "Test", "Plots"]
94 changes: 94 additions & 0 deletions src/Accelerators.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# included in EnsembleKalmanProcess.jl

export DefaultAccelerator, NesterovAccelerator
export update_state!, set_initial_acceleration!

"""
$(TYPEDEF)

Default accelerator provides no acceleration, runs traditional EKI
"""
struct DefaultAccelerator <: Accelerator end

"""
$(TYPEDEF)

Accelerator that adapts Nesterov's momentum method for EKI.
Stores a previous state value u_prev for computational purposes (note this is distinct from state returned as "ensemble value")

$(TYPEDFIELDS)
"""
mutable struct NesterovAccelerator{FT <: AbstractFloat} <: Accelerator
r::FT
u_prev::Any
end

function NesterovAccelerator(r = 3.0, initial = Float64[])
return NesterovAccelerator(r, initial)
end


"""
Sets u_prev to the initial parameter values
"""
function set_ICs!(accelerator::NesterovAccelerator{FT}, u::MA) where {FT <: AbstractFloat, MA <: AbstractMatrix{FT}}
accelerator.u_prev = u
end


"""
Performs traditional state update with no momentum.
"""
function update_state!(
ekp::EnsembleKalmanProcess{FT, IT, P, LRS, DefaultAccelerator},
u::MA,
) where {FT <: AbstractFloat, IT <: Int, P <: Process, LRS <: LearningRateScheduler, MA <: AbstractMatrix{FT}}
push!(ekp.u, DataContainer(u, data_are_columns = true))
end

"""
Performs state update with modified Nesterov momentum approach.
"""
function update_state!(
ekp::EnsembleKalmanProcess{FT, IT, P, LRS, NesterovAccelerator{FT}},
u::MA,
) where {FT <: AbstractFloat, IT <: Int, P <: Process, LRS <: LearningRateScheduler, MA <: AbstractMatrix{FT}}
## update "v" state:
k = get_N_iterations(ekp) + 2
v = u .+ (1 - ekp.accelerator.r / k) * (u .- ekp.accelerator.u_prev)

## update "u" state:
ekp.accelerator.u_prev = u

## push "v" state to EKP object
push!(ekp.u, DataContainer(v, data_are_columns = true))
end


"""
State update method for UKI with no acceleration.
The Accelerator framework has not yet been integrated with UKI process;
UKI tracks its own states, so this method is empty.
"""
function update_state!(
ekp::EnsembleKalmanProcess{FT, IT, P, LRS, DefaultAccelerator},
u::MA,
) where {FT <: AbstractFloat, IT <: Int, P <: Unscented, LRS <: LearningRateScheduler, MA <: AbstractMatrix{FT}}

end

"""
Placeholder state update method for UKI with Nesterov Accelerator.
The Accelerator framework has not yet been integrated with UKI process, so this
method throws an error.
"""
function update_state!(
ekp::EnsembleKalmanProcess{FT, IT, P, LRS, NesterovAccelerator{FT}},
u::MA,
) where {FT <: AbstractFloat, IT <: Int, P <: Unscented, LRS <: LearningRateScheduler, MA <: AbstractMatrix{FT}}
throw(
ArgumentError(
"option `accelerator = NesterovAccelerator` is not implemented for UKI, please use `DefaultAccelerator`",
),
)
end
6 changes: 3 additions & 3 deletions src/EnsembleKalmanInversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,17 +138,17 @@ function update_ensemble!(

u = fh.failsafe_update(ekp, u, g, y, scaled_obs_noise_cov, failed_ens)

# store new parameters (and model outputs)
push!(ekp.u, DataContainer(u, data_are_columns = true))
push!(ekp.g, DataContainer(g, data_are_columns = true))

# Store error
compute_error!(ekp)

# Diagnostics
cov_new = cov(get_u_final(ekp), dims = 2)
cov_new = cov(u, dims = 2)

if ekp.verbose
@info "Covariance-weighted error: $(get_error(ekp)[end])\nCovariance trace: $(tr(cov_new))\nCovariance trace ratio (current/previous): $(tr(cov_new)/tr(cov_init))"
end

return u
end
52 changes: 46 additions & 6 deletions src/EnsembleKalmanProcess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
export get_u_prior, get_u_final, get_g_final, get_ϕ_final
export get_N_iterations, get_error, get_cov_blocks
export get_u_mean, get_u_cov, get_g_mean, get_ϕ_mean
export get_u_mean_final, get_u_cov_prior, get_u_cov_final, get_g_mean_final, get_ϕ_mean_final
export get_u_mean_final, get_u_cov_prior, get_u_cov_final, get_g_mean_final, get_ϕ_mean_final, get_accelerator
export compute_error!
export update_ensemble!
export sample_empirical_gaussian, split_indices_by_success
Expand All @@ -29,6 +29,9 @@
# Failure handlers
abstract type FailureHandlingMethod end

# Accelerators
abstract type Accelerator end



"Failure handling method that ignores forward model failures"
Expand Down Expand Up @@ -104,7 +107,13 @@

$(METHODLIST)
"""
struct EnsembleKalmanProcess{FT <: AbstractFloat, IT <: Int, P <: Process, LRS <: LearningRateScheduler}
struct EnsembleKalmanProcess{
FT <: AbstractFloat,
IT <: Int,
P <: Process,
LRS <: LearningRateScheduler,
ACC <: Accelerator,
}
"array of stores for parameters (`u`), each of size [`N_par × N_ens`]"
u::Array{DataContainer{FT}}
"vector of the observed vector size [`N_obs`]"
Expand All @@ -119,6 +128,8 @@
err::Vector{FT}
"Scheduler to calculate the timestep size in each EK iteration"
scheduler::LRS
"accelerator object that informs EK update steps, stores additional state variables as needed"
accelerator::ACC
"stored vector of timesteps used in each EK iteration"
Δt::Vector{FT}
"the particular EK process (`Inversion` or `Sampler` or `Unscented` or `TransformInversion` or `SparseInversion`)"
Expand All @@ -139,6 +150,7 @@
obs_noise_cov::Union{AbstractMatrix{FT}, UniformScaling{FT}},
process::P;
scheduler::Union{Nothing, LRS} = nothing,
accelerator::Union{Nothing, ACC} = nothing,
Δt = nothing,
rng::AbstractRNG = Random.GLOBAL_RNG,
failure_handler_method::FM = IgnoreFailures(),
Expand All @@ -147,6 +159,7 @@
) where {
FT <: AbstractFloat,
LRS <: LearningRateScheduler,
ACC <: Accelerator,
P <: Process,
FM <: FailureHandlingMethod,
LM <: LocalizationMethod,
Expand Down Expand Up @@ -193,23 +206,39 @@
# timestep store
Δt = FT[]

# set up accelerator
if isnothing(accelerator)
acc = DefaultAccelerator()
else
acc = accelerator
end
AC = typeof(acc)

if AC <: NesterovAccelerator
set_ICs!(acc, params)
if P <: Sampler
@warn "Acceleration is experimental for Sampler processes and may affect convergence."
end
end

# failure handler
fh = FailureHandler(process, failure_handler_method)
# localizer
loc = Localizer(localization_method, N_par, N_obs, N_ens, FT)

if verbose
@info "Initializing ensemble Kalman process of type $(nameof(typeof(process)))\nNumber of ensemble members: $(N_ens)\nLocalization: $(nameof(typeof(localization_method)))\nFailure handler: $(nameof(typeof(failure_handler_method)))\nScheduler: $(nameof(typeof(lrs)))"
@info "Initializing ensemble Kalman process of type $(nameof(typeof(process)))\nNumber of ensemble members: $(N_ens)\nLocalization: $(nameof(typeof(localization_method)))\nFailure handler: $(nameof(typeof(failure_handler_method)))\nScheduler: $(nameof(typeof(lrs)))\nAccelerator: $(nameof(typeof(acc)))"

Check warning on line 230 in src/EnsembleKalmanProcess.jl

View check run for this annotation

Codecov / codecov/patch

src/EnsembleKalmanProcess.jl#L230

Added line #L230 was not covered by tests
odunbar marked this conversation as resolved.
Show resolved Hide resolved
end

EnsembleKalmanProcess{FT, IT, P, RS}(
EnsembleKalmanProcess{FT, IT, P, RS, AC}(
[init_params],
obs_mean,
obs_noise_cov,
N_ens,
g,
err,
lrs,
acc,
Δt,
process,
rng,
Expand All @@ -222,7 +251,6 @@

include("LearningRateSchedulers.jl")


"""
get_u(ekp::EnsembleKalmanProcess, iteration::IT; return_array=true) where {IT <: Integer}

Expand Down Expand Up @@ -423,6 +451,14 @@
return ekp.scheduler
end

"""
get_accelerator(ekp::EnsembleKalmanProcess)
Return accelerator type of EnsembleKalmanProcess.
"""
function get_accelerator(ekp::EnsembleKalmanProcess)
odunbar marked this conversation as resolved.
Show resolved Hide resolved
return ekp.accelerator
end


"""
construct_initial_ensemble(
Expand Down Expand Up @@ -628,7 +664,8 @@

terminate = calculate_timestep!(ekp, g, Δt_new)
if isnothing(terminate)
update_ensemble!(ekp, g, get_process(ekp); ekp_kwargs...)
u = update_ensemble!(ekp, g, get_process(ekp); ekp_kwargs...)
update_state!(ekp, u)
if s > 0.0
multiplicative_inflation ? multiplicative_inflation!(ekp; s = s) : nothing
additive_inflation ? additive_inflation!(ekp; use_prior_cov = use_prior_cov, s = s) : nothing
Expand Down Expand Up @@ -664,3 +701,6 @@
export Gaussian_2d
export construct_initial_ensemble, construct_mean, construct_cov
include("UnscentedKalmanInversion.jl")

# struct Accelerator
include("Accelerators.jl")
5 changes: 3 additions & 2 deletions src/EnsembleKalmanSampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,17 +139,18 @@ function update_ensemble!(
u = fh.failsafe_update(ekp, u_old, g, failed_ens)

# store new parameters (and model outputs)
push!(ekp.u, DataContainer(u, data_are_columns = true))
push!(ekp.g, DataContainer(g, data_are_columns = true))
# u_old is N_ens × N_par, g is N_ens × N_obs,
# but stored in data container with N_ens as the 2nd dim

compute_error!(ekp)

# Diagnostics
cov_new = get_u_cov_final(ekp)
cov_new = cov(u, dims = 2)

if ekp.verbose
@info "Covariance-weighted error: $(get_error(ekp)[end])\nCovariance trace: $(tr(cov_new))\nCovariance trace ratio (current/previous): $(tr(cov_new)/tr(cov_init))"
end

return u
end
5 changes: 3 additions & 2 deletions src/EnsembleTransformKalmanInversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,17 @@ function update_ensemble!(
u = fh.failsafe_update(ekp, u, g, y, scaled_obs_noise_cov, failed_ens)

# store new parameters (and model outputs)
push!(ekp.u, DataContainer(u, data_are_columns = true))
push!(ekp.g, DataContainer(g, data_are_columns = true))

# Store error
compute_error!(ekp)

# Diagnostics
cov_new = cov(get_u_final(ekp), dims = 2)
cov_new = cov(u, dims = 2)

if ekp.verbose
@info "Covariance-weighted error: $(get_error(ekp)[end])\nCovariance trace: $(tr(cov_new))\nCovariance trace ratio (current/previous): $(tr(cov_new)/tr(cov_init))"
end

return u
end
3 changes: 2 additions & 1 deletion src/SparseEnsembleKalmanInversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,12 +222,13 @@ function update_ensemble!(
u = fh.failsafe_update(ekp, u, g, y, scaled_obs_noise_cov, failed_ens)

# store new parameters (and model outputs)
push!(ekp.u, DataContainer(u, data_are_columns = true))
push!(ekp.g, DataContainer(g, data_are_columns = true))

# Store error
compute_error!(ekp)

# Check convergence
cov_new = cov(get_u_final(ekp), dims = 2)
sydneyvernon marked this conversation as resolved.
Show resolved Hide resolved

return u
end
Loading
Loading