Skip to content

Commit

Permalink
Merge pull request #11 from lrcfmd/clean_emd_func
Browse files Browse the repository at this point in the history
Clean emd func and ratio vectors for non mod_petti metrics
  • Loading branch information
SurgeArrester committed Oct 18, 2021
2 parents c0bebac + c2a38b9 commit 67d189c
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 48 deletions.
136 changes: 92 additions & 44 deletions ElMD/ElMD.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,21 @@
def main():
import time
ts = time.time()
x = ElMD("Li7La3Hf2O12", metric="mod_petti")
y = ElMD("CsPbI3", metric="jarvis")
z = ElMD("Zr3AlN", metric="atomic")
x = ElMD("LiCl", metric="mod_petti")
y = ElMD("NaCl", metric="mod_petti")
# z = ElMD("Zr3AlN", metric="atomic")

print(x.elmd(y))
print(y.elmd(x))

x = ElMD("Li7La3Hf2O12", metric="magpie")
y = ElMD("CsPbI3", metric="magpie")
# z = ElMD("Zr3AlN", metric="atomic")

print(x.elmd(y))
print(y.elmd(x))


print(y.elmd(z))
print(x)
print(x.feature_vector)
Expand Down Expand Up @@ -78,32 +87,56 @@ def _get_periodic_tab(metric):

return ElementDict

def EMD(comp1, comp2, lookup, table):
'''
A numba compiled EMD function to compare two sets of labels an associated
element feature matrix, and lookup table to map elements to indices, and
return the associated EMD.
'''
if type(comp1) is str:
source_demands = ElMD(comp1).ratio_vector

def elmd(comp1, comp2, metric="mod_petti"):
if isinstance(comp1, str):
comp1 = ElMD(comp1, metric=metric)
source_demands = comp1.ratio_vector
elif isinstance(comp1, ElMD):
source_demands = comp1.ratio_vector
else:
source_demands = comp1
raise TypeError(f"First composition must be either a string or ElMD object, you input an object of type {type(comp1)}")

if isinstance(comp2, str):
comp2 = ElMD(comp2, metric=metric)
sink_demands = comp2.ratio_vector
elif isinstance(comp2, ElMD):
sink_demands = comp2.ratio_vector

if type(comp2) is ElMD:
sink_demands = ElMD(comp2.formula, metric=comp1.metric).ratio_vector
elif type(comp2) is str:
sink_demands = ElMD(comp2, metric=comp1.metric).ratio_vector
else:
sink_demands = comp2
raise TypeError(f"Second composition must be either a string or ElMD object, you input an object of type {type(comp2)}")

source_labels = np.array([table[lookup[i]] for i in np.where(source_demands > 0)[0]], dtype=int)
sink_labels = np.array([table[lookup[i]] for i in np.where(sink_demands > 0)[0]], dtype=int)
if isinstance(comp1, ElMD) and isinstance(comp2, ElMD) and comp1.metric != comp2.metric:
raise TypeError(f"Both ElMD objects must use the same metric. comp1 has metric={comp1.metric} and comp2 has metric={comp2.metric}")

source_labels = np.array([comp1.periodic_tab[comp1.petti_lookup[i]] for i in np.where(source_demands > 0)[0]], dtype=float)
sink_labels = np.array([comp2.periodic_tab[comp2.petti_lookup[i]] for i in np.where(sink_demands > 0)[0]], dtype=float)

source_demands = source_demands[np.where(source_demands > 0)[0]]
sink_demands = sink_demands[np.where(sink_demands > 0)[0]]

# Perform a floating point conversion
network_costs = np.array([np.linalg.norm(x - y) * 1000000 for x in source_labels for y in sink_labels], dtype=np.int64)
# Perform a floating point conversion to ints to ensure algorithm terminates
network_costs = np.array([[np.linalg.norm(x - y) * 1000000 for x in sink_labels] for y in source_labels], dtype=np.int64)

return EMD(source_demands, sink_demands, network_costs)


def EMD(source_demands, sink_demands, network_costs):
'''
A numba compiled EMD function from the network simplex algorithm to compare
two distributions with a given distance matrix between node labels
'''

if len(network_costs.shape) == 2:
n, m = network_costs.shape

if len(source_demands) != n or len(sink_demands) != m:
raise ValueError(f"Shape of 2D distance matrix must have n rows and m columns where n is the number of source_demands, and m is the number of sink demands. You have n={len(source_demands)} source_demands and m={len(sink_demands)} sink_demands, but your distance matrix is [{n}, {m}].")

network_costs = network_costs.ravel()

else:
raise ValueError("Must input a 2D distance matrix between the elements of both distributions")

return network_simplex(source_demands, sink_demands, network_costs)

Expand All @@ -125,6 +158,7 @@ def __init__(self, formula="", metric="mod_petti", feature_pooling="agg", strict
self.periodic_tab = _get_periodic_tab(metric)
self.lookup = self._gen_lookup()
self.petti_lookup = _get_periodic_tab("mod_petti")
self.petti_lookup = self.filter_petti_lookup()

self.composition = self._parse_formula(self.formula)
self.normed_composition = self._normalise_composition(self.composition)
Expand All @@ -135,28 +169,42 @@ def __init__(self, formula="", metric="mod_petti", feature_pooling="agg", strict
self.feature_vector = self._gen_feature_vector()
self.pretty_formula = self._gen_pretty()

def elmd(self, comp2 = None, comp1 = None, verbose=False):

def filter_petti_lookup(self):
# Remove any elements from the mod_petti dictionary that our absent from our lookup table
filtered_petti = {k: v for k, v in self.petti_lookup.items() if k in self.periodic_tab }

lookup = {k: v for k, v in filtered_petti.items() }

for k, v in filtered_petti.items():
lookup[v] = k

# Now reindex each of the values to create a linearly spaced scale for lookups
comps, vals = zip(*[(k, v) for k, v in filtered_petti.items()])
sorted_inds = np.argsort(vals)

ret_dict = {}

for i, orig_index in enumerate(sorted_inds):
ret_dict[comps[orig_index]] = i
ret_dict[i] = comps[orig_index]


return ret_dict




def elmd(self, comp2 = None, comp1 = None):
'''
Calculate the minimal cost flow between two weighted vectors using the
network simplex method. This is overloaded to accept a range of input
types.
'''
if comp1 == None:
comp1 = self.ratio_vector

if isinstance(comp1, str):
comp1 = ElMD(comp1, metric=self.metric).ratio_vector
comp1 = self

if isinstance(comp1, ElMD):
comp1 = comp1.ratio_vector

if isinstance(comp2, str):
comp2 = ElMD(comp2, metric=self.metric).ratio_vector

if isinstance(comp2, ElMD):
comp2 = ElMD(comp2.formula, metric=self.metric).ratio_vector

return EMD(comp1, comp2, self.lookup, self.periodic_tab)
return elmd(comp1, comp2, metric=self.metric)

def _gen_ratio_vector(self):
'''
Expand All @@ -172,13 +220,13 @@ def _gen_ratio_vector(self):
comp_ratios = []

for k in sorted(comp.keys()):
comp_labels.append(self._get_position(k))
comp_labels.append(self.petti_lookup[k])
comp_ratios.append(comp[k])

indices = np.array(comp_labels, dtype=np.int64)
ratios = np.array(comp_ratios, dtype=np.float64)

numeric = np.zeros(shape=len(self.periodic_tab), dtype=np.float64)
numeric = np.zeros(shape=max([x for x in self.petti_lookup.values() if isinstance(x, int)]), dtype=np.float64)
numeric[indices] = ratios

return numeric
Expand All @@ -200,7 +248,7 @@ def _gen_petti_vector(self):
indices = np.array(comp_labels, dtype=np.int64)
ratios = np.array(comp_ratios, dtype=np.float64)

numeric = np.zeros(shape=103, dtype=np.float64)
numeric = np.zeros(shape=len(self.petti_lookup), dtype=np.float64)
numeric[indices] = ratios

return numeric
Expand All @@ -210,7 +258,7 @@ def _gen_feature_vector(self):
"""
Perform the dot product between the ratio vector and its elemental representation.
"""
n = int(len(self.lookup) / 2)
n = int(len(self.petti_lookup) / 2) - 1

# If we only have an integer representation, return the vector as is
if type(self.periodic_tab["H"]) is int:
Expand All @@ -223,7 +271,7 @@ def _gen_feature_vector(self):

for i, k in enumerate(self.normed_composition.keys()):
try:
numeric[self.lookup[k]] = self.periodic_tab[k]
numeric[self.petti_lookup[k]] = self.periodic_tab[k]
except:
print(f"Failed to process {self.formula} with {self.metric} due to unknown element {k}, discarding this element.")

Expand All @@ -246,9 +294,9 @@ def _gen_pretty(self):

for i, ind in enumerate(inds):
if self.petti_vector[ind] == 1:
pretty_form = pretty_form + f"{self.petti_lookup[str(ind)]}"
pretty_form = pretty_form + f"{self.petti_lookup[ind]}"
else:
pretty_form = pretty_form + f"{self.petti_lookup[str(ind)]}{self.petti_vector[ind]:.3f}".strip('0') + ' '
pretty_form = pretty_form + f"{self.petti_lookup[ind]}{self.petti_vector[ind]:.3f}".strip('0') + ' '

return pretty_form.strip()

Expand Down Expand Up @@ -387,7 +435,7 @@ def _get_position(self, element):

except:
if self.strict_parsing:
raise KeyError(f"One of the elements in {self.composition} is not in the {self.metric} dictionary. Try a different representation or use silent=False")
raise KeyError(f"One of the elements in {self.composition} is not in the {self.metric} dictionary. Try a different representation or use strict_parsing=False")
else:
return -1

Expand Down
19 changes: 18 additions & 1 deletion ElMD/el_lookup/mod_petti.json
Original file line number Diff line number Diff line change
@@ -1 +1,18 @@
{"D": 102, "T": 102, "H": 102, "102": "H", "0": "He", "He": 0, "11": "Li", "Li": 11, "76": "Be", "Be": 76, "85": "B", "B": 85, "86": "C", "C": 86, "87": "N", "N": 87, "96": "O", "O": 96, "101": "F", "F": 101, "1": "Ne", "Ne": 1, "10": "Na", "Na": 10, "72": "Mg", "Mg": 72, "77": "Al", "Al": 77, "84": "Si", "Si": 84, "88": "P", "P": 88, "95": "S", "S": 95, "100": "Cl", "Cl": 100, "2": "Ar", "Ar": 2, "9": "K", "K": 9, "15": "Ca", "Ca": 15, "47": "Sc", "Sc": 47, "50": "Ti", "Ti": 50, "53": "V", "V": 53, "54": "Cr", "Cr": 54, "71": "Mn", "Mn": 71, "70": "Fe", "Fe": 70, "69": "Co", "Co": 69, "68": "Ni", "Ni": 68, "67": "Cu", "Cu": 67, "73": "Zn", "Zn": 73, "78": "Ga", "Ga": 78, "83": "Ge", "Ge": 83, "89": "As", "As": 89, "94": "Se", "Se": 94, "99": "Br", "Br": 99, "3": "Kr", "Kr": 3, "8": "Rb", "Rb": 8, "14": "Sr", "Sr": 14, "20": "Y", "Y": 20, "48": "Zr", "Zr": 48, "52": "Nb", "Nb": 52, "55": "Mo", "Mo": 55, "58": "Tc", "Tc": 58, "60": "Ru", "Ru": 60, "62": "Rh", "Rh": 62, "64": "Pd", "Pd": 64, "66": "Ag", "Ag": 66, "74": "Cd", "Cd": 74, "79": "In", "In": 79, "82": "Sn", "Sn": 82, "90": "Sb", "Sb": 90, "93": "Te", "Te": 93, "98": "I", "I": 98, "42": "Es", "Xe": 4, "4": "Xe", "7": "Cs", "Cs": 7, "13": "Ba", "Ba": 13, "31": "La", "La": 31, "30": "Ce", "Ce": 30, "29": "Pr", "Pr": 29, "28": "Nd", "Nd": 28, "27": "Pm", "Pm": 27, "26": "Sm", "Sm": 26, "16": "Eu", "Eu": 16, "25": "Gd", "Gd": 25, "24": "Tb", "Tb": 24, "23": "Dy", "Dy": 23, "22": "Ho", "Ho": 22, "21": "Er", "Er": 21, "19": "Tm", "Tm": 19, "17": "Yb", "Yb": 17, "18": "Lu", "Lu": 18, "49": "Hf", "Hf": 49, "51": "Ta", "Ta": 51, "56": "W", "W": 56, "57": "Re", "Re": 57, "59": "Os", "Os": 59, "61": "Ir", "Ir": 61, "63": "Pt", "Pt": 63, "65": "Au", "Au": 65, "75": "Hg", "Hg": 75, "80": "Tl", "Tl": 80, "81": "Pb", "Pb": 81, "91": "Bi", "Bi": 91, "92": "Po", "Po": 92, "97": "At", "At": 97, "5": "Rn", "Rn": 5, "6": "Fr", "Fr": 6, "12": "Ra", "Ra": 12, "32": "Ac", "Ac": 32, "33": "Th", "Th": 33, "34": "Pa", "Pa": 34, "35": "U", "U": "35", "36": "Np", "Np": 36, "37": "Pu", "Pu": 37, "38": "Am", "Am": 38, "39": "Cm", "Cm": 39, "40": "Bk", "Bk": 40, "41": "Cf", "Cf": 41, "Es": 42, "43": "Fm", "Fm": 43, "44": "Md", "Md": 44, "45": "No", "No": 45, "46": "Lr", "Lr": 46, "Rf": 0, "Db": 0, "Sg": 0, "Bh": 0, "Hs": 0, "Mt": 0, "Ds": 0, "Rg": 0, "Cn": 0, "Nh": 0, "Fl": 0, "Mc": 0, "Lv": 0, "Ts": 0, "Og": 0, "Uue": 0}
{"D": 102, "T": 102, "H": 102, "He": 0, "Li": 11,
"Be": 76, "B": 85, "C": 86, "N": 87, "O": 96, "F": 101,
"Ne": 1, "Na": 10, "Mg": 72, "Al": 77, "Si": 84, "P": 88,
"S": 95, "Cl": 100, "Ar": 2, "K": 9, "Ca": 15, "Sc": 47,
"Ti": 50, "V": 53, "Cr": 54, "Mn": 71, "Fe": 70, "Co": 69,
"Ni": 68, "Cu": 67, "Zn": 73, "Ga": 78, "Ge": 83, "As": 89,
"Se": 94, "Br": 99, "Kr": 3, "Rb": 8, "Sr": 14, "Y": 20, "Zr": 48,
"Nb": 52, "Mo": 55, "Tc": 58, "Ru": 60, "Rh": 62, "Pd": 64, "Ag": 66,
"Cd": 74, "In": 79, "Sn": 82, "Sb": 90, "Te": 93, "I": 98, "Xe": 4,
"Cs": 7, "Ba": 13, "La": 31, "Ce": 30, "Pr": 29, "Nd": 28, "Pm": 27,
"Sm": 26, "Eu": 16, "Gd": 25, "Tb": 24, "Dy": 23, "Ho": 22, "Er": 21,
"Tm": 19, "Yb": 17, "Lu": 18, "Hf": 49, "Ta": 51, "W": 56, "Re": 57,
"Os": 59, "Ir": 61, "Pt": 63, "Au": 65, "Hg": 75, "Tl": 80, "Pb": 81,
"Bi": 91, "Po": 92, "At": 97, "Rn": 5, "Fr": 6, "Ra": 12, "Ac": 32,
"Th": 33, "Pa": 34, "U": 35, "Np": 36, "Pu": 37, "Am": 38, "Cm": 39,
"Bk": 40, "Cf": 41, "Es": 42, "Fm": 43, "Md": 44, "No": 45, "Lr": 46,
"Rf": 0, "Db": 0, "Sg": 0, "Bh": 0, "Hs": 0, "Mt": 0, "Ds": 0, "Rg": 0,
"Cn": 0, "Nh": 0, "Fl": 0, "Mc": 0, "Lv": 0, "Ts": 0, "Og": 0, "Uue": 0}
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
setup(
name = 'ElMD',
packages = ['ElMD'],
version = '0.4.3',
version = '0.4.5',
license='GPL3',
description = 'An implementation of the Element movers distance for chemical similarity of ionic compositions',
author = 'Cameron Hagreaves',
author_email = '[email protected]',
url = 'https://github.com/lrcfmd/ElMD/',
download_url = 'https://github.com/lrcfmd/ElMD/archive/v0.4.3 .tar.gz',
download_url = 'https://github.com/lrcfmd/ElMD/archive/v0.4.5 .tar.gz',
keywords = ['ChemInformatics', 'Materials Science', 'Machine Learning', 'Materials Representation'],
package_data={"elementFeatures": ["ElementDict.json"]},
include_package_data=True,
Expand All @@ -31,4 +31,4 @@
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
],
)
)

0 comments on commit 67d189c

Please sign in to comment.