Skip to content

Commit

Permalink
Allow InMemoryDataLoader to store tensors in RAM (#81)
Browse files Browse the repository at this point in the history
* fea: allow in mem dl to store in ram

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* lint

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
CompRhys and pre-commit-ci[bot] committed Jun 18, 2024
1 parent 2221f41 commit 28c9ab6
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 4 deletions.
26 changes: 23 additions & 3 deletions aviary/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
task_dict: dict[str, TaskType],
robust: bool,
epoch: int = 0,
device: str | None = None,
best_val_scores: dict[str, float] | None = None,
) -> None:
"""Store core model parameters.
Expand All @@ -47,17 +48,22 @@ def __init__(
loss function to attenuate the weighting of uncertain samples.
epoch (int, optional): Epoch model training will begin/resume from.
Defaults to 0.
device (str, optional): Device to store the model parameters on.
best_val_scores (dict[str, float], optional): Validation score to use for
early stopping. Defaults to None.
"""
super().__init__()
self.task_dict = task_dict
self.target_names = list(task_dict)
self.robust = robust
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
self.epoch = epoch
self.best_val_scores = best_val_scores or {}
self.es_patience = 0

self.to(self.device)
self.model_params: dict[str, Any] = {"task_dict": task_dict}

def fit(
Expand Down Expand Up @@ -256,7 +262,10 @@ def evaluate(
for inputs, targets_list, *_ in tqdm(
data_loader, disable=None if pbar else True
):
# compute output
inputs = [ # noqa: PLW2901
tensor.to(self.device) if hasattr(tensor, "to") else tensor
for tensor in inputs
]
outputs = self(*inputs)

mixed_loss: Tensor = 0 # type: ignore[assignment]
Expand All @@ -270,6 +279,7 @@ def evaluate(
if task == "regression":
assert normalizer is not None
targets = normalizer.norm(targets).squeeze() # noqa: PLW2901
targets = targets.to(self.device) # noqa: PLW2901

if self.robust:
preds, log_std = output.unbind(dim=1)
Expand All @@ -284,6 +294,8 @@ def evaluate(
target_metrics["MSE"].append(float(error.pow(2).mean()))

elif task == "classification":
targets = targets.to(self.device) # noqa: PLW2901

if self.robust:
pre_logits, log_std = output.chunk(2, dim=1)
logits = sampled_softmax(pre_logits, log_std)
Expand Down Expand Up @@ -370,6 +382,10 @@ def predict(
for inputs, targets, *batch_ids in tqdm(
data_loader, disable=True if not verbose else None
):
inputs = [ # noqa: PLW2901
tensor.to(self.device) if hasattr(tensor, "to") else tensor
for tensor in inputs
]
preds = self(*inputs) # forward pass to get model preds

test_ids.append(batch_ids)
Expand Down Expand Up @@ -407,8 +423,12 @@ def featurize(self, data_loader: DataLoader) -> np.ndarray:
self.eval() # ensure model is in evaluation mode
features = []

for input_, *_ in data_loader:
output = self.trunk_nn(self.material_nn(*input_)).cpu().numpy()
for inputs, *_ in data_loader:
inputs = [ # noqa: PLW2901
tensor.to(self.device) if hasattr(tensor, "to") else tensor
for tensor in inputs
]
output = self.trunk_nn(self.material_nn(*inputs)).cpu().numpy()
features.append(output)

return np.vstack(features)
Expand Down
18 changes: 17 additions & 1 deletion aviary/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def train_model(
else (torch.nn.NLLLoss() if robust else torch.nn.CrossEntropyLoss())
)
loss_dict = {target_col: (task_type, loss_func)}

normalizer_dict = {target_col: Normalizer() if task_type == reg_key else None}
# TODO consider actually fitting the normalizer, currently just passed into
# model.evaluate() to match function signature
Expand Down Expand Up @@ -272,7 +273,19 @@ def train_model(

with torch.no_grad():
preds = np.concatenate(
[inference_model(*inputs)[0].cpu().numpy() for inputs, *_ in test_loader]
[
inference_model(
*[
tensor.to(inference_model.device)
if hasattr(tensor, "to")
else tensor
for tensor in inputs
]
)[0]
.cpu()
.numpy()
for inputs, *_ in test_loader
]
).squeeze()

if test_df is None:
Expand Down Expand Up @@ -371,6 +384,7 @@ def train_wrenformer(
id_col: str = "material_id",
input_col: str | None = None,
model_params: dict[str, Any] | None = None,
data_loader_device: str = "cpu",
**kwargs,
) -> tuple[dict[str, float], dict[str, Any], pd.DataFrame]:
"""Train a Wrenformer model on a dataframe. This function handles the DataLoader
Expand All @@ -396,6 +410,7 @@ def train_wrenformer(
run_name which default to 'wyckoff' and 'composition' respectively.
model_params (dict): Passed to Wrenformer class. E.g. dict(n_attn_layers=6,
embedding_aggregation=("mean", "std")).
data_loader_device(str): device to store the InMemoryDataLoader's tensors on.
**kwargs: Additional keyword arguments are passed to train_model().
Returns:
Expand All @@ -419,6 +434,7 @@ def train_wrenformer(
input_col=input_col,
id_col=id_col,
embedding_type=embedding_type,
device=data_loader_device,
)
train_loader = df_to_in_mem_dataloader(
train_df,
Expand Down

0 comments on commit 28c9ab6

Please sign in to comment.