Skip to content

Commit

Permalink
Add constrained_gaussian TOML parser
Browse files Browse the repository at this point in the history
  • Loading branch information
nefrathenrici committed Nov 13, 2023
1 parent 4bcb401 commit 4e11583
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 1 deletion.
34 changes: 33 additions & 1 deletion src/TOMLInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ using TOML
using Distributions
using EnsembleKalmanProcesses.ParameterDistributions


# Exports
export path_to_ensemble_member
export get_parameter_distribution
Expand Down Expand Up @@ -56,6 +55,10 @@ function get_parameter_distribution(param_dict::Dict, name::AbstractString)
# Constructing a parameter distribution requires a prior distribution,
# a constraint, and a name.
prior = construct_prior(param_dict[name])
# If constrained_gaussian, then prior is already a ParameterDistribution
if prior isa ParameterDistribution
return prior
end
constraint = construct_constraint(param_dict[name])

return ParameterDistribution(prior, constraint, name)
Expand Down Expand Up @@ -243,6 +246,35 @@ function get_distribution_from_expr(d::Expr)

return dist_type(dist_args)

elseif dist_type_symb == Symbol("constrained_gaussian")

function parse_kwargs(args)
kwargs = []
for arg in args
# Only parse repeats kwarg for now
arg.args[1] != :repeats &&
throw(ArgumentError("Keyword argument $(arg.args[1]) can not be parsed from TOML."))
push!(kwargs, arg.args[1] => parse(Int64, string(arg.args[2])))
end
return kwargs
end

kwargs = []
index = 2
# Non-positional kwargs are the second argument
if d.args[2] isa Expr && d.args[2].args[1].head == :kw
kwargs = parse_kwargs(d.args[2].args)
index += 1
# Positional kwargs
elseif length(d.args) > 6 && d.args[7].head == :kw
kwargs = parse_kwargs([d.args[7]])
end

name, dist_mean, dist_std, lb, ub = d.args[index:(index + 4)]
name = string(name)
lower_bound, upper_bound = parse.(Float64, string.((lb, ub)))
return constrained_gaussian(name, dist_mean, dist_std, lower_bound, upper_bound; kwargs...)

else
throw(ArgumentError("Unknown distribution type from symbol: $(dist_type_symb)"))
end
Expand Down
21 changes: 21 additions & 0 deletions test/TOMLInterface/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,26 @@ const EKP = EnsembleKalmanProcesses
[no_constraint(), no_constraint(), bounded_below(-5.0)],
"uq_param_8",
),
"uq_param_9" =>
ParameterDistribution(Parameterized(Normal(4.0, 0.17881264846405112)), [bounded(0, Inf)], "uq_param_9"),
"uq_param_10" => ParameterDistribution(
VectorOfParameterized([
Normal(4.0, 0.17881264846405112),
Normal(4.0, 0.17881264846405112),
Normal(4.0, 0.17881264846405112),
]),
[bounded(0, Inf), bounded(0, Inf), bounded(0, Inf)],
"uq_param_10",
),
"uq_param_11" => ParameterDistribution(
VectorOfParameterized([
Normal(4.0, 0.17881264846405112),
Normal(4.0, 0.17881264846405112),
Normal(4.0, 0.17881264846405112),
]),
[bounded(0, Inf), bounded(0, Inf), bounded(0, Inf)],
"uq_param_11",
),
)

# Get all `ParameterDistribution`s. We also add dummy (key, value) pairs
Expand All @@ -65,6 +85,7 @@ const EKP = EnsembleKalmanProcesses

@test_throws ArgumentError get_parameter_distribution(bad_param_dict, "uq_param_baddist")
@test_throws ArgumentError get_regularization(bad_param_dict, "uq_param_badL")
@test_throws ArgumentError get_parameter_distribution(bad_param_dict, "uq_param_bad_constrain_gauss")

for param_name in uq_param_names
param_dict[param_name]["description"] = param_name * descr
Expand Down
3 changes: 3 additions & 0 deletions test/TOMLInterface/toml/bad_param.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,6 @@ L2 = 3.5
prior = "NotParameterized(Normal(-100.0, 20.0))"
constraint = "no_constraint()"
L1 = 1.5

[uq_param_bad_constrain_gauss]
prior = "constrained_gaussian(uq_param_11, 55.47802418037957, 10, 0, Inf; repeats = 3, optim_algo = NelderMead())"
13 changes: 13 additions & 0 deletions test/TOMLInterface/toml/uq_test_parameters.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,19 @@ constraint = "repeat([no_constraint()], 3)"
prior = "VectorOfParameterized([Gamma(2.0, 3.0), LogNormal(0.1, 0.1), Normal(0.0, 10.0)])"
constraint = "[no_constraint(), no_constraint(), bounded_below(-5.0)]"

# Test for `constrained_gaussian` constructor
# Basic
[uq_param_9]
prior = "constrained_gaussian(uq_param_9, 55.47802418037957, 10, 0, Inf)"

# Non-positional kwarg
[uq_param_10]
prior = "constrained_gaussian(uq_param_10, 55.47802418037957, 10, 0, Inf; repeats = 3)"

# Positional kwarg
[uq_param_11]
prior = "constrained_gaussian(uq_param_11, 55.47802418037957, 10, 0, Inf, repeats = 3)"

# The six parameters below are interpreted as "regular" (non-UQ) parameters, as they
# they either have no key "prior", or a key "prior" that is set to "fixed"
[mean_sea_level_pressure]
Expand Down

0 comments on commit 4e11583

Please sign in to comment.