Skip to content

Commit

Permalink
modified accelerator for variable timestep
Browse files Browse the repository at this point in the history
simplify expression

format

rm typo in UKI that causes warning
  • Loading branch information
odunbar committed Oct 25, 2023
1 parent 02d8ad6 commit 03e08e4
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 12 deletions.
41 changes: 35 additions & 6 deletions src/Accelerators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ $(TYPEDFIELDS)
mutable struct NesterovAccelerator{FT <: AbstractFloat} <: Accelerator
r::FT
u_prev::Any
θ_prev::Any
end

function NesterovAccelerator(r = 3.0, initial = Float64[])
return NesterovAccelerator(r, initial)
return NesterovAccelerator(r, initial, initial)
end


Expand Down Expand Up @@ -54,12 +55,26 @@ function accelerate!(
u::MA,
) where {FT <: AbstractFloat, IT <: Int, P <: Process, LRS <: LearningRateScheduler, MA <: AbstractMatrix}
## update "v" state:
k = get_N_iterations(ekp) + 2
v = u .+ (1 - ekp.accelerator.r / k) * (u .- ekp.accelerator.u_prev)
#v = u .+ 2 / (get_N_iterations(ekp) + 2) * (u .- ekp.accelerator.u_prev)
if length(ekp.Δt) == 1
θ_prev = 1
Δt_prev = 1
Δt = ekp.Δt[end]
else
θ_prev = ekp.accelerator.θ_prev
Δt_prev = ekp.Δt[end - 1]
Δt = ekp.Δt[end]
end

# condition θ_prev^2 * (1 - θ) * Δt \leq Δt_prev * θ^2
a = θ_prev^2 * Δt / Δt_prev
θ = (-a + sqrt(a^2 + 4)) / 2

v = u .+ θ * (1 / θ_prev - 1) * (u .- ekp.accelerator.u_prev)

## update "u" state:
ekp.accelerator.u_prev = u

ekp.accelerator.θ_prev = θ
## push "v" state to EKP object
push!(ekp.u, DataContainer(v, data_are_columns = true))
end
Expand All @@ -76,11 +91,25 @@ function accelerate!(

#identical update stage as before
## update "v" state:
k = get_N_iterations(uki) + 2
v = u .+ (1 - uki.accelerator.r / k) * (u .- uki.accelerator.u_prev)
if length(uki.Δt) == 1
θ_prev = 1
Δt_prev = 1
Δt = uki.Δt[end]
else
θ_prev = uki.accelerator.θ_prev
Δt_prev = uki.Δt[end - 1]
Δt = uki.Δt[end]
end

# condition θ_prev^2 * (1 - θ) * Δt \leq Δt_prev * θ^2
a = θ_prev^2 * Δt / Δt_prev
θ = (-a + sqrt(a^2 + 4)) / 2

v = u .+ θ * (1 / θ_prev - 1) * (u .- uki.accelerator.u_prev)

## update "u" state:
uki.accelerator.u_prev = u
uki.accelerator.θ_prev = θ

## push "v" state to UKI object
push!(uki.u, DataContainer(v, data_are_columns = true))
Expand Down
2 changes: 1 addition & 1 deletion src/UnscentedKalmanInversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ UKI prediction step : generate sigma points.
function update_ensemble_prediction!(
process::Unscented,
Δt::FT,
) where {FT <: AbstractFloat, AV <: AbstractVector, AM <: AbstractMatrix}
) where {FT <: AbstractFloat}

process.iter += 1
# update evolution covariance matrix
Expand Down
23 changes: 18 additions & 5 deletions test/EnsembleKalmanProcess/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ end
@testset "Accelerators" begin
# Get an inverse problem
y_obs, G, Γy, _ = inv_problems[end - 2] # additive noise inv problem (deterministic map)
inv_sqrt_Γy = sqrt(inv(Γy))

rng = Random.MersenneTwister(rng_seed)
N_ens_tmp = 5
initial_ensemble = EKP.construct_initial_ensemble(rng, prior, N_ens_tmp)
Expand Down Expand Up @@ -114,11 +116,23 @@ end
# Note: this test only requires that the final ensemble is an improvement on the initial ensemble,
# NOT that the accelerated processes are more effective than the default, as this is not guaranteed.
# Specific cost values are printed to give an idea of acceleration.
processes = [Inversion(), TransformInversion(inv(Γy)), Unscented(prior; impose_prior = true), Sampler(prior)]
schedulers = [repeat([DefaultScheduler(0.1)], 3)..., EKSStableScheduler()]
processes = [
Inversion(),
TransformInversion(inv(Γy)),
Unscented(prior; impose_prior = true),
Inversion(),
TransformInversion(inv(Γy)),
Unscented(prior; impose_prior = true),
Sampler(prior),
]
schedulers = [
repeat([DefaultScheduler(0.1)], 3)...,
repeat([DataMisfitController(terminate_at = 100)], 3)...,
EKSStableScheduler(),
]
for (process, scheduler) in zip(processes, schedulers)
accelerators = [DefaultAccelerator(), NesterovAccelerator()]
N_iters = [5, 5, 5, 5]
N_iters = [5, 5]
init_means = []
final_means = []

Expand Down Expand Up @@ -160,9 +174,8 @@ end
push!(init_means, vec(mean(get_u_prior(ekpobj), dims = 2)))
push!(final_means, vec(mean(get_u_final(ekpobj), dims = 2)))

inv_sqrt_Γy = sqrt(inv(Γy))
cost_initial =
norm(inv_sqrt_Γy * (y_obs .- G(transform_unconstrained_to_constrained(prior, init_means[end]))))
norm(inv_sqrt_Γy * (y_obs .- G(transform_unconstrained_to_constrained(prior, initial_ensemble))))
cost_final =
norm(inv_sqrt_Γy * (y_obs .- G(transform_unconstrained_to_constrained(prior, final_means[end]))))
@info "Convergence:" cost_initial cost_final
Expand Down

0 comments on commit 03e08e4

Please sign in to comment.