Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quickly plot marginals for parameter distribution #264

Merged
merged 1 commit into from
Mar 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
SCS = "c946c3f1-0d1f-5ce8-9dea-7daa1f7e2d13"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand All @@ -33,6 +34,7 @@ julia = "1.5"
[extras]
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"

[targets]
test = ["StableRNGs", "Test"]
test = ["StableRNGs", "Test", "Plots"]
37 changes: 34 additions & 3 deletions docs/src/parameter_distributions.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,27 @@ The use case `constrained_gaussian()` addresses is when prior information is qua

The parameters of the Gaussian are chosen automatically (depending on the constraint) to reproduce the desired μ and σ — per the use case, other details of the form of the prior distribution shouldn't be important for downstream inference!

### Plotting

For quick visualization we have a plot recipe for `ParameterDistribution` types. This will plot marginal histograms for all dimensions of the parameter distribution. For example,

```@example snip1
# with values:
# e.g. lower_bound = 0.0, upper_bound = 1.0
# μ_1 = 0.5, σ_1 = 0.25
# μ_2 = 0.5, σ_2 = 0.25
using Plots
plot(prior)
```
One can also access the underlying Gaussian distributions in the unconstrained space with

```@example snip1
using Plots
plot(prior, constrained=false)
```


### Recommended constructor - Simple example

Task: We wish to create a prior for a one-dimensional parameter. Our problem dictates that this parameter is bounded between 0 and 1; domain knowledge leads us to expect it should be around 0.7. The parameter is called `point_seven`.
Expand Down Expand Up @@ -387,10 +408,10 @@ name2 = "constrained_sampled"
nothing # hide
```

The final parameter is 20-dimensional, defined as a list of i.i.d univariate distributions we make use of the `VectorOfParameterized` type
The final parameter is 4-dimensional, defined as a list of i.i.d univariate distributions we make use of the `VectorOfParameterized` type
```@example snip5
d3 = VectorOfParameterized(repeat([Beta(2,2)],20))
c3 = repeat([no_constraint()],20)
d3 = VectorOfParameterized(repeat([Beta(2,2)],4))
c3 = repeat([no_constraint()],4)
name3 = "Beta"
nothing # hide
```
Expand All @@ -413,6 +434,16 @@ u = ParameterDistribution([param_dict1, param_dict2, param_dict3])
nothing # hide
```

We can visualize the marginals of the constrained distributions,
```@example snip5
using Plots
plot(u)
```
and the unconstrained distributions similarly,
```@example snip5
using Plots
plot(u, constrained = false)
```

## ConstraintType Examples

Expand Down
2 changes: 2 additions & 0 deletions src/EnsembleKalmanProcesses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,6 @@ include("TOMLInterface.jl")
# algorithmic updates
include("EnsembleKalmanProcess.jl")

# Plot recipes
include("PlotRecipes.jl")
end # module
88 changes: 88 additions & 0 deletions src/PlotRecipes.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
module PlotRecipes

using RecipesBase
using Random
using ..ParameterDistributions

export plot_marginal_hist

@recipe function plot(pd::ParameterDistribution; constrained = true, n_sample = 1e4, rng = Random.GLOBAL_RNG)
samples = sample(rng, pd, Int(n_sample))
if constrained
samples = transform_unconstrained_to_constrained(pd, samples)
end

# First attempt, make it into a samples dist and plot histograms instead
n_plots = ndims(pd)
batches = batch(pd)

rows = Int(ceil(sqrt(n_plots)))
cols = Int(floor(sqrt(n_plots)))
tfs = 16
fs = 12

# subplot attr
legend := false
framestyle := repeat([:axes], n_plots)
grid := false

layout := n_plots
size --> (rows * 400, cols * 400)
titlefontsize --> tfs
xtickfontsize --> fs
ytickfontsize --> fs
xguidefontsize --> fs
yguidefontsize --> fs

for i in 1:n_plots
batch_id = [j for j = 1:length(batches) if i ∈ batches[j]][1]
dim_in_batch = i - minimum(batches[batch_id]) + 1 # i.e. if i=5 in batch 3:6, this would be "3"
@series begin
seriestype := :histogram
normalize --> :pdf
color := batch_id
subplot := i
title := pd.name[batch_id] * " (dim " * string(dim_in_batch) * ")"
samples[i, :]
end
end

end

@recipe function plot(d::PDT; n_sample = 1e4, rng = Random.GLOBAL_RNG) where {PDT <: ParameterDistributionType}
samples = sample(rng, d, Int(n_sample))

# First attempt, make it into a samples dist and plot histograms instead
n_plots = ndims(d)

size_l = Int(ceil(sqrt(n_plots)))
tfs = 16
fs = 12

# subplot attr
legend := false
framestyle := repeat([:axes], n_plots)
grid := false
layout := n_plots
size --> (size_l * 400, size_l * 400)
titlefontsize --> tfs
xtickfontsize --> fs
ytickfontsize --> fs
xguidefontsize --> fs
yguidefontsize --> fs

for i in 1:n_plots
@series begin
seriestype := :histogram
normalize --> :pdf
subplot := i
title := "dim " * string(i)
samples[i, :]
end
end

end



end #module
47 changes: 47 additions & 0 deletions test/PlotRecipes/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
using Test
using Distributions
using LinearAlgebra
using Random
using StatsBase

using EnsembleKalmanProcesses.ParameterDistributions
using EnsembleKalmanProcesses.PlotRecipes

if TEST_PLOT_OUTPUT
@testset "PlotRecipes" begin
@testset "Plot ParameterDistribution" begin
d1 = Parameterized(MvNormal(zeros(4), 0.1 * I))
c1 = [no_constraint(), bounded_below(-1.0), bounded_above(0.4), bounded(-0.1, 0.2)]
name1 = "constrained_mvnormal"
u1 = ParameterDistribution(d1, c1, name1)

d2 = Samples([1 2 3 4])
c2 = [bounded(10, 15)]
name2 = "constrained_sampled"
u2 = ParameterDistribution(d2, c2, name2)

d3 = VectorOfParameterized(repeat([Beta(2, 2)], 3))
c3 = repeat([no_constraint()], 3)
name3 = "vector_beta"
u3 = ParameterDistribution(d3, c3, name3)

u = combine_distributions([u1, u2, u3])

# PDTs
plt1 = plot(d1)
savefig(plt1, joinpath(@__DIR__, name1 * ".png"))
plt2 = plot(d2)
savefig(plt2, joinpath(@__DIR__, name2 * ".png"))
plt3 = plot(d3)
savefig(plt3, joinpath(@__DIR__, name3 * ".png"))

# full param dist
plt_constrained = plot(u)
savefig(plt_constrained, joinpath(@__DIR__, "full_dist_constrained.png"))
plt_unconstrained = plot(u, constrained = false)
savefig(plt_unconstrained, joinpath(@__DIR__, "full_dist_unconstrained.png"))
# just a redundant test to show that the plots were created
@test 1 == 1
end
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ end
for submodule in [
"DataContainers",
"ParameterDistributions",
"PlotRecipes",
"Observations",
"EnsembleKalmanProcess",
"Localizers",
Expand Down