Skip to content

Commit

Permalink
modified accelerator for variable timestep
Browse files Browse the repository at this point in the history
simplify expression

format

rm typo in UKI that causes warning

needs a sqrt
  • Loading branch information
odunbar committed Oct 25, 2023
1 parent 02d8ad6 commit b6259ca
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 15 deletions.
31 changes: 25 additions & 6 deletions src/Accelerators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ $(TYPEDFIELDS)
mutable struct NesterovAccelerator{FT <: AbstractFloat} <: Accelerator
r::FT
u_prev::Any
θ_prev::Any
end

function NesterovAccelerator(r = 3.0, initial = Float64[])
return NesterovAccelerator(r, initial)
function NesterovAccelerator(r = 3.0, initial = Float64[], θ_initial = 1.0)
return NesterovAccelerator(r, initial, θ_initial)
end


Expand Down Expand Up @@ -54,11 +55,20 @@ function accelerate!(
u::MA,
) where {FT <: AbstractFloat, IT <: Int, P <: Process, LRS <: LearningRateScheduler, MA <: AbstractMatrix}
## update "v" state:
k = get_N_iterations(ekp) + 2
v = u .+ (1 - ekp.accelerator.r / k) * (u .- ekp.accelerator.u_prev)
#v = u .+ 2 / (get_N_iterations(ekp) + 2) * (u .- ekp.accelerator.u_prev)
Δt_prev = length(ekp.Δt) == 1 ? 1 : ekp.Δt[end - 1]
Δt = ekp.Δt[end]
θ_prev = ekp.accelerator.θ_prev

# condition θ_prev^2 * (1 - θ) * Δt \leq Δt_prev * θ^2
a = sqrt(θ_prev^2 * Δt / Δt_prev)
θ = (-a + sqrt(a^2 + 4)) / 2

v = u .+ θ * (1 / θ_prev - 1) * (u .- ekp.accelerator.u_prev)

## update "u" state:
ekp.accelerator.u_prev = u
ekp.accelerator.θ_prev = θ

## push "v" state to EKP object
push!(ekp.u, DataContainer(v, data_are_columns = true))
Expand All @@ -76,11 +86,20 @@ function accelerate!(

#identical update stage as before
## update "v" state:
k = get_N_iterations(uki) + 2
v = u .+ (1 - uki.accelerator.r / k) * (u .- uki.accelerator.u_prev)
Δt_prev = length(uki.Δt) == 1 ? 1 : uki.Δt[end - 1]
Δt = uki.Δt[end]
θ_prev = uki.accelerator.θ_prev


# condition θ_prev^2 * (1 - θ) * Δt \leq Δt_prev * θ^2
a = sqrt(θ_prev^2 * Δt / Δt_prev)
θ = (-a + sqrt(a^2 + 4)) / 2

v = u .+ θ * (1 / θ_prev - 1) * (u .- uki.accelerator.u_prev)

## update "u" state:
uki.accelerator.u_prev = u
uki.accelerator.θ_prev = θ

## push "v" state to UKI object
push!(uki.u, DataContainer(v, data_are_columns = true))
Expand Down
5 changes: 1 addition & 4 deletions src/UnscentedKalmanInversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -551,10 +551,7 @@ end
UKI prediction step : generate sigma points.
"""
function update_ensemble_prediction!(
process::Unscented,
Δt::FT,
) where {FT <: AbstractFloat, AV <: AbstractVector, AM <: AbstractMatrix}
function update_ensemble_prediction!(process::Unscented, Δt::FT) where {FT <: AbstractFloat}

process.iter += 1
# update evolution covariance matrix
Expand Down
25 changes: 20 additions & 5 deletions test/EnsembleKalmanProcess/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ end
@testset "Accelerators" begin
# Get an inverse problem
y_obs, G, Γy, _ = inv_problems[end - 2] # additive noise inv problem (deterministic map)
inv_sqrt_Γy = sqrt(inv(Γy))

rng = Random.MersenneTwister(rng_seed)
N_ens_tmp = 5
initial_ensemble = EKP.construct_initial_ensemble(rng, prior, N_ens_tmp)
Expand All @@ -107,18 +109,32 @@ end
## test NesterovAccelerators satisfy desired ICs
@test ekiobj.accelerator.r 3.0
@test ekiobj.accelerator.u_prev == initial_ensemble
@test ekiobj.accelerator.θ_prev == 1.0
@test eksobj.accelerator.r 3.0
@test eksobj.accelerator.u_prev == initial_ensemble
@test eksobj.accelerator.θ_prev == 1.0

## test method convergence
# Note: this test only requires that the final ensemble is an improvement on the initial ensemble,
# NOT that the accelerated processes are more effective than the default, as this is not guaranteed.
# Specific cost values are printed to give an idea of acceleration.
processes = [Inversion(), TransformInversion(inv(Γy)), Unscented(prior; impose_prior = true), Sampler(prior)]
schedulers = [repeat([DefaultScheduler(0.1)], 3)..., EKSStableScheduler()]
processes = [
Inversion(),
TransformInversion(inv(Γy)),
Unscented(prior; impose_prior = true),
Inversion(),
TransformInversion(inv(Γy)),
Unscented(prior; impose_prior = true),
Sampler(prior),
]
schedulers = [
repeat([DefaultScheduler(0.1)], 3)...,
repeat([DataMisfitController(terminate_at = 100)], 3)...,
EKSStableScheduler(),
]
for (process, scheduler) in zip(processes, schedulers)
accelerators = [DefaultAccelerator(), NesterovAccelerator()]
N_iters = [5, 5, 5, 5]
N_iters = [20, 20]
init_means = []
final_means = []

Expand Down Expand Up @@ -160,9 +176,8 @@ end
push!(init_means, vec(mean(get_u_prior(ekpobj), dims = 2)))
push!(final_means, vec(mean(get_u_final(ekpobj), dims = 2)))

inv_sqrt_Γy = sqrt(inv(Γy))
cost_initial =
norm(inv_sqrt_Γy * (y_obs .- G(transform_unconstrained_to_constrained(prior, init_means[end]))))
norm(inv_sqrt_Γy * (y_obs .- G(transform_unconstrained_to_constrained(prior, initial_ensemble))))
cost_final =
norm(inv_sqrt_Γy * (y_obs .- G(transform_unconstrained_to_constrained(prior, final_means[end]))))
@info "Convergence:" cost_initial cost_final
Expand Down

0 comments on commit b6259ca

Please sign in to comment.