Skip to content

Commit

Permalink
Merge pull request #131 from sparks-baird/encode-cell-type
Browse files Browse the repository at this point in the history
convert encode/decode_as_primitive kwargs to encode/decode_cell_type kwargs
  • Loading branch information
sgbaird committed Jun 23, 2022
2 parents f39f86d + adad563 commit e39144c
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 34 deletions.
54 changes: 25 additions & 29 deletions src/xtal2png/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from pymatgen.core.lattice import Lattice
from pymatgen.core.structure import Structure
from pymatgen.io.cif import CifWriter
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from tqdm import tqdm

from xtal2png import __version__
Expand All @@ -29,6 +28,7 @@
get_image_mode,
rgb_scaler,
rgb_unscaler,
unit_cell_converter,
)

# from sklearn.preprocessing import MinMaxScaler
Expand Down Expand Up @@ -124,14 +124,16 @@ class XtalConverter:
``func:pymatgen.symmetry.analyzer.SpaceGroupAnalyzer.get_refined_structure``. If
specified as a tuple, then ``angle_tolerance[0]`` applies to encoding and
``angle_tolerance[1]`` applies to decoding. By default 5.0.
encode_as_primitive : bool, optional
Encode structures as symmetrized, primitive structures. Uses ``symprec`` if
``symprec`` is of type float, else uses ``symprec[0]`` if ``symprec`` is of type
tuple. Same applies for ``angle_tolerance``. By default True
decode_as_primitive : bool, optional
Decode structures as symmetrized, primitive structures. Uses ``symprec`` if
``symprec`` is of type float, else uses ``symprec[1]`` if ``symprec`` is of type
tuple. Same applies for ``angle_tolerance``. By default True
encode_cell_type : Optional[str], optional
Encode structures as-is (None), or after applying a certain tranformation. Uses
``symprec`` if ``symprec`` is of type float, else uses ``symprec[0]`` if
``symprec`` is of type tuple. Same applies for ``angle_tolerance``. By default
None
decode_cell_type : Optional[str], optional
Decode structures as-is (None), or after applying a certain tranformation. Uses
``symprec`` if ``symprec`` is of type float, else uses ``symprec[0]`` if
``symprec`` is of type tuple. Same applies for ``angle_tolerance``. By default
None
relax_on_decode: bool, optional
Use m3gnet to relax the decoded crystal structures.
channels : int, optional
Expand Down Expand Up @@ -166,8 +168,8 @@ def __init__(
save_dir: Union[str, "PathLike[str]"] = path.join("data", "preprocessed"),
symprec: Union[float, Tuple[float, float]] = 0.1,
angle_tolerance: Union[float, int, Tuple[float, float], Tuple[int, int]] = 5.0,
encode_as_primitive: bool = True,
decode_as_primitive: bool = True,
encode_cell_type: Optional[str] = None,
decode_cell_type: Optional[str] = None,
relax_on_decode: bool = False,
channels: int = 1,
verbose: bool = True,
Expand Down Expand Up @@ -199,8 +201,8 @@ def __init__(
self.encode_angle_tolerance = angle_tolerance[0]
self.decode_angle_tolerance = angle_tolerance[1]

self.encode_as_primitive = encode_as_primitive
self.decode_as_primitive = decode_as_primitive
self.encode_cell_type = encode_cell_type
self.decode_cell_type = decode_cell_type
self.relax_on_decode = relax_on_decode

self.channels = channels
Expand Down Expand Up @@ -559,15 +561,12 @@ def structures_to_arrays(
distance_matrix_tmp: List[NDArray[np.float64]] = []

for s in self.tqdm_if_verbose(structures):
spa = SpacegroupAnalyzer(
s = unit_cell_converter(
s,
self.encode_cell_type,
symprec=self.encode_symprec,
angle_tolerance=self.encode_angle_tolerance,
)
if self.encode_as_primitive:
s = spa.get_primitive_standard_structure()
else:
s = spa.get_refined_structure()
) # noqa: E501

n_sites = len(s.atomic_numbers)
if n_sites > self.max_sites:
Expand Down Expand Up @@ -987,24 +986,21 @@ def arrays_to_structures(
lattice = Lattice.from_parameters(
a=a, b=b, c=c, alpha=alpha, beta=beta, gamma=gamma
)
structure = Structure(lattice, at, fr)
s = Structure(lattice, at, fr)

# REVIEW: round fractional coordinates to nearest multiple?
if self.relax_on_decode:
relaxed_results = relaxer.relax(structure, verbose=self.verbose)
structure = relaxed_results["final_structure"]
relaxed_results = relaxer.relax(s, verbose=self.verbose)
s = relaxed_results["final_structure"]

spa = SpacegroupAnalyzer(
structure,
s = unit_cell_converter(
s,
self.decode_cell_type,
symprec=self.decode_symprec,
angle_tolerance=self.decode_angle_tolerance,
)
if self.decode_as_primitive:
structure = spa.get_primitive_standard_structure()
else:
structure = spa.get_refined_structure()

S.append(structure)
S.append(s)

if self.relax_on_decode:
# restore default https://stackoverflow.com/a/51340381/13697228
Expand Down
80 changes: 79 additions & 1 deletion src/xtal2png/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from numpy.typing import ArrayLike
from pymatgen.core.lattice import Lattice
from pymatgen.core.structure import Structure
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer

coords = [[0, 0, 0], [0.75, 0.5, 0.75]]
lattice = Lattice.from_parameters(a=3.84, b=3.84, c=3.84, alpha=120, beta=90, gamma=60)
Expand Down Expand Up @@ -191,7 +192,35 @@ def rgb_unscaler(
return X_scaled


def get_image_mode(d):
def get_image_mode(d: np.ndarray) -> str:
"""Get the image mode (i.e. "RGB" vs. grayscale ("L")) for an image array.
Parameters
----------
d : np.ndarray
A NumPy array with 3 dimensions, where the first dimension corresponds to the
of image channels and the second and third dimensions correspond to the height and
width of the image.
Returns
-------
mode : str
"RGB" for 3-channel images and "L" for grayscale images.
Raises
------
ValueError
"expected an array with 3 dimensions, received {d.ndim} dims"
ValueError
"Expected a single-channel or 3-channel array, but received a {d.ndim}-channel
array."
Examples
--------
>>> d = np.zeros((1, 64, 64), dtype=np.uint8) # grayscale image
>>> mode = get_image_mode(d)
"L"
"""
if d.ndim != 3:
raise ValueError("expected an array with 3 dimensions, received {d.ndim} dims")
if d.shape[0] == 3:
Expand All @@ -204,3 +233,52 @@ def get_image_mode(d):
)

return mode


def unit_cell_converter(
s: Structure, cell_type: Optional[str] = None, symprec=0.1, angle_tolerance=5.0
):
"""Convert from the original unit cell type to another unit cell via pymatgen.
Parameters
----------
s : Structure
a pymatgen Structure.
cell_type : Optional[str], optional
The cell type as a str or None if leaving the structure as-is. Possible options
are "primitive_standard", "conventional_standard", "refined", "reduced", and
None. By default None
Returns
-------
s : Structure
The converted Structure.
Raises
------
ValueError
"Expected one of 'primitive_standard', 'conventional_standard', 'refined',
'reduced' or None, got {cell_type}"
Examples
--------
>>> s = unit_cell_converter(s, cell_type="reduced")
"""
spa = SpacegroupAnalyzer(
s,
symprec=symprec,
angle_tolerance=angle_tolerance,
)
if cell_type == "primitive_standard":
s = spa.get_primitive_standard_structure()
elif cell_type == "conventional_standard":
s = spa.get_conventional_standard_structure()
elif cell_type == "refined":
s = spa.get_refined_structure()
elif cell_type == "reduced":
s = s.get_reduced_structure()
elif cell_type is not None:
raise ValueError(
f"Expected one of 'primitive_standard', 'conventional_standard', 'refined', 'reduced' or None, got {cell_type}" # noqa: E501
)
return s
8 changes: 4 additions & 4 deletions tests/xtal2png_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,8 @@ def test_primitive_encoding():
xc = XtalConverter(
symprec=0.1,
angle_tolerance=5.0,
encode_as_primitive=True,
decode_as_primitive=False,
encode_cell_type="primitive_standard",
decode_cell_type=None,
relax_on_decode=False,
)
input_structures = [
Expand All @@ -264,8 +264,8 @@ def test_primitive_decoding():
xc = XtalConverter(
symprec=0.1,
angle_tolerance=5.0,
encode_as_primitive=False,
decode_as_primitive=True,
encode_cell_type=None,
decode_cell_type="primitive_standard",
relax_on_decode=False,
)
input_structures = [
Expand Down

0 comments on commit e39144c

Please sign in to comment.