Skip to content

Commit

Permalink
plot recipes for marginal histograms
Browse files Browse the repository at this point in the history
added plots example into docs

format

add RecipesBase

actually add file...

formatting...

rm "using plots"

format...
  • Loading branch information
odunbar committed Mar 22, 2023
1 parent 69bcca5 commit 9bccc03
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 3 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
SCS = "c946c3f1-0d1f-5ce8-9dea-7daa1f7e2d13"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand Down
37 changes: 34 additions & 3 deletions docs/src/parameter_distributions.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,27 @@ The use case `constrained_gaussian()` addresses is when prior information is qua

The parameters of the Gaussian are chosen automatically (depending on the constraint) to reproduce the desired μ and σ — per the use case, other details of the form of the prior distribution shouldn't be important for downstream inference!

### Plotting

For quick visualization we have a plot recipe for `ParameterDistribution` types. This will plot marginal histograms for all dimensions of the parameter distribution. For example,

```@example snip1
# with values:
# e.g. lower_bound = 0.0, upper_bound = 1.0
# μ_1 = 0.5, σ_1 = 0.25
# μ_2 = 0.5, σ_2 = 0.25
using Plots
plot(prior)
```
One can also access the underlying Gaussian distributions in the unconstrained space with

```@example snip1
using Plots
plot(prior, constrained=false)
```


### Recommended constructor - Simple example

Task: We wish to create a prior for a one-dimensional parameter. Our problem dictates that this parameter is bounded between 0 and 1; domain knowledge leads us to expect it should be around 0.7. The parameter is called `point_seven`.
Expand Down Expand Up @@ -387,10 +408,10 @@ name2 = "constrained_sampled"
nothing # hide
```

The final parameter is 20-dimensional, defined as a list of i.i.d univariate distributions we make use of the `VectorOfParameterized` type
The final parameter is 4-dimensional, defined as a list of i.i.d univariate distributions we make use of the `VectorOfParameterized` type
```@example snip5
d3 = VectorOfParameterized(repeat([Beta(2,2)],20))
c3 = repeat([no_constraint()],20)
d3 = VectorOfParameterized(repeat([Beta(2,2)],4))
c3 = repeat([no_constraint()],4)
name3 = "Beta"
nothing # hide
```
Expand All @@ -413,6 +434,16 @@ u = ParameterDistribution([param_dict1, param_dict2, param_dict3])
nothing # hide
```

We can visualize the marginals of the constrained distributions,
```@example snip5
using Plots
plot(u)
```
and the unconstrained distributions similarly,
```@example snip5
using Plots
plot(u, constrained = false)
```

## ConstraintType Examples

Expand Down
2 changes: 2 additions & 0 deletions src/EnsembleKalmanProcesses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,6 @@ include("TOMLInterface.jl")
# algorithmic updates
include("EnsembleKalmanProcess.jl")

# Plot recipes
include("PlotRecipes.jl")
end # module
86 changes: 86 additions & 0 deletions src/PlotRecipes.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
module PlotRecipes

using RecipesBase
using Random
using ..ParameterDistributions

export plot_marginal_hist

@recipe function plot(pd::ParameterDistribution; constrained = true, n_sample = 1e4, rng = Random.GLOBAL_RNG)
samples = sample(rng, pd, Int(n_sample))
if constrained
samples = transform_unconstrained_to_constrained(pd, samples)
end

# First attempt, make it into a samples dist and plot histograms instead
n_plots = ndims(pd)
batches = batch(pd)

rows = Int(ceil(sqrt(n_plots)))
cols = Int(floor(sqrt(n_plots)))
tfs = 16
fs = 14

# subplot attr
legend := false
framestyle := repeat([:axes], n_plots)
grid := false

layout := n_plots
size --> (rows * 400, cols * 400)
titlefontsize --> tfs
xtickfontsize --> fs
ytickfontsize --> fs
xguidefontsize --> fs
yguidefontsize --> fs

for i in 1:n_plots
batch_id = [j for j = 1:length(batches) if i batches[j]][1]
dim_in_batch = i - minimum(batches[batch_id]) + 1 # i.e. if i=5 in batch 3:6, this would be "3"
@series begin
seriestype := :histogram
color := batch_id
subplot := i
title := pd.name[batch_id] * " (dim " * string(dim_in_batch) * ")"
samples[i, :]
end
end

end

@recipe function plot(d::PDT; n_sample = 1e4, rng = Random.GLOBAL_RNG) where {PDT <: ParameterDistributionType}
samples = sample(rng, d, Int(n_sample))

# First attempt, make it into a samples dist and plot histograms instead
n_plots = ndims(d)

size_l = Int(ceil(sqrt(n_plots)))
tfs = 16
fs = 14

# subplot attr
legend := false
framestyle := repeat([:axes], n_plots)
grid := false
layout := n_plots
size --> (size_l * 400, size_l * 400)
titlefontsize --> tfs
xtickfontsize --> fs
ytickfontsize --> fs
xguidefontsize --> fs
yguidefontsize --> fs

for i in 1:n_plots
@series begin
seriestype := :histogram
subplot := i
title := "dim " * string(i)
samples[i, :]
end
end

end



end #module
47 changes: 47 additions & 0 deletions test/PlotRecipes/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
using Test
using Distributions
using LinearAlgebra
using Random
using StatsBase

using EnsembleKalmanProcesses.ParameterDistributions
using EnsembleKalmanProcesses.PlotRecipes

if TEST_PLOT_OUTPUT
@testset "PlotRecipes" begin
@testset "Plot ParameterDistribution" begin
d1 = Parameterized(MvNormal(zeros(4), 0.1 * I))
c1 = [no_constraint(), bounded_below(-1.0), bounded_above(0.4), bounded(-0.1, 0.2)]
name1 = "constrained_mvnormal"
u1 = ParameterDistribution(d1, c1, name1)

d2 = Samples([1 2 3 4])
c2 = [bounded(10, 15)]
name2 = "constrained_sampled"
u2 = ParameterDistribution(d2, c2, name2)

d3 = VectorOfParameterized(repeat([Beta(2, 2)], 3))
c3 = repeat([no_constraint()], 3)
name3 = "vector_beta"
u3 = ParameterDistribution(d3, c3, name3)

u = combine_distributions([u1, u2, u3])

# PDTs
plt1 = plot(d1)
savefig(plt1, joinpath(@__DIR__, name1 * ".png"))
plt2 = plot(d2)
savefig(plt2, joinpath(@__DIR__, name2 * ".png"))
plt3 = plot(d3)
savefig(plt3, joinpath(@__DIR__, name3 * ".png"))

# full param dist
plt_constrained = plot(u)
savefig(plt_constrained, joinpath(@__DIR__, "full_dist_constrained.png"))
plt_unconstrained = plot(u, constrained = false)
savefig(plt_unconstrained, joinpath(@__DIR__, "full_dist_unconstrained.png"))
# just a redundant test to show that the plots were created
@test 1 == 1
end
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ end
for submodule in [
"DataContainers",
"ParameterDistributions",
"PlotRecipes",
"Observations",
"EnsembleKalmanProcess",
"Localizers",
Expand Down

0 comments on commit 9bccc03

Please sign in to comment.