Skip to content

Commit

Permalink
Allow checkpoint name to be set explicitly. (#87)
Browse files Browse the repository at this point in the history
* rename single-letter variables

* ruff auto format

* fix checkpoint_model call in train_model

mypy v1.11.0 caught following issues
aviary/train.py:320: error: Unexpected keyword argument "model" for "checkpoint_model"  [call-arg]
aviary/train.py:320: error: Unexpected keyword argument "epoch" for "checkpoint_model"; did you mean "epochs"?  [call-arg]
aviary/train.py:328: error: Argument "timestamp" to "checkpoint_model" has incompatible type "str | None"; expected "str"  [arg-type]

* fix: remove strict from zip

* checkpoint every 10

* fea: working save every 10

* fea: checkpoint filename

---------

Co-authored-by: Janosh Riebesell <[email protected]>
  • Loading branch information
CompRhys and janosh committed Aug 1, 2024
1 parent aa120d4 commit 71f1c62
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions aviary/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,10 @@ def make_ensemble_predictions(

@print_walltime(end_desc="predict_from_wandb_checkpoints")
def predict_from_wandb_checkpoints(
runs: list[wandb.apis.public.Run], cache_dir: str, **kwargs: Any
runs: list[wandb.apis.public.Run],
checkpoint_filename: str = "checkpoint.pth",
cache_dir: str = "./checkpoint_cache",
**kwargs: Any,
) -> pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame]:
"""Download and cache checkpoints for an ensemble of models, then make
predictions on some dataset. Finally print ensemble metrics and store
Expand All @@ -167,6 +170,7 @@ def predict_from_wandb_checkpoints(
runs (list[wandb.apis.public.Run]): List of WandB runs to download model
checkpoints from which are then loaded into memory to generate
predictions for the input_col in df.
checkpoint_filename (str): Name of the checkpoint file to download.
cache_dir (str): Directory to cache downloaded checkpoints in.
**kwargs: Additional keyword arguments to pass to make_ensemble_predictions().
Expand Down Expand Up @@ -194,15 +198,15 @@ def predict_from_wandb_checkpoints(
out_dir = f"{cache_dir}/{run_path}"
os.makedirs(out_dir, exist_ok=True)

checkpoint_path = f"{out_dir}/checkpoint.pth"
checkpoint_path = f"{out_dir}/{checkpoint_filename}"
checkpoint_paths.append(checkpoint_path)
print(f"{idx:>3}/{len(runs)}: {run.url}\n\t{checkpoint_path}\n")

with open(f"{out_dir}/run.md", "w") as md_file:
md_file.write(f"[{run.name}]({run.url})\n")

if not os.path.isfile(checkpoint_path):
run.file("checkpoint.pth").download(root=out_dir)
run.file(f"{checkpoint_filename}").download(root=out_dir)

if target_col in kwargs:
df, ensemble_metrics = make_ensemble_predictions(checkpoint_paths, **kwargs)
Expand Down

0 comments on commit 71f1c62

Please sign in to comment.