Skip to content

Commit

Permalink
fix: predict_from_wandb_checkpoints allow to return when target col i…
Browse files Browse the repository at this point in the history
…s none
  • Loading branch information
CompRhys committed Jul 2, 2024
1 parent 3f66c73 commit 4f10075
Showing 1 changed file with 17 additions and 11 deletions.
28 changes: 17 additions & 11 deletions aviary/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,20 +137,23 @@ def make_ensemble_predictions(
@print_walltime(end_desc="predict_from_wandb_checkpoints")
def predict_from_wandb_checkpoints(
runs: list[wandb.apis.public.Run], cache_dir: str, **kwargs: Any
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""Download and cache checkpoints for an ensemble of models, then make predictions on some
dataset. Finally print ensemble metrics and store predictions to CSV.
) -> pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame]:
"""Download and cache checkpoints for an ensemble of models, then make
predictions on some dataset. Finally print ensemble metrics and store
predictions to CSV.
Args:
runs (list[wandb.apis.public.Run]): List of WandB runs to download model checkpoints from
which are then loaded into memory to generate predictions for the input_col in df.
runs (list[wandb.apis.public.Run]): List of WandB runs to download model
checkpoints from which are then loaded into memory to generate
predictions for the input_col in df.
cache_dir (str): Directory to cache downloaded checkpoints in.
**kwargs: Additional keyword arguments to pass to make_ensemble_predictions().
Returns:
tuple[pd.DataFrame, pd.DataFrame]: Original input dataframe with added columns for model
predictions and uncertainties. The 2nd dataframe holds ensemble performance metrics
like mean and standard deviation of MAE/RMSE.
pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame]: Original input dataframe
with added columns for model predictions and uncertainties. The optional
2nd dataframe holds ensemble performance metrics like mean and standard
deviation of MAE/RMSE.
"""
print(f"Using checkpoints from {len(runs)} run(s):")

Expand Down Expand Up @@ -180,7 +183,10 @@ def predict_from_wandb_checkpoints(
if not os.path.isfile(checkpoint_path):
run.file("checkpoint.pth").download(root=out_dir)

df, ensemble_metrics = make_ensemble_predictions(checkpoint_paths, **kwargs)
if target_col in kwargs:
df, ensemble_metrics = make_ensemble_predictions(checkpoint_paths, **kwargs)
# round to save disk space and speed up cloud storage uploads
return df.round(6), ensemble_metrics

# round to save disk space and speed up cloud storage uploads
return df.round(6), ensemble_metrics
df = make_ensemble_predictions(checkpoint_paths, **kwargs)
return df.round(6)

0 comments on commit 4f10075

Please sign in to comment.