Skip to content

Commit

Permalink
optimize posterior state storage in SG-MCMC methods (awslabs#68)
Browse files Browse the repository at this point in the history
* optimize posterior state storage in SG-MCMC methods

* lint the code

* address review feedback

* simplify MCMC checkpoints handling and fix GPU compatibility
  • Loading branch information
master committed May 24, 2023
1 parent 9681d30 commit fcf4335
Show file tree
Hide file tree
Showing 12 changed files with 286 additions and 81 deletions.
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
from typing import Optional
import pathlib

from fortuna.training.train_state import TrainState
from fortuna.training.callback import Callback
from fortuna.training.train_state_repository import TrainStateRepository
from fortuna.training.trainer import TrainerABC
from fortuna.typing import Path

from fortuna.prob_model.posterior.sgmcmc.sgmcmc_sampling_callback import (
SGMCMCSamplingCallback,
Expand All @@ -24,7 +20,6 @@ def __init__(
trainer: TrainerABC,
state_repository: TrainStateRepository,
keep_top_n_checkpoints: int,
save_checkpoint_dir: Optional[Path] = None,
):
"""
Cyclical Stochastic Gradient Langevin Dynamics (SGLD) callback that collects samples
Expand Down Expand Up @@ -53,14 +48,11 @@ def __init__(
An instance of the state repository.
keep_top_n_checkpoints: int
Number of past checkpoint files to keep.
save_checkpoint_dir: Optional[Path]
The optional path to save checkpoints.
"""
super().__init__(
trainer=trainer,
state_repository=state_repository,
keep_top_n_checkpoints=keep_top_n_checkpoints,
save_checkpoint_dir=save_checkpoint_dir,
)

self._do_sample = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from fortuna.prob_model.posterior.run_preliminary_map import (
run_preliminary_map,
)
from fortuna.prob_model.posterior.posterior_multi_state_repository import (
PosteriorMultiStateRepository,
from fortuna.prob_model.posterior.sgmcmc.sgmcmc_posterior_state_repository import (
SGMCMCPosteriorStateRepository,
)
from fortuna.prob_model.posterior.sgmcmc.sgmcmc_posterior import (
SGMCMCPosterior,
Expand Down Expand Up @@ -138,13 +138,26 @@ def fit(
else:
state = self._init_map_state(map_state, train_data_loader, fit_config)

if fit_config.optimizer.freeze_fun is not None:
which_params = get_trainable_paths(
params=state.params, freeze_fun=fit_config.optimizer.freeze_fun
)
else:
which_params = None

state = CyclicalSGLDState.convert_from_map_state(
map_state=state,
optimizer=fit_config.optimizer.method,
which_params=which_params,
)

state = super()._freeze_optimizer_in_state(state, fit_config)

self.state = PosteriorMultiStateRepository(
self.state = SGMCMCPosteriorStateRepository(
size=self.posterior_approximator.n_samples,
checkpoint_dir=fit_config.checkpointer.save_checkpoint_dir
if fit_config.checkpointer.dump_state is True
else None,
checkpoint_dir=fit_config.checkpointer.save_checkpoint_dir,
which_params=which_params,
all_params=state.params if which_params else None,
)

cyclical_sampling_callback = CyclicalSGLDSamplingCallback(
Expand All @@ -157,16 +170,8 @@ def fit(
trainer=trainer,
state_repository=self.state,
keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints,
save_checkpoint_dir=fit_config.checkpointer.save_checkpoint_dir,
)

state = CyclicalSGLDState.convert_from_map_state(
map_state=state,
optimizer=fit_config.optimizer.method,
)

state = super()._freeze_optimizer_in_state(state, fit_config)

logging.info(f"Run CyclicalSGLD.")
state, status = trainer.train(
rng=self.rng.get(),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
from __future__ import annotations

from typing import (
Dict,
List,
Tuple,
Optional,
)

import jax.numpy as jnp

from fortuna.prob_model.posterior.state import PosteriorState
from fortuna.utils.strings import convert_string_to_jnp_array
from fortuna.utils.strings import (
convert_string_to_jnp_array,
encode_tuple_of_lists_of_strings_to_numpy,
)
from fortuna.prob_model.posterior.map.map_state import MAPState
from fortuna.typing import OptaxOptimizer
from fortuna.typing import (
AnyKey,
Array,
OptaxOptimizer,
)


class CyclicalSGLDState(PosteriorState):
Expand All @@ -17,10 +31,14 @@ class CyclicalSGLDState(PosteriorState):
"""

encoded_name: jnp.ndarray = convert_string_to_jnp_array("CyclicalSGLDState")
_encoded_which_params: Optional[Dict[str, List[Array]]] = None

@classmethod
def convert_from_map_state(
cls, map_state: MAPState, optimizer: OptaxOptimizer
cls,
map_state: MAPState,
optimizer: OptaxOptimizer,
which_params: Tuple[List[AnyKey], ...],
) -> CyclicalSGLDState:
"""
Convert a MAP state into an CyclicalSGLDState state.
Expand All @@ -31,16 +49,20 @@ def convert_from_map_state(
A MAP posterior state.
optimizer: OptaxOptimizer
An Optax optimizer.
which_params: Tuple[List[AnyKey], ...]
Sequences of keys pointing to the stochastic parameters.
Returns
-------
SGHMCState
An SGHMC state.
CyclicalSGLDState
An Cyclical SGLD state.
"""
return CyclicalSGLDState.init(
_encoded_which_params = encode_tuple_of_lists_of_strings_to_numpy(which_params)
return cls.init(
params=map_state.params,
mutable=map_state.mutable,
optimizer=optimizer,
calib_params=map_state.calib_params,
calib_mutable=map_state.calib_mutable,
_encoded_which_params=_encoded_which_params,
)
8 changes: 0 additions & 8 deletions fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_callback.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
from typing import Optional
import pathlib

from fortuna.training.train_state import TrainState
from fortuna.training.callback import Callback
from fortuna.training.train_state_repository import TrainStateRepository
from fortuna.training.trainer import TrainerABC
from fortuna.typing import Path
from fortuna.prob_model.posterior.sgmcmc.sgmcmc_sampling_callback import (
SGMCMCSamplingCallback,
)
Expand All @@ -22,7 +18,6 @@ def __init__(
trainer: TrainerABC,
state_repository: TrainStateRepository,
keep_top_n_checkpoints: int,
save_checkpoint_dir: Optional[Path] = None,
):
"""
Stochastic Gradient Hamiltonian Monte Carlo (SGHMC) callback that collects samples
Expand All @@ -46,14 +41,11 @@ def __init__(
An instance of the state repository.
keep_top_n_checkpoints: int
Number of past checkpoint files to keep.
save_checkpoint_dir: Optional[Path]
The optional path to save checkpoints.
"""
super().__init__(
trainer=trainer,
state_repository=state_repository,
keep_top_n_checkpoints=keep_top_n_checkpoints,
save_checkpoint_dir=save_checkpoint_dir,
)

self._do_sample = (
Expand Down
24 changes: 15 additions & 9 deletions fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_posterior.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import logging
from typing import Optional
from itertools import cycle
import pathlib
from flax.core import FrozenDict

from flax.core import FrozenDict
from fortuna.utils.freeze import get_trainable_paths
from fortuna.utils.nested_dicts import nested_set, nested_get
from fortuna.data.loader import DataLoader
Expand All @@ -19,8 +18,8 @@
)
from fortuna.prob_model.posterior.map.map_posterior import MAPPosterior
from fortuna.prob_model.posterior.map.map_state import MAPState
from fortuna.prob_model.posterior.posterior_multi_state_repository import (
PosteriorMultiStateRepository,
from fortuna.prob_model.posterior.sgmcmc.sgmcmc_posterior_state_repository import (
SGMCMCPosteriorStateRepository,
)
from fortuna.prob_model.posterior.sgmcmc.sgmcmc_posterior import (
SGMCMCPosterior,
Expand Down Expand Up @@ -136,18 +135,26 @@ def fit(
else:
state = self._init_map_state(map_state, train_data_loader, fit_config)

if fit_config.optimizer.freeze_fun is not None:
which_params = get_trainable_paths(
params=state.params, freeze_fun=fit_config.optimizer.freeze_fun
)
else:
which_params = None

state = SGHMCState.convert_from_map_state(
map_state=state,
optimizer=fit_config.optimizer.method,
which_params=which_params,
)

state = super()._freeze_optimizer_in_state(state, fit_config)

self.state = PosteriorMultiStateRepository(
self.state = SGMCMCPosteriorStateRepository(
size=self.posterior_approximator.n_samples,
checkpoint_dir=fit_config.checkpointer.save_checkpoint_dir
if fit_config.checkpointer.dump_state is True
else None,
checkpoint_dir=fit_config.checkpointer.save_checkpoint_dir,
which_params=which_params,
all_params=state.params if which_params else None,
)

sghmc_sampling_callback = SGHMCSamplingCallback(
Expand All @@ -159,7 +166,6 @@ def fit(
trainer=trainer,
state_repository=self.state,
keep_top_n_checkpoints=fit_config.checkpointer.keep_top_n_checkpoints,
save_checkpoint_dir=fit_config.checkpointer.save_checkpoint_dir,
)

logging.info(f"Run SGHMC.")
Expand Down
30 changes: 26 additions & 4 deletions fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_state.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
from __future__ import annotations

from typing import (
Dict,
List,
Tuple,
Optional,
)

import jax.numpy as jnp

from fortuna.prob_model.posterior.state import PosteriorState
from fortuna.utils.strings import convert_string_to_jnp_array
from fortuna.utils.strings import (
convert_string_to_jnp_array,
encode_tuple_of_lists_of_strings_to_numpy,
)
from fortuna.prob_model.posterior.map.map_state import MAPState
from fortuna.typing import OptaxOptimizer
from fortuna.typing import (
AnyKey,
Array,
OptaxOptimizer,
)


class SGHMCState(PosteriorState):
Expand All @@ -17,10 +31,14 @@ class SGHMCState(PosteriorState):
"""

encoded_name: jnp.ndarray = convert_string_to_jnp_array("SGHMCState")
_encoded_which_params: Optional[Dict[str, List[Array]]] = None

@classmethod
def convert_from_map_state(
cls, map_state: MAPState, optimizer: OptaxOptimizer
cls,
map_state: MAPState,
optimizer: OptaxOptimizer,
which_params: Tuple[List[AnyKey], ...],
) -> SGHMCState:
"""
Convert a MAP state into an SGHMC state.
Expand All @@ -31,16 +49,20 @@ def convert_from_map_state(
A MAP posterior state.
optimizer: OptaxOptimizer
An Optax optimizer.
which_params: Tuple[List[AnyKey], ...]
Sequences of keys pointing to the stochastic parameters.
Returns
-------
SGHMCState
An SGHMC state.
"""
return SGHMCState.init(
_encoded_which_params = encode_tuple_of_lists_of_strings_to_numpy(which_params)
return cls.init(
params=map_state.params,
mutable=map_state.mutable,
optimizer=optimizer,
calib_params=map_state.calib_params,
calib_mutable=map_state.calib_mutable,
_encoded_which_params=_encoded_which_params,
)
15 changes: 11 additions & 4 deletions fortuna/prob_model/posterior/sgmcmc/sgmcmc_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
from fortuna.prob_model.fit_config.base import FitConfig
from fortuna.prob_model.joint.state import JointState
from fortuna.prob_model.posterior.map.map_state import MAPState
from fortuna.prob_model.posterior.posterior_multi_state_repository import (
PosteriorMultiStateRepository,
from fortuna.prob_model.posterior.sgmcmc.sgmcmc_posterior_state_repository import (
SGMCMCPosteriorStateRepository,
)
from fortuna.typing import Path
from fortuna.utils.strings import decode_encoded_tuple_of_lists_of_strings_to_array


class SGMCMCPosterior(Posterior):
Expand Down Expand Up @@ -43,6 +44,7 @@ def sample(
self.state.get(i=0),
random.choice(rng, self.posterior_approximator.n_samples),
)

return JointState(
params=state.params,
mutable=state.mutable,
Expand All @@ -52,14 +54,19 @@ def sample(

def load_state(self, checkpoint_dir: Path) -> None:
try:
self.restore_checkpoint(pathlib.Path(checkpoint_dir) / "0")
state = self.restore_checkpoint(pathlib.Path(checkpoint_dir) / "c")
except ValueError:
raise ValueError(
f"No checkpoint was found in `checkpoint_dir={checkpoint_dir}`."
)
self.state = PosteriorMultiStateRepository(
which_params = decode_encoded_tuple_of_lists_of_strings_to_array(
state._encoded_which_params
)
self.state = SGMCMCPosteriorStateRepository(
size=self.posterior_approximator.n_samples,
checkpoint_dir=checkpoint_dir,
which_params=which_params,
all_params=state.params if which_params else None,
)

def save_state(self, checkpoint_dir: Path, keep_top_n_checkpoints: int = 1) -> None:
Expand Down
Loading

0 comments on commit fcf4335

Please sign in to comment.