Skip to content

Commit

Permalink
added storage of observation inverses
Browse files Browse the repository at this point in the history
  • Loading branch information
odunbar committed Jun 25, 2024
1 parent 74a7710 commit 1b4889f
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 14 deletions.
102 changes: 94 additions & 8 deletions src/Observations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ using Random
export Observation, Minibatcher, FixedMinibatcher, RandomFixedSizeMinibatcher, ObservationSeries
export get_samples,
get_covs,
get_inv_covs,
get_names,
get_indices,
combine_observations,
get_obs_noise_cov,
get_obs_noise_cov_inv,
get_obs,
create_new_epoch!,
get_minibatches,
Expand All @@ -36,33 +38,43 @@ Structure that contains a (possibly stacked) observation. Defined by sample(s),
$(TYPEDFIELDS)
"""
struct Observation{AV1 <: AbstractVector, AV2 <: AbstractVector, AV3 <: AbstractVector, AV4 <: AbstractVector}
struct Observation{
AV1 <: AbstractVector,
AV2 <: AbstractVector,
AV3 <: AbstractVector,
AV4 <: AbstractVector,
AV5 <: AbstractVector,
}
"A (vector of) observation vectors"
samples::AV1
"A (vector of) observation covariance matrices"
covs::AV2
"A (vector of) inverses of observation covariance matrices"
inv_covs::AV3
"A (vector of) name strings"
names::AV3
names::AV4
"A (vector of) indices of the contained observation blocks"
indices::AV4
indices::AV5
end

get_samples(o::Observation) = o.samples
get_covs(o::Observation) = o.covs
get_inv_covs(o::Observation) = o.inv_covs
get_names(o::Observation) = o.names
get_indices(o::Observation) = o.indices

function Observation(obs_dict::Dict)
if !all(["samples", "names", "covariances"] .∈ [collect(keys(obs_dict))])
throw(
ArgumentError(
"input dictionaries must contain the keys: \"samples\", \"names\", \"covariances\". Got $(keys(obs_dict))",
"input dictionaries must contain the keys: \"samples\", \"names\", \"covariances\", and optionally: \"inv_covariances\". Got $(keys(obs_dict))",
),
)
end
samples = obs_dict["samples"]
covariances = obs_dict["covariances"]
names = obs_dict["names"]

if !isa(samples, AbstractVector) # 1 -> [[1]]
snew = [[samples]]
else
Expand All @@ -80,17 +92,33 @@ function Observation(obs_dict::Dict)
T = promote_type((typeof(c) for c in covariances)...)
cnew = [convert(T, c) for c in covariances] # to re-infer eltype
end

if !("inv_covariances" collect(keys(obs_dict)))
inv_covariances = []
for c in cnew # ensures its a vector
push!(inv_covariances, inv(c))
end
else
inv_covariances = obs_dict["inv_covariances"]
end
if !isa(inv_covariances, AbstractVector) # [2 1;1 2] -> [[2 1;1 2]]
icnew = [inv_covariances]
else
T = promote_type((typeof(c) for c in inv_covariances)...)
icnew = [convert(T, c) for c in inv_covariances] # to re-infer eltype
end

if !isa(names, AbstractVector) # "name" -> ["name"]
nnew = [names]
else
T = promote_type((typeof(n) for n in names)...)
nnew = [convert(T, n) for n in names] # to re-infer eltype
end

if !all([length(snew) == length(cnew), length(nnew) == length(cnew)])
if !all([length(snew) == length(cnew), length(nnew) == length(cnew), length(icnew) == length(cnew)])
throw(
ArgumentError(
"input dictionaries must contain the same number of objects. Got $(length(snew)) samples, $(length(cnew)) covs, and $(length(nnew)) names.",
"input dictionaries must contain the same number of objects. Got $(length(snew)) samples, $(length(cnew)) covs, $(length(icnew)) inv_covs, and $(length(nnew)) names.",
),
)
end
Expand All @@ -103,7 +131,7 @@ function Observation(obs_dict::Dict)
end
end

return Observation(snew, cnew, nnew, indices)
return Observation(snew, cnew, icnew, nnew, indices)

end

Expand All @@ -113,13 +141,15 @@ function combine_observations(obs_vec::AV) where {AV <: AbstractVector}

snew = []
cnew = []
icnew = []
nnew = []
inew = []
shift = [0] # running shift to add to indexing
for obs in obs_vec
@assert(nameof(typeof(obs)) == :Observation) # check it's a vector of Observations
append!(snew, get_samples(obs))
append!(cnew, get_covs(obs))
append!(icnew, get_inv_covs(obs))
append!(nnew, get_names(obs))
indices = get_indices(obs)
shifted_indices = [ind .+ shift[1] for ind in get_indices(obs)]
Expand All @@ -132,12 +162,14 @@ function combine_observations(obs_vec::AV) where {AV <: AbstractVector}
snew2 = [convert(T, s) for s in snew]
T = promote_type((typeof(c) for c in cnew)...)
cnew2 = [convert(T, c) for c in cnew]
T = promote_type((typeof(c) for c in icnew)...)
icnew2 = [convert(T, c) for c in icnew]
T = promote_type((typeof(n) for n in nnew)...)
nnew2 = [convert(T, n) for n in nnew]
T = promote_type((typeof(i) for i in inew)...)
inew2 = [convert(T, i) for i in inew]

return Observation(snew2, cnew2, nnew2, inew2)
return Observation(snew2, cnew2, icnew2, nnew2, inew2)
end

function get_obs(o::Observation; build = true)
Expand Down Expand Up @@ -176,6 +208,28 @@ function get_obs_noise_cov(o::Observation; build = true)
end
end

function get_obs_noise_cov_inv(o::Observation; build = true)

if !build # return the blocks directly
return get_inv_covs(o)
else # return the blocked matrix
indices = get_indices(o)
inv_covs = get_inv_covs(o)
inv_cov_full = zeros(maximum(indices[end]), maximum(indices[end]))
for (idx, c) in zip(indices, inv_covs)
if isa(c, UniformScaling)
for idxx in idx
inv_cov_full[idxx, idxx] = c.λ
end
else
inv_cov_full[idx, idx] .= c
end
end

return inv_cov_full
end
end

function Base.:(==)(ob_a::OB1, ob_b::OB2) where {OB1 <: Observation, OB2 <: Observation}
fn = unique([fieldnames(OB1)...; fieldnames(OB2)...])
x = [false for f in fn]
Expand Down Expand Up @@ -549,6 +603,38 @@ function get_obs_noise_cov(os::OS; build = true) where {OS <: ObservationSeries}

end

function get_obs_noise_cov_inv(os::OS; build = true) where {OS <: ObservationSeries}
minibatch = get_current_minibatch(os) # gives the indices of the minibatch
minibatch_length = length(minibatch)
observations_vec = get_observations(os)[minibatch] # gives observation objects
if minibatch_length == 1 # if only 1 sample then return it
return get_obs_noise_cov_inv(observations_vec[1], build = build)
else
minibatch_inv_covs = []
for observation in observations_vec
push!(minibatch_inv_covs, get_obs_noise_cov_inv(observation, build = build)) #
end
end
if !build # return the blocks directly
return reduce(vcat, minibatch_inv_covs)
else # return the blocked matrix
block_sizes = size.(minibatch_inv_covs, 1) # square mats
minibatch_inv_cov_full = zeros(sum(block_sizes), sum(block_sizes))
idx_min = [0]
for (i, mc) in enumerate(minibatch_inv_covs)
idx = (idx_min[1] + 1):(idx_min[1] + block_sizes[i])
minibatch_inv_cov_full[idx, idx] .= mc
idx_min[1] += block_sizes[i]
end

return minibatch_inv_cov_full
end

end




function Base.:(==)(os_a::OS1, os_b::OS2) where {OS1 <: ObservationSeries, OS2 <: ObservationSeries}
fn = unique([fieldnames(OS1)...; fieldnames(OS2)...])
x = [false for f in fn]
Expand Down
64 changes: 58 additions & 6 deletions test/Observations/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Test
using Random
using Statistics

using LinearAlgebra
using EnsembleKalmanProcesses


Expand All @@ -12,9 +12,25 @@ using EnsembleKalmanProcesses
n_samples = length(sample_sizes)
samples = []
covariances = []
inv_covariances = []
for i in 1:n_samples
push!(samples, vec(i * ones(sample_sizes[i])))
push!(covariances, i * ones(sample_sizes[i], sample_sizes[i]))
if (i == 3)
X = I
elseif (i == 4)
X = I
else
X = randn(sample_sizes[i], sample_sizes[i])
end

push!(covariances, i * X' * X)
if !(i == 3) # take inverse if not == 3
ic = inv(i * X' * X)
else # here submit a user-defined inverse covariance (not true inverse)
ic = I
end
push!(inv_covariances, ic)

end
names = ["d$(string(i))" for i in 1:n_samples]

Expand All @@ -32,21 +48,29 @@ using EnsembleKalmanProcesses
observation_1 = Observation(obs_dict)
@test get_samples(observation_1) == [samples[1]] # all stored as a vec
@test get_covs(observation_1) == [covariances[1]]
@test all(isapprox.(get_inv_covs(observation_1)[1], inv_covariances[1], atol = 1e-10)) # inversion approximate
@test get_names(observation_1) == [names[1]]
@test get_indices(observation_1) == [indices[1]]

# 2) via a dict [vec]
obs_dict = Dict("samples" => samples[2:4], "covariances" => covariances[2:4], "names" => names[2:4])
# 2) via a dict [vec], pass in inv_covs
obs_dict = Dict(
"samples" => samples[2:4],
"covariances" => covariances[2:4],
"inv_covariances" => inv_covariances[2:4],
"names" => names[2:4],
)
observation_2_4 = Observation(obs_dict)
@test get_samples(observation_2_4) == samples[2:4]
@test get_covs(observation_2_4) == covariances[2:4]
@test get_inv_covs(observation_2_4) == inv_covariances[2:4]
@test get_names(observation_2_4) == names[2:4]
@test get_indices(observation_2_4) == [id .- maximum(indices[1]) for id in indices[2:4]] # shifted

# 2) via combining Observations
observation = combine_observations([observation_1, observation_2_4])
@test get_samples(observation) == samples
@test get_covs(observation) == covariances
@test all(isapprox.(get_inv_covs(observation), inv_covariances, atol = 1e-10))
@test get_names(observation) == names
@test get_indices(observation) == indices # correctly shifted back

Expand All @@ -63,9 +87,34 @@ using EnsembleKalmanProcesses

full = zeros(maximum(indices[end]), maximum(indices[end]))
for (idx, c) in zip(indices, covariances)
full[idx, idx] = c
if isa(c, UniformScaling)
for idxx in idx
full[idxx, idxx] = c.λ
end
else
full[idx, idx] .= c
end
end
@test onc_full == full

# get_obs_noise_cov_inv
onci_block = get_obs_noise_cov_inv(observation, build = false)
onci_full = get_obs_noise_cov_inv(observation) # default build=true
@test onci_block == inv_covariances

full = zeros(maximum(indices[end]), maximum(indices[end]))
for (idx, c) in zip(indices, inv_covariances)
if isa(c, UniformScaling)
for idxx in idx
full[idxx, idxx] = c.λ
end
else
full[idx, idx] .= c
end
end
@test onci_full == full


end

@testset "Minibatching" begin
Expand Down Expand Up @@ -168,7 +217,10 @@ end

for i in 1:n_samples
push!(samples, vec(sid .+ i * ones(sample_sizes[i])))
push!(covariances, sid .+ i * ones(sample_sizes[i], sample_sizes[i]))

X = randn(sample_sizes[i], sample_sizes[i])

push!(covariances, sid .+ i * X' * X)
end
names = ["d$(string(i))" for i in 1:n_samples]

Expand Down

0 comments on commit 1b4889f

Please sign in to comment.