Skip to content

Commit

Permalink
reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
ardunn committed Aug 20, 2022
1 parent 0f6f62c commit e8bacc1
Show file tree
Hide file tree
Showing 12 changed files with 32 additions and 33 deletions.
8 changes: 6 additions & 2 deletions matminer/datasets/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,16 @@ def test_validate_dataset(self):
_validate_dataset(self._path, url=None, file_hash=self._hash, download_if_missing=True)

with self.assertRaises(UserWarning):
_validate_dataset(self._path, self._url, file_hash="!@#$%^&*", download_if_missing=True, n_retries_allowed=0)
_validate_dataset(
self._path, self._url, file_hash="!@#$%^&*", download_if_missing=True, n_retries_allowed=0
)
if os.path.exists(self._path):
os.remove(self._path)

with self.assertRaises(ValueError):
_validate_dataset(self._path, self._url, file_hash=self._hash, download_if_missing=True, n_retries_allowed=-1)
_validate_dataset(
self._path, self._url, file_hash=self._hash, download_if_missing=True, n_retries_allowed=-1
)

_validate_dataset(self._path, self._url, self._hash, download_if_missing=True)
self.assertTrue(os.path.exists(self._path))
Expand Down
19 changes: 7 additions & 12 deletions matminer/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,7 @@ def _get_data_home(data_home=None):
return data_home


def _validate_dataset(
data_path,
url=None,
file_hash=None,
download_if_missing=True,
n_retries_allowed=3
):
def _validate_dataset(data_path, url=None, file_hash=None, download_if_missing=True, n_retries_allowed=3):
"""
Checks to see if a dataset is on the local machine,
if not tries to download if download_if_missing is set to true,
Expand All @@ -76,7 +70,7 @@ def _validate_dataset(
Returns (None)
"""
DOWNLOAD_RETRY_WAIT = 60
download_retry_wait = 60

if n_retries_allowed < 0:
raise ValueError("Number of retries for download cannot be less than 0.")
Expand All @@ -100,8 +94,9 @@ def _validate_dataset(

do_download = True

hash_mismatch_msg = "Error, hash of downloaded file does not match that " \
"included in metadata, the data may be corrupt or altered"
hash_mismatch_msg = (
"Error, hash of downloaded file does not match that " "included in metadata, the data may be corrupt or altered"
)
if do_download:
n_retries = 0
while n_retries <= n_retries_allowed:
Expand All @@ -116,8 +111,8 @@ def _validate_dataset(
except UserWarning:
warnings.warn(hash_mismatch_msg)
if n_retries < n_retries_allowed:
warnings.warn(f"Waiting {DOWNLOAD_RETRY_WAIT}s and trying again...")
time.sleep(DOWNLOAD_RETRY_WAIT)
warnings.warn(f"Waiting {download_retry_wait}s and trying again...")
time.sleep(download_retry_wait)
else:
raise UserWarning(
f"File could not be downloaded to {data_path} after {n_retries_allowed} retries"
Expand Down
6 changes: 3 additions & 3 deletions matminer/featurizers/composition/alloy.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,7 @@ def compute_lambda(yang_delta, entropy):
float
"""
if yang_delta != 0:
return entropy / yang_delta**2
return entropy / yang_delta ** 2
else:
return 0

Expand All @@ -791,8 +791,8 @@ def compute_gamma_radii(miracle_radius_stats):
mrmin = miracle_radius_stats["min"]
mrmax = miracle_radius_stats["max"]

numerator = 1 - np.sqrt((mrmean * mrmin + mrmin**2) / (mrmean + mrmin) ** 2)
denominator = 1 - np.sqrt((mrmean * mrmax + mrmax**2) / (mrmean + mrmax) ** 2)
numerator = 1 - np.sqrt((mrmean * mrmin + mrmin ** 2) / (mrmean + mrmin) ** 2)
denominator = 1 - np.sqrt((mrmean * mrmax + mrmax ** 2) / (mrmean + mrmax) ** 2)
return numerator / denominator

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion matminer/featurizers/site/fingerprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def featurize(self, struct, idx):
site_list.append(n)
this_av_inv_drel += 1.0 / (neigh_dist[j][1])
this_av_inv_drel = this_av_inv_drel / float(this_cn)
d_fac = this_av_inv_drel**self.dist_exp
d_fac = this_av_inv_drel ** self.dist_exp
for cn in range(max(2, prev_cn + 1), min(this_cn + 1, 13)):
# Set all OPs of non-CN-complying neighbor environments
# to zero if applicable.
Expand Down
4 changes: 2 additions & 2 deletions matminer/featurizers/site/rdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def g2(eta, rs, cutoff):
Returns:
(float) Gaussian radial symmetry function.
"""
ridge = np.exp(-eta * (rs**2.0) / (cutoff**2.0)) * GaussianSymmFunc.cosine_cutoff(rs, cutoff)
ridge = np.exp(-eta * (rs ** 2.0) / (cutoff ** 2.0)) * GaussianSymmFunc.cosine_cutoff(rs, cutoff)
return ridge.sum()

@staticmethod
Expand Down Expand Up @@ -118,7 +118,7 @@ def g4(etas, zetas, gammas, neigh_dist, neigh_coords, cutoff):
ind = 0
for eta in etas:
# Compute the eta term
eta_term = np.exp(-eta * (r_ij**2.0 + r_ik**2.0 + r_jk**2.0) / (cutoff**2.0)) * cutoff_fun
eta_term = np.exp(-eta * (r_ij ** 2.0 + r_ik ** 2.0 + r_jk ** 2.0) / (cutoff ** 2.0)) * cutoff_fun
for zeta in zetas:
for gamma in gammas:
term = (1.0 + gamma * cos_theta) ** zeta * eta_term
Expand Down
4 changes: 2 additions & 2 deletions matminer/featurizers/site/tests/test_fingerprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def test_off_center_cscl(self):
self.assertArrayAlmostEqual(-1 * site1[8:], site2[8:])

# Make sure the site-ones are as expected.
right_dist = 4.209 * np.sqrt(0.45**2 + 2 * 0.5**2)
right_dist = 4.209 * np.sqrt(0.45 ** 2 + 2 * 0.5 ** 2)
right_xdist = 4.209 * 0.45
left_dist = 4.209 * np.sqrt(0.55**2 + 2 * 0.5**2)
left_dist = 4.209 * np.sqrt(0.55 ** 2 + 2 * 0.5 ** 2)
left_xdist = 4.209 * 0.55
self.assertAlmostEqual(
4
Expand Down
2 changes: 1 addition & 1 deletion matminer/featurizers/structure/bonding.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def _approximate_bonds(self, local_bonds):
d0 = u_mends[0] - l_mends[0]
d1 = u_mends[1] - l_mends[1]

d = (d0**2.0 + d1**2.0) ** 0.5
d = (d0 ** 2.0 + d1 ** 2.0) ** 0.5
if not d_min:
d_min = d
nearest = [abss]
Expand Down
2 changes: 1 addition & 1 deletion matminer/featurizers/structure/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def featurize(self, s):
raise ValueError("Disordered structure support not built yet.")
total_rad = 0
for site in s:
total_rad += site.specie.atomic_radius**3
total_rad += site.specie.atomic_radius ** 3
output.append(4 * math.pi * total_rad / (3 * s.volume))

return output
Expand Down
10 changes: 5 additions & 5 deletions matminer/featurizers/utils/grdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def volume(self, cutoff):
(float): Volume of bin
"""

results = integrate.quad(lambda x: 4.0 * pi * self(x) * x**2.0, 0, cutoff)
results = integrate.quad(lambda x: 4.0 * pi * self(x) * x ** 2.0, 0, cutoff)
if results[1] > 1e-5:
raise ValueError("Numerical integration fails for this function." " Please implement analytic integral")
return results[0]
Expand All @@ -92,7 +92,7 @@ def __call__(self, r_ij):
)

def volume(self, cutoff):
return 4.0 / 3 * np.pi * (min(self.start + self.width, cutoff) ** 3 - self.start**3)
return 4.0 / 3 * np.pi * (min(self.start + self.width, cutoff) ** 3 - self.start ** 3)


class Gaussian(AbstractPairwise):
Expand All @@ -117,7 +117,7 @@ def volume(self, cutoff):
* self.width
* (
np.sqrt(pi)
* (2 * self.center**2 + self.width**2)
* (2 * self.center ** 2 + self.width ** 2)
* (erf((cutoff - self.center) / self.width) + erf(self.center / self.width))
+ 2 * self.width * (self.center * self(0) - (self.center + cutoff) * self(cutoff))
)
Expand All @@ -143,7 +143,7 @@ def volume(self, cutoff):
4
* pi
* (((self.a * cutoff) ** 2 - 2) * np.sin(self.a * cutoff) + 2 * self.a * cutoff * np.cos(self.a * cutoff))
/ self.a**3
/ self.a ** 3
)


Expand All @@ -170,7 +170,7 @@ def volume(self, cutoff):
+ 2 * self.a * cutoff * np.sin(self.a * cutoff)
- 2
)
/ self.a**3
/ self.a ** 3
)


Expand Down
4 changes: 2 additions & 2 deletions matminer/featurizers/utils/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def skewness(data_lst, weights=None):
u2 = np.dot(weights, np.power(diff, 2)) / total_weight
if np.isclose(u3, 0):
return 0
return u3 / u2**1.5
return u3 / u2 ** 1.5

@staticmethod
def kurtosis(data_lst, weights=None):
Expand Down Expand Up @@ -203,7 +203,7 @@ def kurtosis(data_lst, weights=None):
u2 = np.dot(weights, diff_sq)
if np.isclose(u4, 0):
return 0
return u4 / u2**2 * total_weight
return u4 / u2 ** 2 * total_weight

@staticmethod
def geom_std_dev(data_lst, weights=None):
Expand Down
2 changes: 1 addition & 1 deletion matminer/featurizers/utils/tests/test_grdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_gaussian(self):
def test_histogram(self):
h = Histogram(1, 4)
self.assertArrayAlmostEqual([0, 1, 0], h([0.5, 2, 5]))
self.assertAlmostEqual(h.volume(10), 4 / 3.0 * np.pi * (5**3 - 1**3))
self.assertAlmostEqual(h.volume(10), 4 / 3.0 * np.pi * (5 ** 3 - 1 ** 3))

def test_cosine(self):
c = Cosine(2)
Expand Down
2 changes: 1 addition & 1 deletion matminer/utils/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ def gaussian_kernel(arr0, arr1, SIGMA):
kernel trick.
"""
diff = arr0 - arr1
return np.exp(-np.linalg.norm(diff.A1, ord=2) ** 2 / 2 / SIGMA**2)
return np.exp(-np.linalg.norm(diff.A1, ord=2) ** 2 / 2 / SIGMA ** 2)

0 comments on commit e8bacc1

Please sign in to comment.