Skip to content

Commit

Permalink
Try #264:
Browse files Browse the repository at this point in the history
  • Loading branch information
bors[bot] committed Mar 22, 2023
2 parents 69bcca5 + 9bccc03 commit 60fd1e4
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 3 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
SCS = "c946c3f1-0d1f-5ce8-9dea-7daa1f7e2d13"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand Down
37 changes: 34 additions & 3 deletions docs/src/parameter_distributions.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,27 @@ The use case `constrained_gaussian()` addresses is when prior information is qua

The parameters of the Gaussian are chosen automatically (depending on the constraint) to reproduce the desired μ and σ — per the use case, other details of the form of the prior distribution shouldn't be important for downstream inference!

### Plotting

For quick visualization we have a plot recipe for `ParameterDistribution` types. This will plot marginal histograms for all dimensions of the parameter distribution. For example,

```@example snip1
# with values:
# e.g. lower_bound = 0.0, upper_bound = 1.0
# μ_1 = 0.5, σ_1 = 0.25
# μ_2 = 0.5, σ_2 = 0.25
using Plots
plot(prior)
```
One can also access the underlying Gaussian distributions in the unconstrained space with

```@example snip1
using Plots
plot(prior, constrained=false)
```


### Recommended constructor - Simple example

Task: We wish to create a prior for a one-dimensional parameter. Our problem dictates that this parameter is bounded between 0 and 1; domain knowledge leads us to expect it should be around 0.7. The parameter is called `point_seven`.
Expand Down Expand Up @@ -387,10 +408,10 @@ name2 = "constrained_sampled"
nothing # hide
```

The final parameter is 20-dimensional, defined as a list of i.i.d univariate distributions we make use of the `VectorOfParameterized` type
The final parameter is 4-dimensional, defined as a list of i.i.d univariate distributions we make use of the `VectorOfParameterized` type
```@example snip5
d3 = VectorOfParameterized(repeat([Beta(2,2)],20))
c3 = repeat([no_constraint()],20)
d3 = VectorOfParameterized(repeat([Beta(2,2)],4))
c3 = repeat([no_constraint()],4)
name3 = "Beta"
nothing # hide
```
Expand All @@ -413,6 +434,16 @@ u = ParameterDistribution([param_dict1, param_dict2, param_dict3])
nothing # hide
```

We can visualize the marginals of the constrained distributions,
```@example snip5
using Plots
plot(u)
```
and the unconstrained distributions similarly,
```@example snip5
using Plots
plot(u, constrained = false)
```

## ConstraintType Examples

Expand Down
2 changes: 2 additions & 0 deletions src/EnsembleKalmanProcesses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,6 @@ include("TOMLInterface.jl")
# algorithmic updates
include("EnsembleKalmanProcess.jl")

# Plot recipes
include("PlotRecipes.jl")
end # module
86 changes: 86 additions & 0 deletions src/PlotRecipes.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
module PlotRecipes

using RecipesBase
using Random
using ..ParameterDistributions

export plot_marginal_hist

@recipe function plot(pd::ParameterDistribution; constrained = true, n_sample = 1e4, rng = Random.GLOBAL_RNG)
samples = sample(rng, pd, Int(n_sample))
if constrained
samples = transform_unconstrained_to_constrained(pd, samples)
end

# First attempt, make it into a samples dist and plot histograms instead
n_plots = ndims(pd)
batches = batch(pd)

rows = Int(ceil(sqrt(n_plots)))
cols = Int(floor(sqrt(n_plots)))
tfs = 16
fs = 14

# subplot attr
legend := false
framestyle := repeat([:axes], n_plots)
grid := false

layout := n_plots
size --> (rows * 400, cols * 400)
titlefontsize --> tfs
xtickfontsize --> fs
ytickfontsize --> fs
xguidefontsize --> fs
yguidefontsize --> fs

for i in 1:n_plots
batch_id = [j for j = 1:length(batches) if i batches[j]][1]
dim_in_batch = i - minimum(batches[batch_id]) + 1 # i.e. if i=5 in batch 3:6, this would be "3"
@series begin
seriestype := :histogram
color := batch_id
subplot := i
title := pd.name[batch_id] * " (dim " * string(dim_in_batch) * ")"
samples[i, :]
end
end

end

@recipe function plot(d::PDT; n_sample = 1e4, rng = Random.GLOBAL_RNG) where {PDT <: ParameterDistributionType}
samples = sample(rng, d, Int(n_sample))

# First attempt, make it into a samples dist and plot histograms instead
n_plots = ndims(d)

size_l = Int(ceil(sqrt(n_plots)))
tfs = 16
fs = 14

# subplot attr
legend := false
framestyle := repeat([:axes], n_plots)
grid := false
layout := n_plots
size --> (size_l * 400, size_l * 400)
titlefontsize --> tfs
xtickfontsize --> fs
ytickfontsize --> fs
xguidefontsize --> fs
yguidefontsize --> fs

for i in 1:n_plots
@series begin
seriestype := :histogram
subplot := i
title := "dim " * string(i)
samples[i, :]
end
end

end



end #module
47 changes: 47 additions & 0 deletions test/PlotRecipes/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
using Test
using Distributions
using LinearAlgebra
using Random
using StatsBase

using EnsembleKalmanProcesses.ParameterDistributions
using EnsembleKalmanProcesses.PlotRecipes

if TEST_PLOT_OUTPUT
@testset "PlotRecipes" begin
@testset "Plot ParameterDistribution" begin
d1 = Parameterized(MvNormal(zeros(4), 0.1 * I))
c1 = [no_constraint(), bounded_below(-1.0), bounded_above(0.4), bounded(-0.1, 0.2)]
name1 = "constrained_mvnormal"
u1 = ParameterDistribution(d1, c1, name1)

d2 = Samples([1 2 3 4])
c2 = [bounded(10, 15)]
name2 = "constrained_sampled"
u2 = ParameterDistribution(d2, c2, name2)

d3 = VectorOfParameterized(repeat([Beta(2, 2)], 3))
c3 = repeat([no_constraint()], 3)
name3 = "vector_beta"
u3 = ParameterDistribution(d3, c3, name3)

u = combine_distributions([u1, u2, u3])

# PDTs
plt1 = plot(d1)
savefig(plt1, joinpath(@__DIR__, name1 * ".png"))
plt2 = plot(d2)
savefig(plt2, joinpath(@__DIR__, name2 * ".png"))
plt3 = plot(d3)
savefig(plt3, joinpath(@__DIR__, name3 * ".png"))

# full param dist
plt_constrained = plot(u)
savefig(plt_constrained, joinpath(@__DIR__, "full_dist_constrained.png"))
plt_unconstrained = plot(u, constrained = false)
savefig(plt_unconstrained, joinpath(@__DIR__, "full_dist_unconstrained.png"))
# just a redundant test to show that the plots were created
@test 1 == 1
end
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ end
for submodule in [
"DataContainers",
"ParameterDistributions",
"PlotRecipes",
"Observations",
"EnsembleKalmanProcess",
"Localizers",
Expand Down

0 comments on commit 60fd1e4

Please sign in to comment.