Skip to content

Commit

Permalink
respond to review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
haakon-e committed Jul 5, 2023
1 parent c97f0d2 commit 938b39a
Showing 1 changed file with 8 additions and 16 deletions.
24 changes: 8 additions & 16 deletions src/ParameterDistributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -790,11 +790,7 @@ function transform_constrained_to_unconstrained(pd::ParameterDistribution, x::Ab
pd_batch_idxs = batch(pd, function_parameter_opt = "eval") # e.g. [collect(1:2), collect(3:3), collect(5:9)]
pd_constraints = get_all_constraints(pd, return_dict = true)

x_out = if eltype(pd.distribution) == GaussianRandomFieldInterface
zeros((sum(n_eval_pts.(pd.distribution)), length(axes(x, 2))))
else
similar(x)
end
x_out = similar(x, ndims(pd; function_parameter_opt="eval"), length(axes(x, 2)))
for (name, idxs, d) in zip(param_names, pd_batch_idxs, pd.distribution)
view(x_out, idxs, :) .= transform_constrained_to_unconstrained(d, pd_constraints[name], view(x, idxs, :))
end
Expand Down Expand Up @@ -841,11 +837,11 @@ end


function transform_unconstrained_to_constrained(
d::PDT,
d::ParameterDistributionType,
constraints::AbstractVector,
x::AbstractArray{FT};
x::AbstractArray{<:Real};
kwargs...,
) where {FT <: Real, PDT <: ParameterDistributionType}
)
x_out = similar(x)
for (out, in, c) in zip(eachrow(x_out), eachrow(x), constraints)
out .= c.unconstrained_to_constrained.(in)
Expand All @@ -863,22 +859,18 @@ Each column of `x` is a sample, and each row is a parameter.
"""
function transform_unconstrained_to_constrained(
pd::ParameterDistribution,
x::AbstractArray{FT};
x::AbstractArray{<:Real};
build_flag::Bool = true,
) where {FT <: Real}
)
param_names = get_name(pd)
pd_constraints = get_all_constraints(pd, return_dict = true)
eval_batch_idxs = batch(pd; function_parameter_opt = "eval")

# naive function parameter check, is x a dof vector, or the unconstrained evaluated function?
function_parameter_opt = build_flag ? "dof" : "eval"
pd_batch_idxs = batch(pd; function_parameter_opt)

x_out_size, eval_batch_idxs = if build_flag && (eltype(pd.distribution) == GaussianRandomFieldInterface)
(sum(n_eval_pts.(pd.distribution)), length(axes(x, 2))), batch(pd; function_parameter_opt = "eval")
else
size(x), pd_batch_idxs
end
x_out = similar(x, x_out_size)
x_out = similar(x, ndims(pd; function_parameter_opt="eval"), length(axes(x, 2)))
for (name, eval_idx, pd_idxs, d) in zip(param_names, eval_batch_idxs, pd_batch_idxs, pd.distribution)
view(x_out, eval_idx, :) .=
transform_unconstrained_to_constrained(d, pd_constraints[name], view(x, pd_idxs, :); build_flag)
Expand Down

0 comments on commit 938b39a

Please sign in to comment.