Skip to content

Commit

Permalink
fea: add get_aflow_strs_from_iso_and_composition
Browse files Browse the repository at this point in the history
  • Loading branch information
CompRhys committed Jun 28, 2024
1 parent 28c9ab6 commit 8b3556f
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 3 deletions.
82 changes: 80 additions & 2 deletions aviary/wren/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,13 +269,13 @@ def get_aflow_label_from_spg_analyzer(
return aflow_label_with_chemsys


def canonicalize_elem_wyks(elem_wyks: str, spg_num: int) -> str:
def canonicalize_elem_wyks(elem_wyks: str, spg_num: int | str) -> str:
"""Given an element ordering, canonicalize the associated Wyckoff positions
based on the alphabetical weight of equivalent choices of origin.
Args:
elem_wyks (str): Wren Wyckoff string encoding element types at Wyckoff positions
spg_num (int): International space group number.
spg_num (int | str): International space group number.
Returns:
str: Canonicalized Wren Wyckoff encoding.
Expand Down Expand Up @@ -476,6 +476,84 @@ def get_isopointal_proto_from_aflow(aflow_label: str) -> str:
return "_".join((c_anom, pearson, spg, canonical[0][1]))


def _get_anom_formula_dict(anonymous_formula: str) -> dict:
"""Get a dictionary of element to count from an anonymous formula."""
subst = r"\g<1>1"
anonymous_formula = re.sub(r"([A-z](?![0-9]))", subst, anonymous_formula)
anom_list = ["".join(g) for _, g in groupby(anonymous_formula, str.isalpha)]
counts = anom_list[1::2]
dummy = anom_list[0::2]

return dict(zip(dummy, map(int, counts), strict=True))


def _find_translations(dict1: dict, dict2: dict) -> list[dict]:
"""Find all possible translations between two dictionaries."""
# Check if the dictionaries have the same values
if sorted(dict1.values()) != sorted(dict2.values()):
return []

keys1 = list(dict1.keys())
keys2 = list(dict2.keys())

valid_translations = []

# Generate all permutations of keys2
for perm in permutations(keys2):
# Create a translation dictionary
translation = dict(zip(keys1, perm))

# Apply the translation to dict1
transformed = {translation[k]: v for k, v in dict1.items()}

# Check if the transformed dictionary matches dict2
if transformed == dict2:
valid_translations.append(translation)

return valid_translations


def get_aflow_strs_from_iso_and_composition(
isopointal_proto: str, composition: Composition
) -> list[str]:
"""Get a canonicalized string for the prototype.
Args:
isopointal_proto (str): AFLOW-style Canonicalized prototype label
composition (Composition): pymatgen Composition object
Returns:
list[str]: List of possible AFLOW-style prototype labels with appended
chemical systems that can be generated from combinations of the
input isopointal_proto and composition.
"""
anonymous_formula, pearson, spg, *wyckoffs = isopointal_proto.split("_")

ele_amt_dict = composition.get_el_amt_dict()
proto_formula = prototype_formula(composition)
anom_amt_dict = _get_anom_formula_dict(anonymous_formula)

translations = _find_translations(ele_amt_dict, anom_amt_dict)
anom_ele_to_wyk = dict(zip(anom_amt_dict.keys(), wyckoffs, strict=True))

subst = r"1\g<1>"
anonymous_formula = re.sub(r"([A-z](?![0-9]))", subst, anonymous_formula)

aflow_strs = []
for t in translations:
elem_order = sorted(t.keys())
elem_wyks = [
re.sub(r"(?<!\d)([a-zA-Z])", r"1\1", anom_ele_to_wyk[t[elem]])
for elem in elem_order
]
canonical = canonicalize_elem_wyks("_".join(elem_wyks), spg)
chemsys = "-".join(elem_order)
aflow_str = f"{proto_formula}_{pearson}_{spg}_{canonical}:{chemsys}"
aflow_strs.append(aflow_str)

return aflow_strs


def count_distinct_wyckoff_letters(aflow_str: str) -> int:
"""Count number of distinct Wyckoff letters in Wyckoff representation."""
aflow_str, _ = aflow_str.split(":") # drop chemical system
Expand Down
28 changes: 27 additions & 1 deletion tests/test_wyckoff_ops.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from shutil import which

import pytest
from pymatgen.core.structure import Structure
from pymatgen.core.structure import Composition, Structure

from aviary.wren.utils import (
count_crystal_dof,
count_distinct_wyckoff_letters,
count_wyckoff_positions,
get_aflow_label_from_aflow,
get_aflow_label_from_spglib,
get_aflow_strs_from_iso_and_composition,
get_isopointal_proto_from_aflow,
)

Expand Down Expand Up @@ -55,6 +56,31 @@ def test_get_isopointal_proto(aflow_label, expected):
assert get_isopointal_proto_from_aflow(aflow_label) == expected


@pytest.mark.parametrize(
"isopointal_proto, composition, expected",
[
(
"AB2C3D4_tP10_115_a_g_bg_cdg",
"Ce2Al3GaPd4",
"A3B2CD4_tP10_115_ag_g_b_cdg:Al-Ce-Ga-Pd",
),
# checks that we can handle cases where one element could be on multiple sites
(
"ABC3_oP20_62_a_c_cd",
"YbNiO3",
"AB3C_oP20_62_c_cd_a:Ni-O-Yb AB3C_oP20_62_a_cd_c:Ni-O-Yb",
),
],
)
def test_get_aflow_strs_from_iso_and_composition(
isopointal_proto, composition, expected
):
aflows = get_aflow_strs_from_iso_and_composition(
isopointal_proto, Composition(composition)
)
assert aflows == expected.split(" ")


@pytest.mark.parametrize(
"aflow_label, expected",
[
Expand Down

0 comments on commit 8b3556f

Please sign in to comment.