Skip to content

Commit

Permalink
fea: don't calculate ens values if only one learner
Browse files Browse the repository at this point in the history
  • Loading branch information
CompRhys committed Jul 8, 2024
1 parent f2c41cd commit 82b4218
Showing 1 changed file with 31 additions and 10 deletions.
41 changes: 31 additions & 10 deletions aviary/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,21 +97,42 @@ def make_ensemble_predictions(
[model(*inputs)[0].cpu().numpy() for inputs, *_ in data_loader]
).squeeze()

pred_col = f"{target_col}_pred_{idx}" if target_col else f"pred_{idx}"

if model.robust:
preds, aleat_log_std = preds.T
df[f"aleatoric_std_{idx}"] = aleatoric_std = np.exp(aleat_log_std)
ale_col = (
f"{target_col}_aleatoric_std_{idx}"
if target_col
else f"aleatoric_std_{idx}"
)
df[pred_col] = preds
df[ale_col] = aleatoric_std = np.exp(aleat_log_std)
else:
df[pred_col] = preds

pred_col = f"{target_col}_pred_{idx}" if target_col else f"pred_{idx}"
df[pred_col] = preds
if len(checkpoint_paths) > 1:
df_preds = df.filter(regex=r"_pred_\d")

df_preds = df.filter(regex=r"_pred_\d")
df[f"{target_col}_pred_ens"] = ensemble_preds = df_preds.mean(axis=1)
df[f"{target_col}_epistemic_std_ens"] = epistemic_std = df_preds.std(axis=1)
pred_ens_col = f"{target_col}_pred_ens" if target_col else "pred_ens"
df[pred_ens_col] = ensemble_preds = df_preds.mean(axis=1)

if df.columns.str.startswith("aleatoric_std_").any():
aleatoric_std = df.filter(regex=r"aleatoric_std_\d").mean(axis=1)
df[f"{target_col}_aleatoric_std_ens"] = aleatoric_std
df[f"{target_col}_total_std_ens"] = (epistemic_std**2 + aleatoric_std**2) ** 0.5
pred_epi_std_ens = (
f"{target_col}_epistemic_std_ens" if target_col else "epistemic_std_ens"
)
df[pred_epi_std_ens] = epistemic_std = df_preds.std(axis=1)

if df.columns.str.startswith("aleatoric_std_").any():
pred_ale_std_ens = (
f"{target_col}_aleatoric_std_ens" if target_col else "aleatoric_std_ens"
)
pred_tot_std_ens = (
f"{target_col}_total_std_ens" if target_col else "total_std_ens"
)
df[pred_ale_std_ens] = aleatoric_std = df.filter(
regex=r"aleatoric_std_\d"
).mean(axis=1)
df[pred_tot_std_ens] = (epistemic_std**2 + aleatoric_std**2) ** 0.5

if target_col:
targets = df[target_col]
Expand Down

0 comments on commit 82b4218

Please sign in to comment.