Skip to content

Commit

Permalink
fea: add count sites method
Browse files Browse the repository at this point in the history
  • Loading branch information
CompRhys committed Jul 17, 2024
1 parent e92ec53 commit 3e06cd8
Show file tree
Hide file tree
Showing 3 changed files with 1,785 additions and 1,747 deletions.
50 changes: 38 additions & 12 deletions aviary/wren/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def count_values_for_wyckoff(
lookup_dict: dict[str, dict[str, int]],
):
"""Count values from a lookup table and scale by wyckoff multiplicities."""
return sum(float(n) * lookup_dict[spg][k] for n, k in zip(multiplicity, wyckoff))
return sum(int(n) * lookup_dict[spg][k] for n, k in zip(multiplicity, wyckoff))


def get_aflow_label_from_aflow(
Expand Down Expand Up @@ -267,16 +267,12 @@ def get_aflow_label_from_spg_analyzer(
for el, g in groupby(
equivalent_wyckoff_labels, key=lambda x: x[1]
): # sort alphabetically by element
lg = list(g)
elem_dict[el] = sum(
float(wyckoff_multiplicity_dict[str(spg_num)][e[2]]) for e in lg
)
lg = list(g) # NOTE create a list from the iterator so that we can reuse it
elem_dict[el] = sum(wyckoff_multiplicity_dict[str(spg_num)][e[2]] for e in lg)
wyks = ""
for wyk, w in groupby(
lg, key=lambda x: x[2]
): # sort alphabetically by wyckoff letter
lw = list(w)
wyks += f"{len(lw)}{wyk}"
# sort groups alphabetically by wyckoff letter
for wyk, w in groupby(lg, key=lambda x: x[2]):
wyks += f"{len(list(w))}{wyk}"
elem_wyks.append(wyks)

# canonicalize the possible wyckoff letter sequences
Expand Down Expand Up @@ -438,7 +434,7 @@ def count_wyckoff_positions(aflow_label: str) -> int:


def count_crystal_dof(aflow_label: str) -> int:
"""Count number of free parameters coarse-grained in Wyckoff representation: how
"""Count number of free parameters in coarse-grained Wyckoff representation: how
many degrees of freedom would remain to optimize during a crystal structure
relaxation.
Expand Down Expand Up @@ -471,6 +467,36 @@ def count_crystal_dof(aflow_label: str) -> int:
return n_params


def count_crystal_sites(aflow_label: str) -> int:
"""Count number of sites from Wyckoff representation.
Args:
aflow_label (str): AFLOW-style prototype label with appended chemical system
Returns:
int: Number of free-parameters in given prototype
"""
n_params = 0

aflow_label, _ = aflow_label.split(":") # chop off chemical system
_, pearson, spg, *wyks = aflow_label.split("_")

for wyk_letters_per_elem in wyks:
# normalize Wyckoff letters to start with 1 if missing digit
wyk_letters_normalized = re.sub(
RE_WYCKOFF_NO_PREFIX, RE_SUBST_ONE_PREFIX, wyk_letters_per_elem
)
sep_el_wyks = split_alpha_numeric(wyk_letters_normalized)
n_params += count_values_for_wyckoff(
sep_el_wyks["alpha"],
sep_el_wyks["numeric"],
spg,
wyckoff_multiplicity_dict,
)

return int(n_params)


def get_isopointal_proto_from_aflow(aflow_label: str) -> str:
"""Get a canonicalized string for the prototype.
Expand Down Expand Up @@ -668,7 +694,7 @@ def get_random_structure_for_protostructure(protostructure: str, **kwargs) -> St

species_counts = [
sum(
int(wyckoff_multiplicity_dict[spg][w]) * int(m)
wyckoff_multiplicity_dict[spg][w] * int(m)
for m, w in zip(d["numeric"], d["alpha"])
)
for d in sep_el_wyks
Expand Down
Loading

0 comments on commit 3e06cd8

Please sign in to comment.