Skip to content

Commit

Permalink
Checkpoint every 10 (#86)
Browse files Browse the repository at this point in the history
* rename single-letter variables

* ruff auto format

* fix checkpoint_model call in train_model

mypy v1.11.0 caught following issues
aviary/train.py:320: error: Unexpected keyword argument "model" for "checkpoint_model"  [call-arg]
aviary/train.py:320: error: Unexpected keyword argument "epoch" for "checkpoint_model"; did you mean "epochs"?  [call-arg]
aviary/train.py:328: error: Argument "timestamp" to "checkpoint_model" has incompatible type "str | None"; expected "str"  [arg-type]

* fix: remove strict from zip

* checkpoint every 10

* fea: working save every 10

---------

Co-authored-by: Janosh Riebesell <[email protected]>
  • Loading branch information
CompRhys and janosh committed Jul 27, 2024
1 parent d1bbfe1 commit aa120d4
Showing 1 changed file with 39 additions and 8 deletions.
47 changes: 39 additions & 8 deletions aviary/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from __future__ import annotations

import os
from datetime import datetime
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Literal

import numpy as np
Expand Down Expand Up @@ -56,6 +56,7 @@ def train_model(
train_loader: DataLoader | InMemoryDataLoader,
test_loader: DataLoader | InMemoryDataLoader,
checkpoint: Literal["local", "wandb"] | None = None,
checkpoint_frequency: int = 10,
learning_rate: float = 1e-4,
model_params: dict[str, Any] | None = None,
run_params: dict[str, Any] | None = None,
Expand Down Expand Up @@ -90,6 +91,7 @@ def train_model(
checkpoint = wandb.restore("checkpoint.pth", run_path)
torch.load(checkpoint.name)
```
checkpoint_frequency (int): How often to save a checkpoint. Defaults to 10.
learning_rate (float): The optimizer's learning rate. Defaults to 1e-4.
model_params (dict): Arguments passed to model class. E.g. dict(n_attn_layers=6,
embedding_aggregation=("mean", "std")) for Wrenformer.
Expand Down Expand Up @@ -227,7 +229,7 @@ def train_model(
**wandb_kwargs or {},
)

for epoch in tqdm(range(epochs), disable=None, desc="Training epoch"):
for epoch in tqdm(range(1, epochs + 1), disable=None, desc="Training epoch"):
train_metrics = model.evaluate(
train_loader,
loss_dict,
Expand Down Expand Up @@ -265,6 +267,25 @@ def train_model(
if wandb_path:
wandb.log({"training": train_metrics, "validation": val_metrics})

if epoch % checkpoint_frequency == 0 and epoch < epochs:
inference_model = swa_model if swa_start else model
inference_model.eval()
checkpoint_model(
checkpoint_endpoint=checkpoint,
model_params=model_params,
inference_model=inference_model,
optimizer_instance=optimizer_instance,
lr_scheduler=lr_scheduler,
loss_dict=loss_dict,
epochs=epoch,
test_metrics=val_metrics,
timestamp=timestamp,
run_name=run_name,
normalizer_dict=normalizer_dict,
run_params=run_params,
scheduler_name=scheduler_name,
)

# get test set predictions
if swa_start is not None:
n_swa_epochs = int((1 - swa_start) * epochs)
Expand Down Expand Up @@ -327,7 +348,7 @@ def train_model(
loss_dict=loss_dict,
epochs=epochs,
test_metrics=test_metrics,
timestamp=timestamp or datetime.now().astimezone().strftime("%Y%m%d-%H%M%S"),
timestamp=timestamp,
run_name=run_name,
normalizer_dict=normalizer_dict,
run_params=run_params,
Expand Down Expand Up @@ -365,21 +386,24 @@ def train_model(


def checkpoint_model(
checkpoint_endpoint: str,
checkpoint_endpoint: Literal["local", "wandb"] | None,
model_params: dict | None,
inference_model: nn.Module,
optimizer_instance: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
loss_dict: dict,
epochs: int,
test_metrics: dict,
timestamp: str,
timestamp: str | None,
run_name: str,
normalizer_dict: dict,
run_params: dict,
scheduler_name: str,
):
"""Save model checkpoint to different endpoints."""
if checkpoint_endpoint is None:
return

if model_params is None:
raise ValueError("Must provide model_params to save checkpoint, got None")

Expand All @@ -393,22 +417,29 @@ def checkpoint_model(
metrics=test_metrics,
run_name=run_name,
normalizer_dict=normalizer_dict,
run_params=run_params.copy(),
run_params=deepcopy(run_params),
)
if scheduler_name == "LambdaLR":
# exclude lr_lambda from pickled checkpoint since it causes errors when
# torch.load()-ing a checkpoint and the file defining lr_lambda() was
# renamed
checkpoint_dict["run_params"]["lr_scheduler"].pop("params")

if checkpoint_endpoint == "local":
os.makedirs(f"{ROOT}/models", exist_ok=True)
checkpoint_path = f"{ROOT}/models/{timestamp}-{run_name}.pth"
checkpoint_path = (
f"{ROOT}/models/{timestamp+'-' if timestamp else ''}{run_name}-{epochs}.pth"
)
torch.save(checkpoint_dict, checkpoint_path)

if checkpoint_endpoint == "wandb":
assert (
wandb.run is not None
), "can't save model checkpoint to Weights and Biases, wandb.run is None"
torch.save(checkpoint_dict, f"{wandb.run.dir}/checkpoint.pth")
torch.save(
checkpoint_dict,
f"{wandb.run.dir}/{timestamp+'-' if timestamp else ''}{run_name}-{epochs}.pth",
)


def train_wrenformer(
Expand Down

0 comments on commit aa120d4

Please sign in to comment.