Skip to content

Commit

Permalink
Add save_parameter_samples (#370)
Browse files Browse the repository at this point in the history
  • Loading branch information
nefrathenrici committed Mar 19, 2024
1 parent b420d83 commit 8cdeeab
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 2 deletions.
44 changes: 43 additions & 1 deletion src/TOMLInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ export save_parameter_ensemble
export get_admissible_parameters
export get_regularization
export write_log_file

export save_parameter_samples

"""
get_parameter_values(param_dict, names)
Expand Down Expand Up @@ -301,6 +301,48 @@ function construct_2d_array(arr::Expr)
end


"""
save_parameter_samples(
distribution::ParameterDistribution{Samples},
default_param_data,
num_samples,
save_path;
save_file = "parameters.toml",
rng = Random.MersenneTwister(1234),
)
Takes samples from the given `distribution` and saves them to individual TOML files
in the folder specified by `save_path`
Arguments:
- `distribution` - ParameterDistribution{Samples} to sample from
- `default_param_data` - Dict of default parameters to be combined and saved with
the parameters in `param_array` into a toml file
- `save_path` - Folder where the parameters will be saved
- `save_file` - Name of the toml files to be generated
- `rng` - Random number generator used in sampling
- `pad_zeros` - Amount of digits to pad to
"""
function save_parameter_samples(
distribution::ParameterDistribution,
default_param_data,
num_samples,
save_path;
save_file = "parameters.toml",
rng = Random.MersenneTwister(1234),
pad_zeros = 3,
)

save_parameter_ensemble(
sample(rng, distribution, num_samples),
distribution,
default_param_data::Dict,
save_path,
save_file;
pad_zeros,
apply_constraints = true,
)
end

"""
save_parameter_ensemble(
param_array,
Expand Down
39 changes: 38 additions & 1 deletion test/TOMLInterface/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -277,5 +277,42 @@ const EKP = EnsembleKalmanProcesses

end


# Test `save_parameter_samples`
uq_param_4_samples = [
[14.412514158048534, -9.990904722898303],
[14.877561432702601, -9.979758088554195],
[14.601045096347011, -9.988891003461758],
[14.877561432702601, -9.979758088554195],
[14.899607236135727, -9.995483419057388],
[14.412514158048534, -9.990904722898303],
[14.601045096347011, -9.988891003461758],
[14.899607236135727, -9.995483419057388],
[14.899607236135727, -9.995483419057388],
[14.877561432702601, -9.979758088554195],
]
uq_param_5_samples = [
[1.0, 5.0, 8101.083927575384, 19.999997739670594],
[1.0, 5.0, 8101.083927575384, 19.999997739670594],
[3.0, 7.0, 59872.14171519782, 19.999999694097678],
[3.0, 7.0, 59872.14171519782, 19.999999694097678],
[1.0, 5.0, 8101.083927575384, 19.999997739670594],
[3.0, 7.0, 59872.14171519782, 19.999999694097678],
[3.0, 7.0, 59872.14171519782, 19.999999694097678],
[3.0, 7.0, 59872.14171519782, 19.999999694097678],
[1.0, 5.0, 8101.083927575384, 19.999997739670594],
[1.0, 5.0, 8101.083927575384, 19.999997739670594],
]
mktempdir(@__DIR__) do save_path
# Uncomment the line below to debug if the tests fail
# save_path = "sample_tests"
save_file = "parameters.toml"
pd = get_parameter_distribution(param_dict, ["uq_param_4", "uq_param_5"])
save_parameter_samples(pd, param_dict, 10, save_path; rng = Random.MersenneTwister(1234), save_file)
for (i, fpath) in enumerate(readdir(save_path))
toml_file = joinpath(save_path, fpath, save_file)
param_dict = TOML.parsefile(toml_file)
@test uq_param_4_samples[i] == param_dict["uq_param_4"]["value"]
@test uq_param_5_samples[i] == param_dict["uq_param_5"]["value"]
end
end
end

0 comments on commit 8cdeeab

Please sign in to comment.