Skip to content

Commit

Permalink
modified accelerator for variable timestep
Browse files Browse the repository at this point in the history
simplify expression

format

rm typo in UKI that causes warning

needs a sqrt

remove r
  • Loading branch information
odunbar committed Oct 25, 2023
1 parent 02d8ad6 commit bdbb9e0
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 19 deletions.
34 changes: 26 additions & 8 deletions src/Accelerators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ Stores a previous state value u_prev for computational purposes (note this is di
$(TYPEDFIELDS)
"""
mutable struct NesterovAccelerator{FT <: AbstractFloat} <: Accelerator
r::FT
u_prev::Any
u_prev::Array{FT}
θ_prev::FT
end

function NesterovAccelerator(r = 3.0, initial = Float64[])
return NesterovAccelerator(r, initial)
function NesterovAccelerator(initial = Float64[])
return NesterovAccelerator(initial, 1.0)
end


Expand Down Expand Up @@ -54,11 +54,20 @@ function accelerate!(
u::MA,
) where {FT <: AbstractFloat, IT <: Int, P <: Process, LRS <: LearningRateScheduler, MA <: AbstractMatrix}
## update "v" state:
k = get_N_iterations(ekp) + 2
v = u .+ (1 - ekp.accelerator.r / k) * (u .- ekp.accelerator.u_prev)
#v = u .+ 2 / (get_N_iterations(ekp) + 2) * (u .- ekp.accelerator.u_prev)
Δt_prev = length(ekp.Δt) == 1 ? 1 : ekp.Δt[end - 1]
Δt = ekp.Δt[end]
θ_prev = ekp.accelerator.θ_prev

# condition θ_prev^2 * (1 - θ) * Δt \leq Δt_prev * θ^2
a = sqrt(θ_prev^2 * Δt / Δt_prev)
θ = (-a + sqrt(a^2 + 4)) / 2

v = u .+ θ * (1 / θ_prev - 1) * (u .- ekp.accelerator.u_prev)

## update "u" state:
ekp.accelerator.u_prev = u
ekp.accelerator.θ_prev = θ

## push "v" state to EKP object
push!(ekp.u, DataContainer(v, data_are_columns = true))
Expand All @@ -76,11 +85,20 @@ function accelerate!(

#identical update stage as before
## update "v" state:
k = get_N_iterations(uki) + 2
v = u .+ (1 - uki.accelerator.r / k) * (u .- uki.accelerator.u_prev)
Δt_prev = length(uki.Δt) == 1 ? 1 : uki.Δt[end - 1]
Δt = uki.Δt[end]
θ_prev = uki.accelerator.θ_prev


# condition θ_prev^2 * (1 - θ) * Δt \leq Δt_prev * θ^2
a = sqrt(θ_prev^2 * Δt / Δt_prev)
θ = (-a + sqrt(a^2 + 4)) / 2

v = u .+ θ * (1 / θ_prev - 1) * (u .- uki.accelerator.u_prev)

## update "u" state:
uki.accelerator.u_prev = u
uki.accelerator.θ_prev = θ

## push "v" state to UKI object
push!(uki.u, DataContainer(v, data_are_columns = true))
Expand Down
5 changes: 1 addition & 4 deletions src/UnscentedKalmanInversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -551,10 +551,7 @@ end
UKI prediction step : generate sigma points.
"""
function update_ensemble_prediction!(
process::Unscented,
Δt::FT,
) where {FT <: AbstractFloat, AV <: AbstractVector, AM <: AbstractMatrix}
function update_ensemble_prediction!(process::Unscented, Δt::FT) where {FT <: AbstractFloat}

process.iter += 1
# update evolution covariance matrix
Expand Down
27 changes: 20 additions & 7 deletions test/EnsembleKalmanProcess/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ end
@testset "Accelerators" begin
# Get an inverse problem
y_obs, G, Γy, _ = inv_problems[end - 2] # additive noise inv problem (deterministic map)
inv_sqrt_Γy = sqrt(inv(Γy))

rng = Random.MersenneTwister(rng_seed)
N_ens_tmp = 5
initial_ensemble = EKP.construct_initial_ensemble(rng, prior, N_ens_tmp)
Expand All @@ -105,20 +107,32 @@ end
@test typeof(eksobj_noacc_specified.accelerator) <: DefaultAccelerator

## test NesterovAccelerators satisfy desired ICs
@test ekiobj.accelerator.r 3.0
@test ekiobj.accelerator.u_prev == initial_ensemble
@test eksobj.accelerator.r 3.0
@test ekiobj.accelerator.θ_prev == 1.0
@test eksobj.accelerator.u_prev == initial_ensemble
@test eksobj.accelerator.θ_prev == 1.0

## test method convergence
# Note: this test only requires that the final ensemble is an improvement on the initial ensemble,
# NOT that the accelerated processes are more effective than the default, as this is not guaranteed.
# Specific cost values are printed to give an idea of acceleration.
processes = [Inversion(), TransformInversion(inv(Γy)), Unscented(prior; impose_prior = true), Sampler(prior)]
schedulers = [repeat([DefaultScheduler(0.1)], 3)..., EKSStableScheduler()]
processes = [
Inversion(),
TransformInversion(inv(Γy)),
Unscented(prior; impose_prior = true),
Inversion(),
TransformInversion(inv(Γy)),
Unscented(prior; impose_prior = true),
Sampler(prior),
]
schedulers = [
repeat([DefaultScheduler(0.1)], 3)...,
repeat([DataMisfitController(terminate_at = 100)], 3)...,
EKSStableScheduler(),
]
for (process, scheduler) in zip(processes, schedulers)
accelerators = [DefaultAccelerator(), NesterovAccelerator()]
N_iters = [5, 5, 5, 5]
N_iters = [20, 20]
init_means = []
final_means = []

Expand Down Expand Up @@ -160,9 +174,8 @@ end
push!(init_means, vec(mean(get_u_prior(ekpobj), dims = 2)))
push!(final_means, vec(mean(get_u_final(ekpobj), dims = 2)))

inv_sqrt_Γy = sqrt(inv(Γy))
cost_initial =
norm(inv_sqrt_Γy * (y_obs .- G(transform_unconstrained_to_constrained(prior, init_means[end]))))
norm(inv_sqrt_Γy * (y_obs .- G(transform_unconstrained_to_constrained(prior, initial_ensemble))))
cost_final =
norm(inv_sqrt_Γy * (y_obs .- G(transform_unconstrained_to_constrained(prior, final_means[end]))))
@info "Convergence:" cost_initial cost_final
Expand Down

0 comments on commit bdbb9e0

Please sign in to comment.