Skip to content

Commit

Permalink
fea: extract checkpoint function
Browse files Browse the repository at this point in the history
  • Loading branch information
CompRhys committed Jul 25, 2024
1 parent c235b55 commit 3289c4d
Showing 1 changed file with 64 additions and 25 deletions.
89 changes: 64 additions & 25 deletions aviary/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,14 @@
from aviary.wrenformer.data import df_to_in_mem_dataloader
from aviary.wrenformer.model import Wrenformer

try:
import wandb
except ImportError:
wandb = None

if TYPE_CHECKING:
from torch import nn

from aviary.data import InMemoryDataLoader

__author__ = "Janosh Riebesell"
Expand Down Expand Up @@ -206,8 +213,6 @@ def train_model(
print(f"{run_params=}")

if wandb_path:
import wandb

if wandb.run is None:
wandb.login()
wandb_entity, wandb_project = wandb_path.split("/")
Expand Down Expand Up @@ -290,6 +295,7 @@ def train_model(
if test_df is None:
assert isinstance(test_loader, DataLoader)
test_df = test_loader.dataset.df

if robust:
preds, aleatoric_log_std = np.split(preds, 2, axis=1)
preds = preds.squeeze()
Expand All @@ -311,34 +317,20 @@ def train_model(

# save model checkpoint
if checkpoint is not None:
if model_params is None:
raise ValueError("Must provide model_params to save checkpoint, got None")
checkpoint_dict = dict(
model_params=model_params,
model_state=inference_model.state_dict(),
optimizer_state=optimizer_instance.state_dict(),
scheduler_state=lr_scheduler.state_dict(),
checkpoint_model(
checkpoint_endpoint=checkpoint,
model=inference_model,
optimizer_instance=optimizer_instance,
lr_scheduler=lr_scheduler,
loss_dict=loss_dict,
epoch=epochs,
metrics=test_metrics,
test_metrics=test_metrics,
timestamp=timestamp,
run_name=run_name,
normalizer_dict=normalizer_dict,
run_params=run_params.copy(),
run_params=run_params,
scheduler_name=scheduler_name,
)
if scheduler_name == "LambdaLR":
# exclude lr_lambda from pickled checkpoint since it causes errors when
# torch.load()-ing a checkpoint and the file defining lr_lambda() was
# renamed
checkpoint_dict["run_params"]["lr_scheduler"].pop("params")
if checkpoint == "local":
os.makedirs(f"{ROOT}/models", exist_ok=True)
checkpoint_path = f"{ROOT}/models/{timestamp}-{run_name}.pth"
torch.save(checkpoint_dict, checkpoint_path)
if checkpoint == "wandb":
assert (
wandb.run is not None
), "can't save model checkpoint to Weights and Biases, wandb.run is None"
torch.save(checkpoint_dict, f"{wandb.run.dir}/checkpoint.pth")

# record test set metrics and scatter/ROC plots to wandb
if wandb_path:
Expand Down Expand Up @@ -370,6 +362,53 @@ def train_model(
return test_metrics, run_params, test_df


def checkpoint_model(
checkpoint_endpoint: str,
model_params: dict,
inference_model: nn.Module,
optimizer_instance: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
loss_dict: dict,
epochs: int,
test_metrics: dict,
timestamp: str,
run_name: str,
normalizer_dict: dict,
run_params: dict,
scheduler_name: str,
):
"""Save model checkpoint to different endpoints."""
if model_params is None:
raise ValueError("Must provide model_params to save checkpoint, got None")

checkpoint_dict = dict(
model_params=model_params,
model_state=inference_model.state_dict(),
optimizer_state=optimizer_instance.state_dict(),
scheduler_state=lr_scheduler.state_dict(),
loss_dict=loss_dict,
epoch=epochs,
metrics=test_metrics,
run_name=run_name,
normalizer_dict=normalizer_dict,
run_params=run_params.copy(),
)
if scheduler_name == "LambdaLR":
# exclude lr_lambda from pickled checkpoint since it causes errors when
# torch.load()-ing a checkpoint and the file defining lr_lambda() was
# renamed
checkpoint_dict["run_params"]["lr_scheduler"].pop("params")
if checkpoint_endpoint == "local":
os.makedirs(f"{ROOT}/models", exist_ok=True)
checkpoint_path = f"{ROOT}/models/{timestamp}-{run_name}.pth"
torch.save(checkpoint_dict, checkpoint_path)
if checkpoint_endpoint == "wandb":
assert (
wandb.run is not None
), "can't save model checkpoint to Weights and Biases, wandb.run is None"
torch.save(checkpoint_dict, f"{wandb.run.dir}/checkpoint.pth")


def train_wrenformer(
run_name: str,
target_col: str,
Expand Down

0 comments on commit 3289c4d

Please sign in to comment.