Skip to content

Commit

Permalink
fea: use cache to speed up embedding construction
Browse files Browse the repository at this point in the history
  • Loading branch information
CompRhys committed Jul 9, 2024
1 parent 4589a3e commit e7fd108
Showing 1 changed file with 25 additions and 4 deletions.
29 changes: 25 additions & 4 deletions aviary/wrenformer/data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import json
from functools import cache
from typing import TYPE_CHECKING, Any, Literal

import numpy as np
Expand Down Expand Up @@ -64,6 +65,25 @@ def collate_batch(
elem_features = json.load(file)


@cache
def get_wyckoff_features(
equivalent_wyckoff_set: list[tuple], spg_num: int
) -> np.ndarray:
"""Get Wyckoff set features from the precomputed dictionary. The output of this
function is cached for speed.
Args:
equivalent_wyckoff_set (list[tuple]): List of Wyckoff positions in the set.
spg_num (int): Space group number.
Returns:
np.ndarray: Shape (n_wyckoff_sites, n_features) where n_features = 444.
"""
return np.array(
tuple(sym_features[spg_num][wyk_pos] for wyk_pos in equivalent_wyckoff_set)
)


def wyckoff_embedding_from_aflow_str(wyckoff_str: str) -> Tensor:
"""Concatenate Matscholar element embeddings with Wyckoff set embeddings and handle
augmentation of equivalent Wyckoff sets.
Expand All @@ -78,11 +98,12 @@ def wyckoff_embedding_from_aflow_str(wyckoff_str: str) -> Tensor:
parsed_output = parse_aflow_wyckoff_str(wyckoff_str)
spg_num, wyckoff_site_multiplicities, elements, augmented_wyckoffs = parsed_output

symmetry_features = np.array(
[
[sym_features[spg_num][wyk_pos] for wyk_pos in equivalent_wyckoff_set]
symmetry_features = np.stack(
tuple(
get_wyckoff_features(equivalent_wyckoff_set, spg_num)
for equivalent_wyckoff_set in augmented_wyckoffs
]
),
axis=0,
)
symmetry_features = torch.from_numpy(symmetry_features)

Expand Down

0 comments on commit e7fd108

Please sign in to comment.