Skip to content

Commit

Permalink
Add constrained_gaussian TOML parser
Browse files Browse the repository at this point in the history
  • Loading branch information
nefrathenrici committed Nov 11, 2023
1 parent 4bcb401 commit 924e4f9
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 1 deletion.
14 changes: 13 additions & 1 deletion src/TOMLInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module TOMLInterface
using TOML
using Distributions
using EnsembleKalmanProcesses.ParameterDistributions

using Optim

# Exports
export path_to_ensemble_member
Expand Down Expand Up @@ -56,6 +56,10 @@ function get_parameter_distribution(param_dict::Dict, name::AbstractString)
# Constructing a parameter distribution requires a prior distribution,
# a constraint, and a name.
prior = construct_prior(param_dict[name])
# If constrained_gaussian, then prior is already a ParameterDistribution
if prior isa ParameterDistribution
return prior
end
constraint = construct_constraint(param_dict[name])

return ParameterDistribution(prior, constraint, name)
Expand Down Expand Up @@ -243,6 +247,14 @@ function get_distribution_from_expr(d::Expr)

return dist_type(dist_args)

elseif dist_type_symb == Symbol("constrained_gaussian")
# Check for non-positional kwargs
if d.args[2] isa Expr && d.args[2].args[1].head == :kw
d.args[3] = string(d.args[3])
else
d.args[2] = string(d.args[2])
end
return eval(d)
else
throw(ArgumentError("Unknown distribution type from symbol: $(dist_type_symb)"))
end
Expand Down
20 changes: 20 additions & 0 deletions test/TOMLInterface/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,26 @@ const EKP = EnsembleKalmanProcesses
[no_constraint(), no_constraint(), bounded_below(-5.0)],
"uq_param_8",
),
"uq_param_9" =>
ParameterDistribution(Parameterized(Normal(4.0, 0.17881264846405112)), [bounded(0, Inf)], "uq_param_9"),
"uq_param_10" => ParameterDistribution(
VectorOfParameterized([
Normal(4.0, 0.17881264846405112),
Normal(4.0, 0.17881264846405112),
Normal(4.0, 0.17881264846405112),
]),
[bounded(0, Inf), bounded(0, Inf), bounded(0, Inf)],
"uq_param_10",
),
"uq_param_11" => ParameterDistribution(
VectorOfParameterized([
Normal(4.0, 0.17881264846405112),
Normal(4.0, 0.17881264846405112),
Normal(4.0, 0.17881264846405112),
]),
[bounded(0, Inf), bounded(0, Inf), bounded(0, Inf)],
"uq_param_11",
),
)

# Get all `ParameterDistribution`s. We also add dummy (key, value) pairs
Expand Down
13 changes: 13 additions & 0 deletions test/TOMLInterface/toml/uq_test_parameters.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,19 @@ constraint = "repeat([no_constraint()], 3)"
prior = "VectorOfParameterized([Gamma(2.0, 3.0), LogNormal(0.1, 0.1), Normal(0.0, 10.0)])"
constraint = "[no_constraint(), no_constraint(), bounded_below(-5.0)]"

# Test for `constrained_gaussian` constructor
# Basic
[uq_param_9]
prior = "constrained_gaussian(uq_param_9, 55.47802418037957, 10, 0, Inf)"

# Non-positional kwarg
[uq_param_10]
prior = "constrained_gaussian(uq_param_10, 55.47802418037957, 10, 0, Inf; repeats = 3, optim_algorithm = SimulatedAnnealing())"

# Positional kwarg
[uq_param_11]
prior = "constrained_gaussian(uq_param_11, 55.47802418037957, 10, 0, Inf, repeats = 3, optim_algorithm = SimulatedAnnealing())"

# The six parameters below are interpreted as "regular" (non-UQ) parameters, as they
# they either have no key "prior", or a key "prior" that is set to "fixed"
[mean_sea_level_pressure]
Expand Down

0 comments on commit 924e4f9

Please sign in to comment.