Skip to content

Commit

Permalink
Merge pull request #130 from sparks-baird/encode-as-primitive
Browse files Browse the repository at this point in the history
encode/decode as primitive True by default (temporary fix)
  • Loading branch information
sgbaird committed Jun 23, 2022
2 parents 1e8528f + 8fb7be1 commit f39f86d
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 19 deletions.
9 changes: 2 additions & 7 deletions src/xtal2png/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ def __init__(
save_dir: Union[str, "PathLike[str]"] = path.join("data", "preprocessed"),
symprec: Union[float, Tuple[float, float]] = 0.1,
angle_tolerance: Union[float, int, Tuple[float, float], Tuple[int, int]] = 5.0,
encode_as_primitive: bool = False,
decode_as_primitive: bool = False,
encode_as_primitive: bool = True,
decode_as_primitive: bool = True,
relax_on_decode: bool = False,
channels: int = 1,
verbose: bool = True,
Expand Down Expand Up @@ -558,7 +558,6 @@ def structures_to_arrays(
space_group: List[int] = []
distance_matrix_tmp: List[NDArray[np.float64]] = []

sym_structures = []
for s in self.tqdm_if_verbose(structures):
spa = SpacegroupAnalyzer(
s,
Expand All @@ -569,11 +568,7 @@ def structures_to_arrays(
s = spa.get_primitive_standard_structure()
else:
s = spa.get_refined_structure()
sym_structures.append(s)

structures = sym_structures

for s in self.tqdm_if_verbose(structures):
n_sites = len(s.atomic_numbers)
if n_sites > self.max_sites:
raise ValueError(
Expand Down
38 changes: 26 additions & 12 deletions tests/xtal2png_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ def assert_structures_approximate_match(
f"{i}-th original and decoded structures do not match according to StructureMatcher(comparator=ElementComparator()).fit(s, structure).\n\nOriginal (s): {s}\n\nDecoded (structure): {structure}" # noqa: E501
)

spa = SpacegroupAnalyzer(s, symprec=0.1, angle_tolerance=5.0)
s = spa.get_refined_structure()
spa = SpacegroupAnalyzer(structure, symprec=0.1, angle_tolerance=5.0)
structure = spa.get_refined_structure()

sm = StructureMatcher(primitive_cell=False, comparator=ElementComparator())
s2 = sm.get_s2_like_s1(s, structure)

Expand Down Expand Up @@ -145,7 +150,9 @@ def test_arrays_to_structures():
xc = XtalConverter(relax_on_decode=False)
data, id_data, id_mapper = xc.structures_to_arrays(example_structures)
structures = xc.arrays_to_structures(data, id_data, id_mapper)
assert_structures_approximate_match(example_structures, structures)
assert_structures_approximate_match(
example_structures, structures, tol_multiplier=2.0
)
return structures


Expand All @@ -155,15 +162,19 @@ def test_arrays_to_structures_zero_one():
example_structures, rgb_scaling=False
)
structures = xc.arrays_to_structures(data, id_data, id_mapper, rgb_scaling=False)
assert_structures_approximate_match(example_structures, structures)
assert_structures_approximate_match(
example_structures, structures, tol_multiplier=2.0
)
return structures


def test_arrays_to_structures_single():
xc = XtalConverter(relax_on_decode=False)
data, id_data, id_mapper = xc.structures_to_arrays([example_structures[0]])
structures = xc.arrays_to_structures(data, id_data, id_mapper)
assert_structures_approximate_match([example_structures[0]], structures)
assert_structures_approximate_match(
[example_structures[0]], structures, tol_multiplier=2.0
)
return structures


Expand All @@ -189,14 +200,18 @@ def test_png2xtal():
xc = XtalConverter(relax_on_decode=False)
imgs = xc.xtal2png(example_structures, show=True, save=True)
decoded_structures = xc.png2xtal(imgs)
assert_structures_approximate_match(example_structures, decoded_structures)
assert_structures_approximate_match(
example_structures, decoded_structures, tol_multiplier=2.0
)


def test_png2xtal_single():
xc = XtalConverter(relax_on_decode=False)
imgs = xc.xtal2png([example_structures[0]], show=True, save=True)
decoded_structures = xc.png2xtal(imgs, save=False)
assert_structures_approximate_match([example_structures[0]], decoded_structures)
assert_structures_approximate_match(
[example_structures[0]], decoded_structures, tol_multiplier=2.0
)
return decoded_structures


Expand All @@ -205,7 +220,9 @@ def test_png2xtal_rgb_image():
imgs = xc.xtal2png(example_structures, show=False, save=False)
imgs = [img.convert("RGB") for img in imgs]
decoded_structures = xc.png2xtal(imgs)
assert_structures_approximate_match(example_structures, decoded_structures)
assert_structures_approximate_match(
example_structures, decoded_structures, tol_multiplier=2.0
)
return decoded_structures


Expand All @@ -216,7 +233,9 @@ def test_png2xtal_three_channels():
if img_shape != (64, 64, 3):
raise ValueError(f"Expected image shape: (3, 64, 64), received: {img_shape}")
decoded_structures = xc.png2xtal(imgs)
assert_structures_approximate_match(example_structures, decoded_structures)
assert_structures_approximate_match(
example_structures, decoded_structures, tol_multiplier=2.0
)


def test_primitive_encoding():
Expand Down Expand Up @@ -362,8 +381,3 @@ def test_plot_and_save():
1 + 1

# %% Code Graveyard
# from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
# spa = SpacegroupAnalyzer(s, symprec=0.1, angle_tolerance=5.0)
# s = spa.get_refined_structure()
# spa = SpacegroupAnalyzer(structure, symprec=0.1, angle_tolerance=5.0)
# structure = spa.get_refined_structure()

0 comments on commit f39f86d

Please sign in to comment.