Skip to content

Commit

Permalink
fea: faster data loading
Browse files Browse the repository at this point in the history
  • Loading branch information
CompRhys committed Jul 9, 2024
1 parent 82b4218 commit 4589a3e
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 21 deletions.
37 changes: 17 additions & 20 deletions aviary/wren/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,10 @@ def collate_batch(
)


# Pre-compile the regular expression
WYK_LETTER_PATTERN = re.compile(r"((?<![0-9])[A-z])")


def parse_aflow_wyckoff_str(
aflow_label: str,
) -> tuple[str, list[float], list[str], list[tuple[str, ...]]]:
Expand All @@ -271,36 +275,29 @@ def parse_aflow_wyckoff_str(
wyckoff_set = []

for el, wyk_letters_per_elem in zip(elems, wyckoff_letters):
# normalize Wyckoff letters to start with 1 if missing digit
wyk_letters_normalized = re.sub(
r"((?<![0-9])[A-z])", r"1\g<1>", wyk_letters_per_elem
)
# Normalize Wyckoff letters to start with 1 if missing digit
wyk_letters_normalized = WYK_LETTER_PATTERN.sub(r"1\g<1>", wyk_letters_per_elem)

# Separate out pairs of Wyckoff letters and their number of occurrences
sep_n_wyks = [
"".join(g) for _, g in groupby(wyk_letters_normalized, str.isalpha)
]

# Add the Wyckoff letter and its multiplicity to the list
for mult, letter in zip(map(int, sep_n_wyks[0::2]), sep_n_wyks[1::2]):
# Process Wyckoff letters and multiplicities
mults = map(int, sep_n_wyks[0::2])
letters = sep_n_wyks[1::2]

for mult, letter in zip(mults, letters):
elements.extend([el] * mult)
wyckoff_set.extend([letter] * mult)
wyckoff_site_multiplicities.extend(
[float(wyckoff_multiplicity_dict[spg_num][letter])] * mult
)

# NOTE This on-the-fly augmentation of equivalent Wyckoff sets could be a source of
# high memory use. Can be turned off by commenting out the for loop and returning
# [wyckoff_set] instead of augmented_wyckoff_set. Wren should be able to learn
# anyway.
augmented_wyckoff_set = []
for trans in relab_dict[spg_num]:
# Apply translation dictionary of allowed relabelling operations in spacegroup
t = str.maketrans(trans)
augmented_wyckoff_set.append(
tuple(",".join(wyckoff_set).translate(t).split(","))
)

augmented_wyckoff_set = list(set(augmented_wyckoff_set))
# Create augmented Wyckoff set
augmented_wyckoff_set = {
tuple(",".join(wyckoff_set).translate(str.maketrans(trans)).split(","))
for trans in relab_dict[spg_num]
}

return spg_num, wyckoff_site_multiplicities, elements, augmented_wyckoff_set
return spg_num, wyckoff_site_multiplicities, elements, list(augmented_wyckoff_set)
4 changes: 3 additions & 1 deletion aviary/wrenformer/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,13 @@ def wyckoff_embedding_from_aflow_str(wyckoff_str: str) -> Tensor:
parsed_output = parse_aflow_wyckoff_str(wyckoff_str)
spg_num, wyckoff_site_multiplicities, elements, augmented_wyckoffs = parsed_output

symmetry_features = torch.tensor(
symmetry_features = np.array(
[
[sym_features[spg_num][wyk_pos] for wyk_pos in equivalent_wyckoff_set]
for equivalent_wyckoff_set in augmented_wyckoffs
]
)
symmetry_features = torch.from_numpy(symmetry_features)

n_augments = len(augmented_wyckoffs) # number of equivalent Wyckoff sets
element_features = torch.tensor([elem_features[el] for el in elements])
Expand Down Expand Up @@ -174,6 +175,7 @@ def df_to_in_mem_dataloader(
)
if targets.dtype == torch.bool:
targets = targets.long() # convert binary classification targets to 0 and 1

inputs = np.empty(len(initial_embeddings), dtype=object)
for idx, tensor in enumerate(initial_embeddings):
inputs[idx] = tensor.to(device)
Expand Down

0 comments on commit 4589a3e

Please sign in to comment.