Skip to content

Commit

Permalink
fix psitestats tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jacksund committed Jun 20, 2022
1 parent d471b57 commit b3e2954
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 19 deletions.
4 changes: 2 additions & 2 deletions matminer/featurizers/structure/sites.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def featurize(self, s):

if not s.is_ordered:
raise ValueError("Disordered structure support not built yet")
if self.elements_ is None:
if not hasattr(self, "elements_") or self.elements_ is None:
raise Exception("You must run 'fit' first!")

output = []
Expand Down Expand Up @@ -369,7 +369,7 @@ def compute_pssf(self, s, e):
return stats

def feature_labels(self):
if self.elements_ is None:
if not hasattr(self, "elements_") or self.elements_ is None:
raise Exception("You must run 'fit' first!")

labels = []
Expand Down
35 changes: 18 additions & 17 deletions matminer/featurizers/structure/tests/test_sites.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,19 @@ class PartialStructureSitesFeaturesTest(StructureFeaturesTest):
def test_partialsitestatsfingerprint(self):
# Test matrix.
op_struct_fp = PartialsSiteStatsFingerprint.from_preset("OPSiteFingerprint", stats=None)

op_struct_fp.fit([self.diamond])
opvals = op_struct_fp.featurize(self.diamond)
_ = op_struct_fp.feature_labels()
self.assertAlmostEqual(opvals[10][0], 0.9995, places=7)
self.assertAlmostEqual(opvals[10][1], 0.9995, places=7)

op_struct_fp.fit([self.nacl])
opvals = op_struct_fp.featurize(self.nacl)
self.assertAlmostEqual(opvals[18][0], 0.9995, places=7)
self.assertAlmostEqual(opvals[18][1], 0.9995, places=7)

op_struct_fp.fit([self.cscl])
opvals = op_struct_fp.featurize(self.cscl)
self.assertAlmostEqual(opvals[22][0], 0.9995, places=7)
self.assertAlmostEqual(opvals[22][1], 0.9995, places=7)
Expand Down Expand Up @@ -158,69 +163,65 @@ def test_partialsitestatsfingerprint(self):
stats=["mean"],
covariance=True,
)
prop_fp.fit([self.diamond])

# Test the feature labels
prop_fp.fit([self.diamond])
labels = prop_fp.feature_labels()
self.assertEqual(3, len(labels))

# Test a structure with all the same type (cov should be zero)
prop_fp.fit([self.diamond])
features = prop_fp.featurize(self.diamond)
self.assertArrayAlmostEqual(features, [6, 12.0107, 0])

# Test a structure with only one atom (cov should be zero too)
prop_fp.fit([self.sc])
features = prop_fp.featurize(self.sc)
self.assertArrayAlmostEqual([13, 26.9815386, 0], features)

# Test a structure with nonzero covariance
prop_fp.fit([self.nacl])
features = prop_fp.featurize(self.nacl)
self.assertArrayAlmostEqual([14, 29.22138464, 37.38969216], features)

# Test soap site featurizer
soap_fp = PartialsSiteStatsFingerprint.from_preset("SOAP_formation_energy")
soap_fp.fit([self.sc, self.diamond, self.nacl])
feats = soap_fp.featurize(self.diamond)
self.assertEqual(len(feats), 9504)
self.assertAlmostEqual(feats[0], 0.4412608, places=5)
self.assertAlmostEqual(feats[1], 0.0)
self.assertAlmostEqual(np.sum(feats), 207.88194724, places=5)
self.assertArrayAlmostEqual([11, 22.9897693, np.nan, 17, 35.453, np.nan], features)

def test_ward_prb_2017_lpd(self):
"""Test the local property difference attributes from Ward 2017"""
f = PartialsSiteStatsFingerprint.from_preset("LocalPropertyDifference_ward-prb-2017")
f.fit([self.diamond])

# Test diamond
f.fit([self.diamond])
features = f.featurize(self.diamond)
self.assertArrayAlmostEqual(features, [0] * (22 * 5))
features = f.featurize(self.diamond_no_oxi)
self.assertArrayAlmostEqual(features, [0] * (22 * 5))

# Test CsCl
f.fit([self.cscl])
big_face_area = np.sqrt(3) * 3 / 2 * (2 / 4 / 4)
small_face_area = 0.125
big_face_diff = 55 - 17
features = f.featurize(self.cscl)
labels = f.feature_labels()
my_label = "mean local difference in Number"
my_label = "Cs mean local difference in Number"
self.assertAlmostEqual(
(8 * big_face_area * big_face_diff) / (8 * big_face_area + 6 * small_face_area),
features[labels.index(my_label)],
places=3,
)
my_label = "range local difference in Electronegativity"
my_label = "Cs range local difference in Electronegativity"
self.assertAlmostEqual(0, features[labels.index(my_label)], places=3)

def test_ward_prb_2017_efftcn(self):
"""Test the effective coordination number attributes of Ward 2017"""
f = PartialsSiteStatsFingerprint.from_preset("CoordinationNumber_ward-prb-2017")

# Test Ni3Al
f.fit([self.ni3al])
features = f.featurize(self.ni3al)
labels = f.feature_labels()
my_label = "mean CN_VoronoiNN"
self.assertAlmostEqual(12, features[labels.index(my_label)])
self.assertArrayAlmostEqual([12, 12, 0, 12, 0], features)
self.assertAlmostEqual(12, features[labels.index("Al mean CN_VoronoiNN")])
self.assertAlmostEqual(12, features[labels.index("Ni mean CN_VoronoiNN")])
self.assertArrayAlmostEqual([12, 12, 0, 12, 0] * 2, features)


if __name__ == "__main__":
Expand Down

0 comments on commit b3e2954

Please sign in to comment.