Skip to content

Commit

Permalink
fix: don't hard code the validation batch size
Browse files Browse the repository at this point in the history
  • Loading branch information
CompRhys committed Jul 25, 2024
1 parent 2c76739 commit c235b55
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions aviary/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,9 +304,7 @@ def train_model(
targets = test_df[target_col]
# preds can have shape (n_samples, n_classes) if doing multi-class classification so
# use df to merge all columns into test_df
df_preds = pd.DataFrame(preds, index=test_df.index).add_prefix(
f"{target_col}_pred_"
)
df_preds = pd.DataFrame(preds, index=test_df.index).add_prefix(f"{target_col}_pred_")
test_df[df_preds.columns] = df_preds # requires shuffle=False for test_loader

test_metrics = get_metrics(targets, preds, task_type)
Expand Down Expand Up @@ -379,6 +377,7 @@ def train_wrenformer(
train_df: pd.DataFrame,
test_df: pd.DataFrame,
batch_size: int = 128,
inference_multiplier: int = 4,
embedding_type: str | None = None,
id_col: str = "material_id",
input_col: str | None = None,
Expand All @@ -400,6 +399,8 @@ def train_wrenformer(
train_df (pd.DataFrame): Training set dataframe.
test_df (pd.DataFrame): Test set dataframe.
batch_size (int, optional): Batch size for training. Defaults to 128.
inference_multiplier (int, optional): Multiplier for the test set data loader
batch size. Defaults to 1.
embedding_type ('wyckoff' | 'composition', optional): Type of embedding to use.
Defaults to None meaning auto-detect based on 'wren'/'roost' in run_name.
id_col (str, optional): Column name in train_df and test_df containing unique
Expand Down Expand Up @@ -444,7 +445,7 @@ def train_wrenformer(

test_loader = df_to_in_mem_dataloader(
test_df,
batch_size=512,
batch_size=batch_size * inference_multiplier,
shuffle=False,
**data_loader_kwargs, # type: ignore[arg-type]
)
Expand Down

0 comments on commit c235b55

Please sign in to comment.