Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
CompRhys committed Jul 11, 2024
2 parents 6678d96 + e7fd108 commit 121f519
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 49 deletions.
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ ci:

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.5
rev: v0.5.0
hooks:
- id: ruff
args: [--fix]
- id: ruff-format

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v4.6.0
hooks:
- id: check-case-conflict
- id: check-symlinks
Expand All @@ -23,14 +23,14 @@ repos:
- id: trailing-whitespace

- repo: https://github.com/codespell-project/codespell
rev: v2.2.6
rev: v2.3.0
hooks:
- id: codespell
exclude_types: [json]
args: [--check-filenames]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.9.0
rev: v1.10.1
hooks:
- id: mypy
exclude: (tests|examples)/
Expand Down
69 changes: 48 additions & 21 deletions aviary/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,21 +97,42 @@ def make_ensemble_predictions(
[model(*inputs)[0].cpu().numpy() for inputs, *_ in data_loader]
).squeeze()

pred_col = f"{target_col}_pred_{idx}" if target_col else f"pred_{idx}"

if model.robust:
preds, aleat_log_std = preds.T
df[f"aleatoric_std_{idx}"] = aleatoric_std = np.exp(aleat_log_std)
ale_col = (
f"{target_col}_aleatoric_std_{idx}"
if target_col
else f"aleatoric_std_{idx}"
)
df[pred_col] = preds
df[ale_col] = aleatoric_std = np.exp(aleat_log_std)
else:
df[pred_col] = preds

pred_col = f"{target_col}_pred_{idx}" if target_col else f"pred_{idx}"
df[pred_col] = preds
if len(checkpoint_paths) > 1:
df_preds = df.filter(regex=r"_pred_\d")

df_preds = df.filter(regex=r"_pred_\d")
df[f"{target_col}_pred_ens"] = ensemble_preds = df_preds.mean(axis=1)
df[f"{target_col}_epistemic_std_ens"] = epistemic_std = df_preds.std(axis=1)
pred_ens_col = f"{target_col}_pred_ens" if target_col else "pred_ens"
df[pred_ens_col] = ensemble_preds = df_preds.mean(axis=1)

if df.columns.str.startswith("aleatoric_std_").any():
aleatoric_std = df.filter(regex=r"aleatoric_std_\d").mean(axis=1)
df[f"{target_col}_aleatoric_std_ens"] = aleatoric_std
df[f"{target_col}_total_std_ens"] = (epistemic_std**2 + aleatoric_std**2) ** 0.5
pred_epi_std_ens = (
f"{target_col}_epistemic_std_ens" if target_col else "epistemic_std_ens"
)
df[pred_epi_std_ens] = epistemic_std = df_preds.std(axis=1)

if df.columns.str.startswith("aleatoric_std_").any():
pred_ale_std_ens = (
f"{target_col}_aleatoric_std_ens" if target_col else "aleatoric_std_ens"
)
pred_tot_std_ens = (
f"{target_col}_total_std_ens" if target_col else "total_std_ens"
)
df[pred_ale_std_ens] = aleatoric_std = df.filter(
regex=r"aleatoric_std_\d"
).mean(axis=1)
df[pred_tot_std_ens] = (epistemic_std**2 + aleatoric_std**2) ** 0.5

if target_col:
targets = df[target_col]
Expand All @@ -137,20 +158,23 @@ def make_ensemble_predictions(
@print_walltime(end_desc="predict_from_wandb_checkpoints")
def predict_from_wandb_checkpoints(
runs: list[wandb.apis.public.Run], cache_dir: str, **kwargs: Any
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""Download and cache checkpoints for an ensemble of models, then make predictions on some
dataset. Finally print ensemble metrics and store predictions to CSV.
) -> pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame]:
"""Download and cache checkpoints for an ensemble of models, then make
predictions on some dataset. Finally print ensemble metrics and store
predictions to CSV.
Args:
runs (list[wandb.apis.public.Run]): List of WandB runs to download model checkpoints from
which are then loaded into memory to generate predictions for the input_col in df.
runs (list[wandb.apis.public.Run]): List of WandB runs to download model
checkpoints from which are then loaded into memory to generate
predictions for the input_col in df.
cache_dir (str): Directory to cache downloaded checkpoints in.
**kwargs: Additional keyword arguments to pass to make_ensemble_predictions().
Returns:
tuple[pd.DataFrame, pd.DataFrame]: Original input dataframe with added columns for model
predictions and uncertainties. The 2nd dataframe holds ensemble performance metrics
like mean and standard deviation of MAE/RMSE.
pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame]: Original input dataframe
with added columns for model predictions and uncertainties. The optional
2nd dataframe holds ensemble performance metrics like mean and standard
deviation of MAE/RMSE.
"""
print(f"Using checkpoints from {len(runs)} run(s):")

Expand Down Expand Up @@ -180,7 +204,10 @@ def predict_from_wandb_checkpoints(
if not os.path.isfile(checkpoint_path):
run.file("checkpoint.pth").download(root=out_dir)

df, ensemble_metrics = make_ensemble_predictions(checkpoint_paths, **kwargs)
if target_col in kwargs:
df, ensemble_metrics = make_ensemble_predictions(checkpoint_paths, **kwargs)
# round to save disk space and speed up cloud storage uploads
return df.round(6), ensemble_metrics

# round to save disk space and speed up cloud storage uploads
return df.round(6), ensemble_metrics
df = make_ensemble_predictions(checkpoint_paths, **kwargs)
return df.round(6)
37 changes: 17 additions & 20 deletions aviary/wren/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,10 @@ def collate_batch(
)


# Pre-compile the regular expression
WYK_LETTER_PATTERN = re.compile(r"((?<![0-9])[A-z])")


def parse_aflow_wyckoff_str(
aflow_label: str,
) -> tuple[str, list[float], list[str], list[tuple[str, ...]]]:
Expand All @@ -271,36 +275,29 @@ def parse_aflow_wyckoff_str(
wyckoff_set = []

for el, wyk_letters_per_elem in zip(elems, wyckoff_letters):
# normalize Wyckoff letters to start with 1 if missing digit
wyk_letters_normalized = re.sub(
r"((?<![0-9])[A-z])", r"1\g<1>", wyk_letters_per_elem
)
# Normalize Wyckoff letters to start with 1 if missing digit
wyk_letters_normalized = WYK_LETTER_PATTERN.sub(r"1\g<1>", wyk_letters_per_elem)

# Separate out pairs of Wyckoff letters and their number of occurrences
sep_n_wyks = [
"".join(g) for _, g in groupby(wyk_letters_normalized, str.isalpha)
]

# Add the Wyckoff letter and its multiplicity to the list
for mult, letter in zip(map(int, sep_n_wyks[0::2]), sep_n_wyks[1::2]):
# Process Wyckoff letters and multiplicities
mults = map(int, sep_n_wyks[0::2])
letters = sep_n_wyks[1::2]

for mult, letter in zip(mults, letters):
elements.extend([el] * mult)
wyckoff_set.extend([letter] * mult)
wyckoff_site_multiplicities.extend(
[float(wyckoff_multiplicity_dict[spg_num][letter])] * mult
)

# NOTE This on-the-fly augmentation of equivalent Wyckoff sets could be a source of
# high memory use. Can be turned off by commenting out the for loop and returning
# [wyckoff_set] instead of augmented_wyckoff_set. Wren should be able to learn
# anyway.
augmented_wyckoff_set = []
for trans in relab_dict[spg_num]:
# Apply translation dictionary of allowed relabelling operations in spacegroup
t = str.maketrans(trans)
augmented_wyckoff_set.append(
tuple(",".join(wyckoff_set).translate(t).split(","))
)

augmented_wyckoff_set = list(set(augmented_wyckoff_set))
# Create augmented Wyckoff set
augmented_wyckoff_set = {
tuple(",".join(wyckoff_set).translate(str.maketrans(trans)).split(","))
for trans in relab_dict[spg_num]
}

return spg_num, wyckoff_site_multiplicities, elements, augmented_wyckoff_set
return spg_num, wyckoff_site_multiplicities, elements, list(augmented_wyckoff_set)
5 changes: 5 additions & 0 deletions aviary/wren/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,11 @@ def get_aflow_strs_from_iso_and_composition(
chemical systems that can be generated from combinations of the
input isopointal_proto and composition.
"""
if not isinstance(isopointal_proto, str):
raise TypeError(
f"Invalid isopointal_proto: {isopointal_proto} ({type(isopointal_proto)})"
)

anonymous_formula, pearson, spg, *wyckoffs = isopointal_proto.split("_")

ele_amt_dict = composition.get_el_amt_dict()
Expand Down
31 changes: 27 additions & 4 deletions aviary/wrenformer/data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import json
from functools import cache
from typing import TYPE_CHECKING, Any, Literal

import numpy as np
Expand Down Expand Up @@ -64,6 +65,25 @@ def collate_batch(
elem_features = json.load(file)


@cache
def get_wyckoff_features(
equivalent_wyckoff_set: list[tuple], spg_num: int
) -> np.ndarray:
"""Get Wyckoff set features from the precomputed dictionary. The output of this
function is cached for speed.
Args:
equivalent_wyckoff_set (list[tuple]): List of Wyckoff positions in the set.
spg_num (int): Space group number.
Returns:
np.ndarray: Shape (n_wyckoff_sites, n_features) where n_features = 444.
"""
return np.array(
tuple(sym_features[spg_num][wyk_pos] for wyk_pos in equivalent_wyckoff_set)
)


def wyckoff_embedding_from_aflow_str(wyckoff_str: str) -> Tensor:
"""Concatenate Matscholar element embeddings with Wyckoff set embeddings and handle
augmentation of equivalent Wyckoff sets.
Expand All @@ -78,12 +98,14 @@ def wyckoff_embedding_from_aflow_str(wyckoff_str: str) -> Tensor:
parsed_output = parse_aflow_wyckoff_str(wyckoff_str)
spg_num, wyckoff_site_multiplicities, elements, augmented_wyckoffs = parsed_output

symmetry_features = torch.tensor(
[
[sym_features[spg_num][wyk_pos] for wyk_pos in equivalent_wyckoff_set]
symmetry_features = np.stack(
tuple(
get_wyckoff_features(equivalent_wyckoff_set, spg_num)
for equivalent_wyckoff_set in augmented_wyckoffs
]
),
axis=0,
)
symmetry_features = torch.from_numpy(symmetry_features)

n_augments = len(augmented_wyckoffs) # number of equivalent Wyckoff sets
element_features = torch.tensor([elem_features[el] for el in elements])
Expand Down Expand Up @@ -174,6 +196,7 @@ def df_to_in_mem_dataloader(
)
if targets.dtype == torch.bool:
targets = targets.long() # convert binary classification targets to 0 and 1

inputs = np.empty(len(initial_embeddings), dtype=object)
for idx, tensor in enumerate(initial_embeddings):
inputs[idx] = tensor.to(device)
Expand Down

0 comments on commit 121f519

Please sign in to comment.