Skip to content

Commit

Permalink
test no_minibatching setup
Browse files Browse the repository at this point in the history
  • Loading branch information
odunbar committed Jun 21, 2024
1 parent 291f59c commit 5b4786b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
18 changes: 13 additions & 5 deletions src/Observations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,12 +231,12 @@ function FixedMinibatcher(minibatches::AV) where {AV <: AbstractVector}
return FixedMinibatcher(minibatches, def_method, def_rng)
end

function no_minibatcher() where {AV <: AbstractVector}
function no_minibatcher(size::Int = 1) #optional to provide size
# method
def_minibatch = [[1]]
def_minibatch = [collect(1:size)]
def_method = "order"
def_rng = Random.default_rng()
return FixedMinibatcher([[1]], def_method, def_rng)
return FixedMinibatcher(def_minibatch, def_method, def_rng)
end


Expand Down Expand Up @@ -411,8 +411,16 @@ function ObservationSeries(obs_vec::AV, minibatcher::MM) where {AV <: AbstractVe
return ObservationSeries(obs_vec, minibatcher, names, epoch)
end

function ObservationSeries(obs_vec::O, args...; kwargs...) where {O <: Observation}
return ObservationSeries([obs_vec], no_minibatcher(), args...; kwargs...)
function ObservationSeries(obs_vec::AV) where {AV <: AbstractVector}
len_epoch = length(obs_vec)
minibatcher = no_minibatcher(len_epoch)
names = ["series_$(string(i))" for i in 1:len_epoch]
epoch = collect(1:len_epoch)
return ObservationSeries(obs_vec, minibatcher, names, epoch)
end

function ObservationSeries(obs::O, args...; kwargs...) where {O <: Observation}
return ObservationSeries([obs], no_minibatcher(), args...; kwargs...)
end

function update_minibatch!(os::OS) where {OS <: ObservationSeries}
Expand Down
8 changes: 7 additions & 1 deletion test/Observations/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,13 @@ end
@test get_current_minibatch(observation_series) == new_epoch2[1]
@test get_minibatches(observation_series) == [new_epoch, new_epoch2]

# test the no minibatch option
observation_series_none = ObservationSeries(obs_vec)
@test get_current_minibatch(observation_series_none) == collect(1:length(obs_vec))
update_minibatch!(observation_series_none)
@test get_current_minibatch(observation_series_none) == collect(1:length(obs_vec))
@test get_current_minibatch_index(observation_series_none) == Dict("epoch" => 2, "minibatch" => 1)

# get_obs (def: build = true)
mb = new_epoch2[1]
obs_minibatch = get_obs(observation_series, build = false)
Expand All @@ -229,7 +236,6 @@ end
end
@test minibatch_covs == obs_noise_cov_minibatch_blocks


obs_noise_cov_minibatch_full = get_obs_noise_cov(observation_series)
minibatch_covs = []
for observation in obs_vec[mb]
Expand Down

0 comments on commit 5b4786b

Please sign in to comment.