Skip to content

Commit

Permalink
fea: faster implementations to get wyckoffs from isopointal and compo…
Browse files Browse the repository at this point in the history
…sition
  • Loading branch information
CompRhys committed Jul 11, 2024
1 parent e87b3aa commit 6678d96
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 42 deletions.
92 changes: 53 additions & 39 deletions aviary/wren/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import re
import subprocess
from collections import Counter, defaultdict
from itertools import chain, groupby, permutations, product
from operator import itemgetter
from os.path import abspath, dirname, join
Expand Down Expand Up @@ -494,39 +495,59 @@ def get_isopointal_proto_from_aflow(aflow_label: str) -> str:

def _get_anom_formula_dict(anonymous_formula: str) -> dict:
"""Get a dictionary of element to count from an anonymous formula."""
subst = r"\g<1>1"
anonymous_formula = re.sub(r"([A-z](?![0-9]))", subst, anonymous_formula)
anom_list = ["".join(g) for _, g in groupby(anonymous_formula, str.isalpha)]
counts = anom_list[1::2]
dummy = anom_list[0::2]
result: defaultdict = defaultdict(int)
element = ""
count = ""

for char in anonymous_formula:
if char.isalpha():
if element:
result[element] += int(count) if count else 1
count = ""
element = char
else:
count += char

return dict(zip(dummy, map(int, counts), strict=True))
if element:
result[element] += int(count) if count else 1

return dict(result)

def _find_translations(dict1: dict, dict2: dict) -> list[dict]:

def _find_translations(
dict1: dict[str, int], dict2: dict[str, int]
) -> list[dict[str, str]]:
"""Find all possible translations between two dictionaries."""
# Check if the dictionaries have the same values
if sorted(dict1.values()) != sorted(dict2.values()):
if Counter(dict1.values()) != Counter(dict2.values()):
return []

keys1 = list(dict1.keys())
keys2 = list(dict2.keys())
used = set()

def backtrack(translation, index):
if index == len(dict1):
return [translation.copy()]

valid_translations = []
key1 = list(dict1.keys())[index]
value1 = dict1[key1]
valid_translations = []

# Generate all permutations of keys2
for perm in permutations(keys2):
# Create a translation dictionary
translation = dict(zip(keys1, perm))
for key2 in keys2:
if key2 not in used and dict2[key2] == value1:
used.add(key2)
translation[key1] = key2
valid_translations.extend(backtrack(translation, index + 1))
used.remove(key2)
del translation[key1]

# Apply the translation to dict1
transformed = {translation[k]: v for k, v in dict1.items()}
return valid_translations

# Check if the transformed dictionary matches dict2
if transformed == dict2:
valid_translations.append(translation)
return backtrack({}, 0)

return valid_translations

# Precompile regular expressions
re_wyckoff = re.compile(r"(?<!\d)([a-zA-Z])")
re_anonymous = re.compile(r"([A-Z])(?![0-9])")


def get_aflow_strs_from_iso_and_composition(
Expand All @@ -550,24 +571,17 @@ def get_aflow_strs_from_iso_and_composition(
anom_amt_dict = _get_anom_formula_dict(anonymous_formula)

translations = _find_translations(ele_amt_dict, anom_amt_dict)
anom_ele_to_wyk = dict(zip(anom_amt_dict.keys(), wyckoffs, strict=True))

subst = r"1\g<1>"
anonymous_formula = re.sub(r"([A-z](?![0-9]))", subst, anonymous_formula)

aflow_strs = []
for t in translations:
elem_order = sorted(t.keys())
elem_wyks = [
re.sub(r"(?<!\d)([a-zA-Z])", r"1\1", anom_ele_to_wyk[t[elem]])
for elem in elem_order
]
canonical = canonicalize_elem_wyks("_".join(elem_wyks), spg)
chemsys = "-".join(elem_order)
aflow_str = f"{proto_formula}_{pearson}_{spg}_{canonical}:{chemsys}"
aflow_strs.append(aflow_str)

return aflow_strs
anom_ele_to_wyk = dict(zip(anom_amt_dict.keys(), wyckoffs))
anonymous_formula = re_anonymous.sub(r"1\1", anonymous_formula)

return [
f"{proto_formula}_{pearson}_{spg}_"
f"{canonicalize_elem_wyks('_'.join(
re_wyckoff.sub(r'1\1', anom_ele_to_wyk[t[elem]])
for elem in sorted(t.keys())
), spg)}:{'-'.join(sorted(t.keys()))}"
for t in translations
]


def count_distinct_wyckoff_letters(aflow_str: str) -> int:
Expand Down
67 changes: 64 additions & 3 deletions tests/test_wyckoff_ops.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from itertools import permutations
from shutil import which

import pytest
from pymatgen.core.structure import Composition, Structure

from aviary.wren.utils import (
_find_translations,
count_crystal_dof,
count_distinct_wyckoff_letters,
count_wyckoff_positions,
Expand Down Expand Up @@ -82,12 +84,71 @@ def test_get_aflow_strs_from_iso_and_composition(
)
assert aflows == expected.split(" ")

# check the round trip
assert all(
get_isopointal_proto_from_aflow(aflow) == isopointal_proto for aflow in aflows
)


@pytest.mark.parametrize(
"dict1, dict2, expected",
[
# Test case 1: Simple valid translation
({"a": 1, "b": 2}, {"x": 1, "y": 2}, [{"a": "x", "b": "y"}]),
# Test case 2: Multiple valid translations
(
{"a": 1, "b": 1, "c": 1},
{"x": 1, "y": 1, "z": 1},
[
dict(zip(["a", "b", "c"], perm))
for perm in permutations(["x", "y", "z"])
],
),
# Test case 3: No valid translations (different values)
({"a": 1, "b": 2}, {"x": 1, "y": 3}, []),
# Test case 4: No valid translations (different number of items)
({"a": 1, "b": 2}, {"x": 1, "y": 2, "z": 3}, []),
# Test case 5: Empty dictionaries
({}, {}, [{}]),
# Test case 6: Larger dictionaries
(
{"a": 1, "b": 4, "c": 3, "d": 4},
{"w": 4, "x": 3, "y": 4, "z": 1},
[
{"a": "z", "b": "y", "c": "x", "d": "w"},
{"a": "z", "b": "w", "c": "x", "d": "y"},
],
),
],
)
def test_find_translations(dict1, dict2, expected):
result = _find_translations(dict1, dict2)
assert len(result) == len(expected)
for translation in result:
assert translation in expected


def test_prototype_formula():
assert prototype_formula(Composition("Ce2Al3GaPd4")) == "A3B2CD4"
# Additional test for performance with larger input
def test_find_translations_performance():
dict1 = {f"key{i}": i for i in range(8)}
dict2 = {f"val{i}": i for i in range(8)}
result = _find_translations(dict1, dict2)
assert len(result) == 1 # There should be only one valid translation


def test_get_anom_formula_from_prototype_formula():
@pytest.mark.parametrize(
"composition, expected",
[("Ce2Al3GaPd4", "A3B2CD4"), ("YbNiO3", "AB3C"), ("K2NaAlF6", "AB6C2D")],
)
def test_prototype_formula(composition: str, expected: str):
assert prototype_formula(Composition(composition)) == expected


@pytest.mark.parametrize(
"composition, expected",
[("Ce2Al3GaPd4", "AB2C3D4"), ("YbNiO3", "ABC3"), ("K2NaAlF6", "ABC2D6")],
)
def test_get_anom_formula_from_prototype_formula(composition: str, expected: str):
assert get_anom_formula_from_prototype_formula("A3B2CD4") == "AB2C3D4"


Expand Down

0 comments on commit 6678d96

Please sign in to comment.