Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add convenient method for minibatching data #382

Closed
Tracked by #373
odunbar opened this issue Jun 17, 2024 · 0 comments · Fixed by #384
Closed
Tracked by #373

Add convenient method for minibatching data #382

odunbar opened this issue Jun 17, 2024 · 0 comments · Fixed by #384

Comments

@odunbar
Copy link
Collaborator

odunbar commented Jun 17, 2024

Allow users to provide a batching mechanism and list of observations and corresponding noise

Solution (Updated)

Overhaul the existing Observation object and construct a new ObservationSeries object.

  • The Observations can be combined and stacked to build new Observations (akin to the ParameterDistribution objects).
  • Then the ObservationSeries will take a vector of such Observations and a Minibatcher mechanism that will sample epochs of the provided series.
  • EKP now builds and stores Observation quantities (like obs_mean and obs_noise_cov inside an ObservationSeries object), so all calls to .obs_mean and .obs_noise_cov will now be getter function from the ObservationSeries object and will provide the current minibatch.

The key objects and methods for this solution:

  • The individual (possibly stacked) Observation
struct Observation
    samples::Vector{Vector} # list of observation samples
    covs::Vector{Matrix} # list of covs for each observations
    names::Vector{String} # names of the observation samples
    indices::Vector{UnitRange} # indices of each for stacking the Observation (necessary?)
end
  • The series of Observations for batching
struct ObservationSeries
    observations::Vector{Observation} # vector of observations
    batches::Vector{Vector} # batches grouped by epoch
    current_batch_index::Dict # holds 2 indices (epoch#, batch#) of the latest batch
    batcher<:MiniBatcher # Object to construct the next epoch of batches
    names::Vector String # holds id's for the observation samples
end
  • Methods such as
get_obs(os::ObservationSeries)
get_obs_noise_cov(os::ObservationSeries)
update_minibatch!(os::ObservationSeries)

Some forms of the Minibatcher

The Batches are created by the following object, and its method:

struct MiniBatcher end
function generate_batches_for_epoch(b::MiniBatcher)

Simple instance for default/fixed batches by user

struct DeterministicMiniBatcher <: MiniBatcher
    batches::Vector{Vector}
end
generate_batches_for_epoch(b::DeterministicMiniBatcher) = get_batches(b)

Simple instance for batching into batch_size batches

struct  RandomMiniBatcher <: MiniBatcher
    batch_size::Int
    n_samples_per_epoch::Int
    rng::GlobalRNG
end

function generate_batches_for_epoch(b::RandomMiniBatcher) 
    N = get_n_samples_per_epoch(b)
    rng = get_rng(b)
    indices = shuffle(rng, collect(1:N))
    bs = get_batch_size(b)
    n_batches = Int(floor(N/bs))
    return  [ i < n_batches ? 
             indices[ (i-1) *batch_size+1 : i*batch_size] : # bs sized batches
             indices[ n_batches*batch_size+1 : end] # final large batch < 2*bs sized
             for i = 1:n_batches 
            ]
end
  • The Observation object will update the batch at the end of the EKP update with a manual call to "update batch index" that will either iterate current_batch_index, or will make a call to the MiniBatchers to start a new epoch.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
1 participant