Skip to content

Commit

Permalink
bugfix logpdf broadcasting (#364)
Browse files Browse the repository at this point in the history
* removed broadcasting bug for logpdfs

* rm prints

* sep Par. and Vec.Of.Par. tests

* add univariate eval case

* format

* codecov
  • Loading branch information
odunbar committed Jan 24, 2024
1 parent 8c360ea commit 3428c95
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 9 deletions.
30 changes: 23 additions & 7 deletions src/ParameterDistributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -649,19 +649,33 @@ sample(d::VectorOfParameterized) = sample(Random.GLOBAL_RNG, d, 1)
Obtains the independent logpdfs of the parameter distributions at `xarray`
(non-Samples Distributions only), and returns their sum.
"""
logpdf(d::Parameterized, xarray::AbstractVector{FT}) where {FT <: Real} = logpdf.(d.distribution, xarray)
logpdf(d::Parameterized, x::FT) where {FT <: Real} = logpdf(d, [x]) # make into 1D array

function logpdf(d::VectorOfParameterized, xarray::AbstractVector{FT}) where {FT <: Real}
function logpdf(d::Parameterized, xarray::VV) where {VV <: AbstractVector}
dimension = ndims(d)
if dimension != length(xarray)
throw(
DimensionMismatch("cannot evaluate logpdf with distribution $dimension on array length $(length(xarray))"),
)
end
if dimension == 1 # if univariate, requires scalar evaluation
return logpdf(d.distribution, xarray[1])
else
return logpdf(d.distribution, xarray)
end
end

function logpdf(d::VectorOfParameterized, xarray::VV) where {VV <: AbstractVector}
# get the index of xarray chunks to give to the different distributions.
batches = batch(d)
dimensions = get_dimensions(d)
lpdfsum = 0.0
# perform the logpdf of each of the distributions, and returns their sum
for (i, dd) in enumerate(d.distribution)
if dimensions[i] == 1
lpdfsum += logpdf.(dd, xarray[batches[i]])[1]
for (i, dd, dimen, batch) in zip(1:length(d.distribution), d.distribution, dimensions, batches)
if dimen == 1
lpdfsum += logpdf(dd, xarray[batch][1]) # needs to be eval on a scalar
else
lpdfsum += logpdf(dd, xarray[batches[i]])
lpdfsum += logpdf(dd, xarray[batch])
end
end
return lpdfsum
Expand All @@ -672,7 +686,7 @@ function logpdf(pd::ParameterDistribution, xarray::AbstractVector{FT}) where {FT
if any(isa.(pd.distribution, Samples))
throw(
ErrorException(
"Cannot compute logpdf of Samples distributions. Consider using a Parameterized type for your prior.",
"No implementation of logpdf of Samples distributions. Consider using a Parameterized type for your prior.",
),
)
end
Expand All @@ -692,6 +706,8 @@ function logpdf(pd::ParameterDistribution, xarray::AbstractVector{FT}) where {FT
return sum(sum(logpdf(d, xarray[batches[i]])) for (i, d) in enumerate(pd.distribution))
end

logpdf(pd::ParameterDistribution, x::FT) where {FT <: Real} = logpdf(pd, [x])

#extending StatsBase cov,var
"""
var(pd::ParameterDistribution)
Expand Down
19 changes: 17 additions & 2 deletions test/ParameterDistributions/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -538,11 +538,26 @@ using EnsembleKalmanProcesses.ParameterDistributions
@test_throws ErrorException logpdf(u, zeros(ndims(u)))
x_in_bd = [0.5, 0.5, 0.5]
Random.seed!(seed)
lpdf3 = sum([logpdf(Beta(2, 2), x_in_bd[1])[1], logpdf(MvNormal(zeros(2), 0.1 * I), x_in_bd[2:3])[1]]) #throws deprecated warning without "."

# for VectorOfParameterized
lpdf3 = sum([logpdf(Beta(2, 2), x_in_bd[1])[1], logpdf(MvNormal(zeros(2), 0.1 * I), x_in_bd[2:3])[1]]) #throws deprecated warning without "."
Random.seed!(seed)
@test isapprox(logpdf(u3, x_in_bd) - lpdf3, 0.0; atol = 1e-6)
@test_throws DimensionMismatch logpdf(u3, [0.5, 0.5])
# for Parameterized Multivar
x_in_bd = [0.0, 0.0, 0.0, 0.0]
@test isapprox(logpdf(u1, x_in_bd) - logpdf(MvNormal(zeros(4), 0.1 * I), x_in_bd)[1], 0.0, atol = 1e-6)
@test_throws DimensionMismatch logpdf(u1, [1])
# for Parameterized Univar
u5 = constrained_gaussian("u5", 3.0, 1.0, -Inf, Inf)
x_in_bd = 0.0
@test isapprox(logpdf(u5, x_in_bd) - logpdf(Normal(3.0, 1.0), x_in_bd)[1], 0.0, atol = 1e-6)
@test_throws DimensionMismatch logpdf(u1, [1, 1])
@test isapprox(
logpdf(Parameterized(Normal(3.0, 1.0)), x_in_bd) - logpdf(Normal(3.0, 1.0), x_in_bd)[1],
0.0,
atol = 1e-6,
)
@test_throws DimensionMismatch logpdf(Parameterized(Normal(3.0, 1.0)), [1, 1])

#Test for cov, var
block_cov = cat([cov(d1), var(d2), cov(d3), cov(d4)]..., dims = (1, 2))
Expand Down

0 comments on commit 3428c95

Please sign in to comment.