Skip to content

Commit

Permalink
Merge pull request #1082 from CliMA/glw/pickup
Browse files Browse the repository at this point in the history
New checkpointer features: set! and simulation "pickup"
  • Loading branch information
glwagner committed Oct 26, 2020
2 parents d343e0f + d018eae commit e1d1f19
Show file tree
Hide file tree
Showing 6 changed files with 252 additions and 46 deletions.
21 changes: 13 additions & 8 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,10 @@ uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
version = "1.2.4"

[[FFTW_jll]]
deps = ["Libdl", "Pkg"]
git-tree-sha1 = "6c975cd606128d45d1df432fb812d6eb10fee00b"
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "f10c3009373a2d5c4349b8a2932d8accb892892d"
uuid = "f5851436-0d7a-5f13-b9de-f02708fd171a"
version = "3.3.9+5"
version = "3.3.9+6"

[[GPUArrays]]
deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"]
Expand All @@ -137,6 +137,11 @@ git-tree-sha1 = "05097d81898c527e3bf218bb083ad0ead4378e5f"
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
version = "0.6.1"

[[Glob]]
git-tree-sha1 = "4df9f7e06108728ebf00a0a11edee4b29a482bb2"
uuid = "c27321d9-0574-5035-807b-f59d2c89b15c"
version = "1.3.0"

[[IntelOpenMP_jll]]
deps = ["Libdl", "Pkg"]
git-tree-sha1 = "fb8e1c7a5594ba56f9011310790e03b5384998d6"
Expand Down Expand Up @@ -227,10 +232,10 @@ uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
version = "1.3.1"

[[OpenSpecFun_jll]]
deps = ["CompilerSupportLibraries_jll", "Libdl", "Pkg"]
git-tree-sha1 = "d51c416559217d974a1113522d5919235ae67a87"
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "9db77584158d0ab52307f8c04f8e7c08ca76b5b3"
uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
version = "0.5.3+3"
version = "0.5.3+4"

[[OrderedCollections]]
git-tree-sha1 = "16c08bf5dba06609fe45e30860092d6fa41fde7b"
Expand Down Expand Up @@ -352,6 +357,6 @@ version = "1.2.0"

[[Zlib_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "ded43825988ace7a311ee7e1d0f09571822509c4"
git-tree-sha1 = "320228915c8debb12cb434c59057290f0834dbf6"
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
version = "1.2.11+17"
version = "1.2.11+18"
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Crayons = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
Glob = "c27321d9-0574-5035-807b-f59d2c89b15c"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
Expand Down
75 changes: 74 additions & 1 deletion src/OutputWriters/checkpointer.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
using Glob
import Oceananigans.Fields: set!

using Oceananigans.Fields: offset_data

"""
Expand Down Expand Up @@ -85,8 +88,19 @@ function Checkpointer(model; schedule,
return Checkpointer(schedule, dir, prefix, properties, force, verbose)
end

""" Returns the full prefix (the `superprefix`) associated with `checkpointer`. """
checkpoint_superprefix(prefix) = prefix * "_iteration"

"""
checkpoint_path(iteration::Int, c::Checkpointer)
Returns the path to the `c`heckpointer file associated with model `iteration`.
"""
checkpoint_path(iteration::Int, c::Checkpointer) =
joinpath(c.dir, string(checkpoint_superprefix(c.prefix), iteration, ".jld2"))

function write_output!(c::Checkpointer, model)
filepath = joinpath(c.dir, c.prefix * "_iteration" * string(model.clock.iteration) * ".jld2")
filepath = checkpoint_path(model.clock.iteration, c)
c.verbose && @info "Checkpointing to file $filepath..."

t1 = time_ns()
Expand Down Expand Up @@ -208,3 +222,62 @@ function restore_from_checkpoint(filepath; kwargs...)

return model
end

#####
##### set! for checkpointer filepaths
#####

"""
set!(model, filepath::AbstractString)
Set data in `model.velocities`, `model.tracers`, `model.timestepper.Gⁿ`, and
`model.timestepper.G⁻` to checkpointed data stored at `filepath`.
"""
function set!(model, filepath::AbstractString)

jldopen(filepath, "r") do file

# Validate the grid
checkpointed_grid = file["grid"]
model.grid == checkpointed_grid ||
error("The grid associated with $filepath and model.grid are not the same!")

# Set model fields and tendency fields
model_fields = merge(model.velocities, model.tracers)

for name in propertynames(model_fields)
# Load data for each model field
address = name (:u, :v, :w) ? "velocities/$name" : "tracers/$name"
parent_data = file[address * "/data"]

model_field = model_fields[name]
copyto!(model_field.data.parent, parent_data)

# Load tendency data
#
# Note: this step is unecessary for models that use RungeKutta3TimeStepper and
# tendency restoration could be depcrecated in the future.

# Tendency "n"
parent_data = file["timestepper/Gⁿ/$name/data"]

tendencyⁿ_field = model.timestepper.Gⁿ[name]
copyto!(tendencyⁿ_field.data.parent, parent_data)

# Tendency "n-1"
parent_data = file["timestepper/G⁻/$name/data"]

tendency⁻_field = model.timestepper.G⁻[name]
copyto!(tendency⁻_field.data.parent, parent_data)
end

checkpointed_clock = file["clock"]

# Update model clock
model.clock.iteration = checkpointed_clock.iteration
model.clock.time = checkpointed_clock.time

end

return nothing
end
10 changes: 7 additions & 3 deletions src/OutputWriters/jld2_output_writer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,13 @@ function JLD2OutputWriter(model, outputs; prefix, schedule,
filepath = joinpath(dir, prefix * ".jld2")
force && isfile(filepath) && rm(filepath, force=true)

jldopen(filepath, "a+"; jld2_kw...) do file
init(file, model)
saveproperties!(file, model, including)
try
jldopen(filepath, "a+"; jld2_kw...) do file
init(file, model)
saveproperties!(file, model, including)
end
catch
@warn "Could not initialize $filepath: data may already be initialized."
end

return JLD2OutputWriter(filepath, outputs, schedule, field_slicer,
Expand Down
92 changes: 81 additions & 11 deletions src/Simulations/run.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
using Glob

using Oceananigans.Utils: initialize_schedule!
using Oceananigans.OutputWriters: WindowedTimeAverage
using Oceananigans.Fields: set!
using Oceananigans.OutputWriters: WindowedTimeAverage, checkpoint_superprefix
using Oceananigans.TimeSteppers: QuasiAdamsBashforth2TimeStepper, RungeKutta3TimeStepper, update_state!

import Oceananigans.OutputWriters: checkpoint_path

# Simulations are for running

function stop(sim)
Expand Down Expand Up @@ -61,20 +66,50 @@ get_Δt(simulation::Simulation) = get_Δt(simulation.Δt)
ab2_or_rk3_time_step!(model::IncompressibleModel{<:QuasiAdamsBashforth2TimeStepper}, Δt; euler) = time_step!(model, Δt, euler=euler)
ab2_or_rk3_time_step!(model::IncompressibleModel{<:RungeKutta3TimeStepper}, Δt; euler) = time_step!(model, Δt)

we_want_to_pickup(pickup::Bool) = pickup
we_want_to_pickup(pickup) = true

"""
run!(simulation)
run!(simulation; pickup=false)
Run a `simulation` until one of `simulation.stop_criteria` evaluates `true`.
The simulation will then stop.
# Picking simulations up from a checkpoint
Simulations will be "picked up" from a checkpoint if `pickup` is either `true`, a `String`,
or an `Integer` greater than 0.
Picking up a simulation sets field and tendency data to the specified checkpoint,
leaving all other model properties unchanged.
Run a `simulation` until one of the stop criteria evaluates to true. The simulation
will then stop.
Possible values for `pickup` are:
* `pickup=true` will pick a simulation up from the latest checkpoint associated with
the `Checkpointer` in simulation.output_writers`.
* `pickup=iteration::Int` will pick a simulation up from the checkpointed file associated
with `iteration` and the `Checkpointer` in simulation.output_writers`.
* `pickup=filepath::String` will pick a simulation up from checkpointer data in `filepath`.
Note that `pickup=true` and `pickup=iteration` will fail if `simulation.output_writers` contains
more than one checkpointer.
"""
function run!(sim)
function run!(sim; pickup=false)

model = sim.model
clock = model.clock

# Conservatively update the model state when run! initiates
if we_want_to_pickup(pickup)
checkpointers = filter(writer -> writer isa Checkpointer, collect(values(sim.output_writers)))
set!(model, checkpoint_path(pickup, checkpointers))
end

# Conservatively initialize the model state
update_state!(model)

# Initialization
# Output and diagnostics initialization
for writer in values(sim.output_writers)
open(writer)
initialize_schedule!(writer.schedule)
Expand All @@ -86,18 +121,19 @@ function run!(sim)
while !stop(sim)
time_before = time()

# Evaluate all diagnostics and write output at first iteration
# Evaluate all diagnostics, and then write all output at first iteration
if clock.iteration == 0
[run_diagnostic!(diag, sim.model) for diag in values(sim.diagnostics)]
[write_output!(out, sim.model) for out in values(sim.output_writers)]
[write_output!(writer, sim.model) for writer in values(sim.output_writers)]
end

for n in 1:sim.iteration_interval
euler = clock.iteration == 0 || (sim.Δt isa TimeStepWizard && n == 1)
ab2_or_rk3_time_step!(model, get_Δt(sim.Δt), euler=euler)

[ diag.schedule(model) && run_diagnostic!(diag, sim.model) for diag in values(sim.diagnostics) ]
[ writer.schedule(model) && write_output!(writer, sim.model) for writer in values(sim.output_writers) ]
# Run diagnostics, then write output
[ diag.schedule(model) && run_diagnostic!(diag, sim.model) for diag in values(sim.diagnostics)]
[writer.schedule(model) && write_output!(writer, sim.model) for writer in values(sim.output_writers)]
end

sim.progress(sim)
Expand All @@ -110,3 +146,37 @@ function run!(sim)

return nothing
end

#####
##### Util for "picking up" a simulation from a checkpoint
#####

""" Returns `filepath`. Shortcut for `run!(simulation, pickup=filepath)`. """
checkpoint_path(filepath::AbstractString, checkpointers) = filepath

function checkpoint_path(pickup, checkpointers)
length(checkpointers) == 0 && error("No checkpointers found: cannot pickup simulation!")
length(checkpointers) > 1 && error("Multiple checkpointers found: not sure which one to pickup simulation from!")
return checkpoint_path(pickup, first(checkpointers))
end

"""
checkpoint_path(pickup::Bool, checkpointer)
For `pickup=true`, parse the filenames in `checkpointer.dir` associated with
`checkpointer.prefix` and return the path to the file whose name contains
the largest iteration.
"""
function checkpoint_path(pickup::Bool, checkpointer::Checkpointer)
filepaths = glob(checkpoint_superprefix(checkpointer.prefix) * "*.jld2", checkpointer.dir)
filenames = basename.(filepaths)

# Parse filenames to find latest checkpointed iteration
leading = length(checkpoint_superprefix(checkpointer.prefix))
trailing = 5 # length(".jld2")
iterations = map(name -> parse(Int, name[leading+1:end-trailing]), filenames)

latest_iteration, idx = findmax(iterations)

return filepaths[idx]
end
Loading

0 comments on commit e1d1f19

Please sign in to comment.