Skip to content

Commit

Permalink
Accelerators
Browse files Browse the repository at this point in the history
remove docstring

format
  • Loading branch information
odunbar committed Nov 3, 2023
1 parent 06c3edf commit 34f81f7
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions src/Accelerators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,15 @@ function accelerate!(
u::MA,
) where {FT <: AbstractFloat, IT <: Int, P <: Process, LRS <: LearningRateScheduler, MA <: AbstractMatrix}
## update "v" state:
#v = u .+ 2 / (get_N_iterations(ekp) + 2) * (u .- ekp.accelerator.u_prev)
Δt_prev = length(ekp.Δt) == 1 ? 1 : ekp.Δt[end - 1]
Δt_prev = length(ekp.Δt) == 1 ? ekp.Δt[end] : ekp.Δt[end - 1]
Δt = ekp.Δt[end]
θ_prev = ekp.accelerator.θ_prev

# condition θ_prev^2 * (1 - θ) * Δt \leq Δt_prev * θ^2
a = sqrt(θ_prev^2 * Δt / Δt_prev)
θ = (-a + sqrt(a^2 + 4)) / 2
# condition θ_prev^2 * (1 - θ) * h \leq h_prev * θ^2
b = θ_prev^2 * Δt / Δt_prev

θ_lowbd = (-b + sqrt(b^2 + 4 * b)) / 2
θ = min((θ_lowbd + 1) / 2, θ_lowbd + 1.0 / length(ekp.Δt)^2) # can be unstable close to the boundary, so add a quadratic penalization

v = u .+ θ * (1 / θ_prev - 1) * (u .- ekp.accelerator.u_prev)

Expand All @@ -118,14 +119,15 @@ function accelerate!(

#identical update stage as before
## update "v" state:
Δt_prev = length(uki.Δt) == 1 ? 1 : uki.Δt[end - 1]
Δt_prev = length(uki.Δt) == 1 ? uki.Δt[end] : uki.Δt[end - 1]
Δt = uki.Δt[end]
θ_prev = uki.accelerator.θ_prev


# condition θ_prev^2 * (1 - θ) * Δt \leq Δt_prev * θ^2
a = sqrt(θ_prev^2 * Δt / Δt_prev)
θ = (-a + sqrt(a^2 + 4)) / 2
# condition θ_prev^2 * (1 - θ) * h \leq h_prev * θ^2
b = θ_prev^2 * Δt / Δt_prev
θ_lowbd = (-b + sqrt(b^2 + 4 * b)) / 2
θ = min((θ_lowbd + 1) / 2, θ_lowbd + 1.0 / length(uki.Δt)^2) # can be unstable at the boundary, so add a quadratic penalization

v = u .+ θ * (1 / θ_prev - 1) * (u .- uki.accelerator.u_prev)

Expand Down

0 comments on commit 34f81f7

Please sign in to comment.