Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow users to select termination time for DataMisfitController #291

Merged
merged 1 commit into from
Jun 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions docs/src/learning_rate_scheduler.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ Currently we will retain constant timestepping while we investigate further, tho

Please let us know how you get on by setting the keyword argument in EKP
```julia
scheduler = DataMisfitController() # terminating
scheduler = DataMisfitController(on_terminate = "continue") #non-terminating
scheduler = DataMisfitController() # terminating at `T=1`
scheduler = DataMisfitController(terminate_at = 10) # terminating at `T=10`
scheduler = DataMisfitController(on_terminate = "continue") # non-terminating
```

!!! warning "Ensemble Kalman Sampler"
Expand Down
28 changes: 19 additions & 9 deletions src/LearningRateSchedulers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,15 @@ end
$(TYPEDEF)

Scheduler from Iglesias, Yang, 2021, Based on Bayesian Tempering.
Terminates at `T=1`, and at this time, ensemble spread provides a (more) meaningful approximation of posterior uncertainty
Terminates at `T=1` by default, and at this time, ensemble spread provides a (more) meaningful approximation of posterior uncertainty
In particular, for parameters ``\\theta_j`` at step ``n``, to calculate the next timestep
``\\Delta t_n = \\min\\left(\\max\\left(\\frac{J}{2\\Phi}, \\sqrt{\\frac{J}{2\\langle \\Phi, \\Phi \\rangle}}\\right), 1-\\sum^{n-1}_i t_i\\right) `` where ``\\Phi_j = \\|\\Gamma^{-1}(G(\\theta_j) - y)\\|^2``.
Cannot be overriden. By default termination returns `true` from `update_ensemble!` and
Cannot be overriden by user provided timesteps.
By default termination returns `true` from `update_ensemble!` and
- if `on_terminate == "stop"`, stops further iteration.
- if `on_terminate == "continue_fixed", continues iteration with the final timestep fixed
- if `on_terminate == "continue", continues the algorithm (though no longer compares to ``1-\\sum^{n-1}_i t_i``)
The user may also change the `T` with `terminate_at` keyword.

$(TYPEDFIELDS)
"""
Expand All @@ -104,16 +106,23 @@ struct DataMisfitController{FT, M, S} <:
iteration::Vector{Int}
history::Vector{FT}
inv_sqrt_noise::Vector{M}
terminate_at::FT
on_terminate::S
end # Iglesias Yan 2021

function DataMisfitController(; on_terminate = "stop")
function DataMisfitController(; terminate_at = 1.0, on_terminate = "stop")
FT = Float64
M = Matrix{FT}
iteration = Int[]
history = FT[]
inv_sqrt_noise = M[]

if terminate_at > 0 #can be infinity
ta = FT(terminate_at)
else
ta = FT(1.0) # has a notion of posterior
end

if on_terminate ∉ ["continue", "continue_fixed", "stop"]
throw(
ArgumentError(
Expand All @@ -122,7 +131,7 @@ function DataMisfitController(; on_terminate = "stop")
)
end

return DataMisfitController{FT, M, typeof(on_terminate)}(iteration, history, inv_sqrt_noise, on_terminate)
return DataMisfitController{FT, M, typeof(on_terminate)}(iteration, history, inv_sqrt_noise, ta, on_terminate)
end

"""
Expand Down Expand Up @@ -240,6 +249,7 @@ function calculate_timestep!(


M, J = size(g)
T = scheduler.terminate_at

if isempty(ekp.Δt)
push!(scheduler.iteration, 1)
Expand All @@ -258,8 +268,8 @@ function calculate_timestep!(
sum_Δt = (n == 1) ? 0.0 : sum(ekp.Δt)
sum_Δt_min1 = (n <= 2) ? 0.0 : sum(ekp.Δt[1:(end - 1)])
# On termination condition:
if sum_Δt >= 1
if sum_Δt_min1 < 1 # "Just reached termination"
if sum_Δt >= T
if sum_Δt_min1 < T # "Just reached termination"
if scheduler.on_terminate == "stop"
@warn "Termination condition of timestepping scheme `DataMisfitController` has been exceeded. Preventing futher updates\n Set on_terminate=\"continue\" in `DataMisfitController` to ignore termination"
return true #returns a terminate call
Expand All @@ -284,16 +294,16 @@ function calculate_timestep!(

q = maximum((M / (2 * Φ_mean), sqrt(M / (2 * Φ_var))))

if sum_Δt < 1
Δt = minimum([q, 1 - sum_Δt])
if sum_Δt < T
Δt = minimum([q, T - sum_Δt])
else # when termination condition satisfied but choose to continue
Δt = q
end

# in theory the following should be the same.
push!(ekp.Δt, Δt)

if (sum_Δt < 1) && (sum_Δt + Δt >= 1)
if (sum_Δt < T) && (sum_Δt + Δt >= T)
@info "Termination condition of timestepping scheme `DataMisfitController` has been satisfied."
end
nothing
Expand Down
14 changes: 12 additions & 2 deletions test/EnsembleKalmanProcess/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,11 @@ end
@test length(dmclrs1.iteration) == 0
@test typeof(dmclrs1.inv_sqrt_noise) == Vector{Matrix{Float64}}
@test length(dmclrs1.inv_sqrt_noise) == 0
@test dmclrs1.terminate_at == Float64(1)
@test dmclrs1.on_terminate == "stop"
dmclrs2 = EKP.DataMisfitController(on_terminate = "continue")
dmclrs2 = EKP.DataMisfitController(terminate_at = 7, on_terminate = "continue")
@test dmclrs2.on_terminate == "continue"
@test dmclrs2.terminate_at == Float64(7)
dmclrs3 = EKP.DataMisfitController(on_terminate = "continue_fixed")
@test dmclrs3.on_terminate == "continue_fixed"

Expand All @@ -141,11 +143,12 @@ end
# Unscented(prior), TO BE UNCOMMENTED WHEN UKI BUG-FIXED
#Sparse inversion tests in test/SparseInversion/runtests.jl
]
T_end = 3 # (this could fail a test if N_iters is not enough to reach T_end)
for process in processes
schedulers = [
DefaultScheduler(0.05),
MutableScheduler(0.05),
DataMisfitController(),
DataMisfitController(terminate_at = T_end),
DataMisfitController(on_terminate = "continue"),
DataMisfitController(on_terminate = "continue_fixed"),
]
Expand Down Expand Up @@ -182,6 +185,13 @@ end
end
push!(init_means, vec(mean(get_u_prior(ekpobj), dims = 2)))
push!(final_means, vec(mean(get_u_final(ekpobj), dims = 2)))

# this test is fine so long as N_iter is large enough to hit the termination time
if nameof(typeof(scheduler)) == DataMisfitController
if (scheduler.terminate_at, scheduler.on_terminate) == (Float64(T_end), "stop")
@test sum(ekpobj.Δt) ≈ scheduler.terminate_at
end
end
end
for i in 1:length(final_means)
u_star = transform_constrained_to_unconstrained(prior, ϕ_star)
Expand Down
12 changes: 8 additions & 4 deletions test/SparseInversion/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ using EnsembleKalmanProcesses.ParameterDistributions
using EnsembleKalmanProcesses.Localizers
const EKP = EnsembleKalmanProcesses

TEST_PLOT_OUTPUT = false

# Read inverse problem definitions
include("../EnsembleKalmanProcess/inverse_problem.jl")

Expand Down Expand Up @@ -152,11 +150,11 @@ include("../EnsembleKalmanProcess/inverse_problem.jl")

## Repeat first test with several schedulers
y_obs, G, Γy = nl_inv_problems[1]

T_end = 3
schedulers = [
DefaultScheduler(0.1),
MutableScheduler(0.1),
DataMisfitController(),
DataMisfitController(terminate_at = T_end),
DataMisfitController(on_terminate = "continue"),
DataMisfitController(on_terminate = "continue_fixed"),
]
Expand Down Expand Up @@ -195,6 +193,12 @@ include("../EnsembleKalmanProcess/inverse_problem.jl")
end
push!(init_means, vec(mean(get_u_prior(ekiobj), dims = 2)))
push!(final_means, vec(mean(get_u_final(ekiobj), dims = 2)))
# this test is fine so long as N_iter is large enough to hit the termination time
if nameof(typeof(scheduler)) == DataMisfitController
if (scheduler.terminate_at, scheduler.on_terminate) == (Float64(T_end), "stop")
@test sum(ekiobj.Δt) ≈ scheduler.terminate_at
end
end

end
for i in 1:length(final_means)
Expand Down
Loading