Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tests for splitting output files using TimeInterval #3523

Merged
merged 12 commits into from
Mar 28, 2024
10 changes: 5 additions & 5 deletions src/OutputWriters/output_writer_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ using Oceananigans.Utils: AbstractSchedule
##### Output writer utilities
#####

struct NoFileSplitting end
(::NoFileSplitting)(model) = false
Base.summary(::NoFileSplitting) = "NoFileSplitting"
Base.show(io::IO, nfs::NoFileSplitting) = print(io, summary(nfs))

mutable struct FileSizeLimit <: AbstractSchedule
size_limit :: Float64
path :: String
Expand Down Expand Up @@ -47,11 +52,6 @@ function update_file_splitting_schedule!(schedule::FileSizeLimit, filepath)
return nothing
end

struct NoFileSplitting end
(::NoFileSplitting)(model) = false
Base.summary(::NoFileSplitting) = "NoFileSplitting"
Base.show(io::IO, nfs::NoFileSplitting) = print(io, summary(nfs))

"""
ext(ow)

Expand Down
56 changes: 51 additions & 5 deletions test/test_jld2_output_writer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ function jld2_sliced_field_output(model, outputs=model.velocities)
return size(u₁) == (2, 2, 4) && size(v₁) == (2, 2, 4) && size(w₁) == (2, 2, 5)
end

function test_jld2_file_splitting(arch)
function test_jld2_size_file_splitting(arch)
grid = RectilinearGrid(arch, size=(16, 16, 16), extent=(1, 1, 1), halo=(1, 1, 1))
model = NonhydrostaticModel(; grid, buoyancy=SeawaterBuoyancy(), tracers=(:T, :S))
simulation = Simulation(model, Δt=1, stop_iteration=10)
Expand Down Expand Up @@ -88,6 +88,51 @@ function test_jld2_file_splitting(arch)
return nothing
end

function test_jld2_time_file_splitting(arch)
grid = RectilinearGrid(arch, size=(16, 16, 16), extent=(1, 1, 1), halo=(1, 1, 1))
model = NonhydrostaticModel(; grid, buoyancy=SeawaterBuoyancy(), tracers=(:T, :S))
simulation = Simulation(model, Δt=1, stop_iteration=10)

function fake_bc_init(file, model)
file["boundary_conditions/fake"] = π
end
ow = JLD2OutputWriter(model, (; u=model.velocities.u);
dir = ".",
filename = "test.jld2",
schedule = IterationInterval(1),
init = fake_bc_init,
including = [:grid],
array_type = Array{Float64},
with_halos = true,
file_splitting = TimeInterval(3seconds),
overwrite_existing = true)

push!(simulation.output_writers, ow)

run!(simulation)

for n in string.(1:3)
filename = "test_part$n.jld2"
jldopen(filename, "r") do file
# Test to make sure all files contain structs from `including`.
@test file["grid/Nx"] == 16

# Test to make sure all files contain the same number of snapshots.
dimlength = length(file["timeseries/t"])
@test dimlength == 3

# Test to make sure all files contain info from `init` function.
@test file["boundary_conditions/fake"] == π
end

# Leave test directory clean.
rm(filename)
end
rm("test_part4.jld2")

return nothing
end

function test_jld2_time_averaging_of_horizontal_averages(model)

model.clock.iteration = 0
Expand Down Expand Up @@ -266,11 +311,12 @@ for arch in archs
test_field_slicing("sliced_funcs_jld2_test.jld2", ("u", "v", "w"), (4, 4, 4), (4, 4, 4), (4, 4, 5))
test_field_slicing("sliced_func_fields_jld2_test.jld2", ("αt", "background_u"), (2, 4, 4), (2, 4, 4))

#####
##### File splitting
#####
####
#### File splitting
####

test_jld2_file_splitting(arch)
test_jld2_size_file_splitting(arch)
test_jld2_time_file_splitting(arch)

#####
##### Time-averaging
Expand Down
45 changes: 43 additions & 2 deletions test/test_netcdf_output_writer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ function test_DateTime_netcdf_output(arch)
return nothing
end

function test_netcdf_file_splitting(arch)
function test_netcdf_size_file_splitting(arch)
grid = RectilinearGrid(arch, size=(16, 16, 16), extent=(1, 1, 1), halo=(1, 1, 1))
model = NonhydrostaticModel(; grid, buoyancy=SeawaterBuoyancy(), tracers=(:T, :S))
simulation = Simulation(model, Δt=1, stop_iteration=10)
Expand Down Expand Up @@ -90,6 +90,45 @@ function test_netcdf_file_splitting(arch)
return nothing
end

function test_netcdf_time_file_splitting(arch)
grid = RectilinearGrid(arch, size=(16, 16, 16), extent=(1, 1, 1), halo=(1, 1, 1))
model = NonhydrostaticModel(; grid, buoyancy=SeawaterBuoyancy(), tracers=(:T, :S))
simulation = Simulation(model, Δt=1, stop_iteration=12seconds)

fake_attributes = Dict("fake_attribute"=>"fake_attribute")

ow = NetCDFOutputWriter(model, (; u=model.velocities.u);
dir = ".",
filename = "test.nc",
schedule = IterationInterval(2),
array_type = Array{Float64},
with_halos = true,
global_attributes = fake_attributes,
file_splitting = TimeInterval(4seconds),
overwrite_existing = true)

push!(simulation.output_writers, ow)

run!(simulation)

for n in string.(1:3)
filename = "test_part$n.nc"
ds = NCDataset(filename,"r")
dimlength = length(ds["time"])
# Test that all files contain the same dimensions.
@test dimlength == 2
# Test that all files contain the user defined attributes.
@test ds.attrib["fake_attribute"] == "fake_attribute"

# Leave test directory clean.
close(ds)
rm(filename)
end
rm("test_part4.nc")

return nothing
end

function test_TimeDate_netcdf_output(arch)
grid = RectilinearGrid(arch, size=(1, 1, 1), extent=(1, 1, 1))
clock = Clock(time=TimeDate(2021, 1, 1))
Expand Down Expand Up @@ -880,7 +919,8 @@ for arch in archs
@testset "NetCDF output writer [$(typeof(arch))]" begin
@info " Testing NetCDF output writer [$(typeof(arch))]..."
test_DateTime_netcdf_output(arch)
test_netcdf_file_splitting(arch)
test_netcdf_size_file_splitting(arch)
test_netcdf_time_file_splitting(arch)
test_TimeDate_netcdf_output(arch)
test_thermal_bubble_netcdf_output(arch)
test_thermal_bubble_netcdf_output_with_halos(arch)
Expand All @@ -892,3 +932,4 @@ for arch in archs
test_netcdf_regular_lat_lon_grid_output(arch)
end
end