Skip to content

Commit

Permalink
Merge #291
Browse files Browse the repository at this point in the history
291: Allow users to select termination time for DataMisfitController r=odunbar a=odunbar

<!--- THESE LINES ARE COMMENTED -->
## Purpose 
<!--- One sentence to describe the purpose of this PR, refer to any linked issues:
#14 -- this will link to issue 14
Closes #2 -- this will automatically close issue 2 on PR merge
-->
Closes #290 

## Content
<!---  specific tasks that are currently complete 
- Solution implemented
-->
- Added a keyword argument `terminate_at` for DMC
- Added tests to show it stops at the correct time
- Added a line into the docs 

<!---
Review checklist

I have:
- followed the codebase contribution guide: https://clima.github.io/ClimateMachine.jl/latest/Contributing/
- followed the style guide: https://clima.github.io/ClimateMachine.jl/latest/DevDocs/CodeStyle/
- followed the documentation policy: https://github.com/CliMA/policies/wiki/Documentation-Policy
- checked that this PR does not duplicate an open PR.

In the Content, I have included 
- relevant unit tests, and integration tests, 
- appropriate docstrings on all functions, structs, and modules, and included relevant documentation.

-->

----
- [x] I have read and checked the items on the review checklist.


Co-authored-by: odunbar <[email protected]>
  • Loading branch information
bors[bot] and odunbar committed Jun 8, 2023
2 parents cabe5da + 57eb6f5 commit 7e64a5e
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 17 deletions.
5 changes: 3 additions & 2 deletions docs/src/learning_rate_scheduler.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ Currently we will retain constant timestepping while we investigate further, tho

Please let us know how you get on by setting the keyword argument in EKP
```julia
scheduler = DataMisfitController() # terminating
scheduler = DataMisfitController(on_terminate = "continue") #non-terminating
scheduler = DataMisfitController() # terminating at `T=1`
scheduler = DataMisfitController(terminate_at = 10) # terminating at `T=10`
scheduler = DataMisfitController(on_terminate = "continue") # non-terminating
```

!!! warning "Ensemble Kalman Sampler"
Expand Down
28 changes: 19 additions & 9 deletions src/LearningRateSchedulers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,15 @@ end
$(TYPEDEF)
Scheduler from Iglesias, Yang, 2021, Based on Bayesian Tempering.
Terminates at `T=1`, and at this time, ensemble spread provides a (more) meaningful approximation of posterior uncertainty
Terminates at `T=1` by default, and at this time, ensemble spread provides a (more) meaningful approximation of posterior uncertainty
In particular, for parameters ``\\theta_j`` at step ``n``, to calculate the next timestep
``\\Delta t_n = \\min\\left(\\max\\left(\\frac{J}{2\\Phi}, \\sqrt{\\frac{J}{2\\langle \\Phi, \\Phi \\rangle}}\\right), 1-\\sum^{n-1}_i t_i\\right) `` where ``\\Phi_j = \\|\\Gamma^{-1}(G(\\theta_j) - y)\\|^2``.
Cannot be overriden. By default termination returns `true` from `update_ensemble!` and
Cannot be overriden by user provided timesteps.
By default termination returns `true` from `update_ensemble!` and
- if `on_terminate == "stop"`, stops further iteration.
- if `on_terminate == "continue_fixed", continues iteration with the final timestep fixed
- if `on_terminate == "continue", continues the algorithm (though no longer compares to ``1-\\sum^{n-1}_i t_i``)
The user may also change the `T` with `terminate_at` keyword.
$(TYPEDFIELDS)
"""
Expand All @@ -104,16 +106,23 @@ struct DataMisfitController{FT, M, S} <:
iteration::Vector{Int}
history::Vector{FT}
inv_sqrt_noise::Vector{M}
terminate_at::FT
on_terminate::S
end # Iglesias Yan 2021

function DataMisfitController(; on_terminate = "stop")
function DataMisfitController(; terminate_at = 1.0, on_terminate = "stop")
FT = Float64
M = Matrix{FT}
iteration = Int[]
history = FT[]
inv_sqrt_noise = M[]

if terminate_at > 0 #can be infinity
ta = FT(terminate_at)
else
ta = FT(1.0) # has a notion of posterior
end

if on_terminate ["continue", "continue_fixed", "stop"]
throw(
ArgumentError(
Expand All @@ -122,7 +131,7 @@ function DataMisfitController(; on_terminate = "stop")
)
end

return DataMisfitController{FT, M, typeof(on_terminate)}(iteration, history, inv_sqrt_noise, on_terminate)
return DataMisfitController{FT, M, typeof(on_terminate)}(iteration, history, inv_sqrt_noise, ta, on_terminate)
end

"""
Expand Down Expand Up @@ -240,6 +249,7 @@ function calculate_timestep!(


M, J = size(g)
T = scheduler.terminate_at

if isempty(ekp.Δt)
push!(scheduler.iteration, 1)
Expand All @@ -258,8 +268,8 @@ function calculate_timestep!(
sum_Δt = (n == 1) ? 0.0 : sum(ekp.Δt)
sum_Δt_min1 = (n <= 2) ? 0.0 : sum(ekp.Δt[1:(end - 1)])
# On termination condition:
if sum_Δt >= 1
if sum_Δt_min1 < 1 # "Just reached termination"
if sum_Δt >= T
if sum_Δt_min1 < T # "Just reached termination"
if scheduler.on_terminate == "stop"
@warn "Termination condition of timestepping scheme `DataMisfitController` has been exceeded. Preventing futher updates\n Set on_terminate=\"continue\" in `DataMisfitController` to ignore termination"
return true #returns a terminate call
Expand All @@ -284,16 +294,16 @@ function calculate_timestep!(

q = maximum((M / (2 * Φ_mean), sqrt(M / (2 * Φ_var))))

if sum_Δt < 1
Δt = minimum([q, 1 - sum_Δt])
if sum_Δt < T
Δt = minimum([q, T - sum_Δt])
else # when termination condition satisfied but choose to continue
Δt = q
end

# in theory the following should be the same.
push!(ekp.Δt, Δt)

if (sum_Δt < 1) && (sum_Δt + Δt >= 1)
if (sum_Δt < T) && (sum_Δt + Δt >= T)
@info "Termination condition of timestepping scheme `DataMisfitController` has been satisfied."
end
nothing
Expand Down
14 changes: 12 additions & 2 deletions test/EnsembleKalmanProcess/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,11 @@ end
@test length(dmclrs1.iteration) == 0
@test typeof(dmclrs1.inv_sqrt_noise) == Vector{Matrix{Float64}}
@test length(dmclrs1.inv_sqrt_noise) == 0
@test dmclrs1.terminate_at == Float64(1)
@test dmclrs1.on_terminate == "stop"
dmclrs2 = EKP.DataMisfitController(on_terminate = "continue")
dmclrs2 = EKP.DataMisfitController(terminate_at = 7, on_terminate = "continue")
@test dmclrs2.on_terminate == "continue"
@test dmclrs2.terminate_at == Float64(7)
dmclrs3 = EKP.DataMisfitController(on_terminate = "continue_fixed")
@test dmclrs3.on_terminate == "continue_fixed"

Expand All @@ -141,11 +143,12 @@ end
# Unscented(prior), TO BE UNCOMMENTED WHEN UKI BUG-FIXED
#Sparse inversion tests in test/SparseInversion/runtests.jl
]
T_end = 3 # (this could fail a test if N_iters is not enough to reach T_end)
for process in processes
schedulers = [
DefaultScheduler(0.05),
MutableScheduler(0.05),
DataMisfitController(),
DataMisfitController(terminate_at = T_end),
DataMisfitController(on_terminate = "continue"),
DataMisfitController(on_terminate = "continue_fixed"),
]
Expand Down Expand Up @@ -182,6 +185,13 @@ end
end
push!(init_means, vec(mean(get_u_prior(ekpobj), dims = 2)))
push!(final_means, vec(mean(get_u_final(ekpobj), dims = 2)))

# this test is fine so long as N_iter is large enough to hit the termination time
if nameof(typeof(scheduler)) == DataMisfitController
if (scheduler.terminate_at, scheduler.on_terminate) == (Float64(T_end), "stop")
@test sum(ekpobj.Δt) scheduler.terminate_at
end
end
end
for i in 1:length(final_means)
u_star = transform_constrained_to_unconstrained(prior, ϕ_star)
Expand Down
12 changes: 8 additions & 4 deletions test/SparseInversion/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ using EnsembleKalmanProcesses.ParameterDistributions
using EnsembleKalmanProcesses.Localizers
const EKP = EnsembleKalmanProcesses

TEST_PLOT_OUTPUT = false

# Read inverse problem definitions
include("../EnsembleKalmanProcess/inverse_problem.jl")

Expand Down Expand Up @@ -152,11 +150,11 @@ include("../EnsembleKalmanProcess/inverse_problem.jl")

## Repeat first test with several schedulers
y_obs, G, Γy = nl_inv_problems[1]

T_end = 3
schedulers = [
DefaultScheduler(0.1),
MutableScheduler(0.1),
DataMisfitController(),
DataMisfitController(terminate_at = T_end),
DataMisfitController(on_terminate = "continue"),
DataMisfitController(on_terminate = "continue_fixed"),
]
Expand Down Expand Up @@ -195,6 +193,12 @@ include("../EnsembleKalmanProcess/inverse_problem.jl")
end
push!(init_means, vec(mean(get_u_prior(ekiobj), dims = 2)))
push!(final_means, vec(mean(get_u_final(ekiobj), dims = 2)))
# this test is fine so long as N_iter is large enough to hit the termination time
if nameof(typeof(scheduler)) == DataMisfitController
if (scheduler.terminate_at, scheduler.on_terminate) == (Float64(T_end), "stop")
@test sum(ekiobj.Δt) scheduler.terminate_at
end
end

end
for i in 1:length(final_means)
Expand Down

0 comments on commit 7e64a5e

Please sign in to comment.