Skip to content

Commit

Permalink
Reintroduced njit
Browse files Browse the repository at this point in the history
  • Loading branch information
SurgeArrester committed Nov 8, 2021
1 parent 631b421 commit 2047b81
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 16 deletions.
30 changes: 16 additions & 14 deletions ElMD/ElMD.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def main():
print(x.elmd(y))
print(y.elmd(x))

print(np.sum(np.abs(np.cumsum(x.ratio_vector - y.ratio_vector))))

x = ElMD("Li7La3Hf2O12", metric="magpie")
y = ElMD("CsPbI3", metric="magpie")
# z = ElMD("Zr3AlN", metric="atomic")
Expand Down Expand Up @@ -453,7 +455,7 @@ def __gt__(self, other):
BSD license.
'''

# @njit(cache=True)
@njit(cache=True)
def reduced_cost(i, costs, potentials, tails, heads, flows):
"""Return the reduced cost of an edge i.
"""
Expand All @@ -464,7 +466,7 @@ def reduced_cost(i, costs, potentials, tails, heads, flows):
else:
return -c

# @njit(cache=True)
@njit(cache=True)
def find_entering_edges(e, f, tails, heads, costs, potentials, flows):
"""Yield entering edges until none can be found.
"""
Expand Down Expand Up @@ -520,7 +522,7 @@ def find_entering_edges(e, f, tails, heads, costs, potentials, flows):
# All edges have nonnegative reduced costs. The flow is optimal.
return -1, -1, -1, -1

# @njit(cache=True)
@njit(cache=True)
def find_apex(p, q, size, parent):
"""Find the lowest common ancestor of nodes p and q in the spanning
tree.
Expand All @@ -544,7 +546,7 @@ def find_apex(p, q, size, parent):
else:
return p

# @njit(cache=True)
@njit(cache=True)
def trace_path(p, w, edge, parent):
"""Return the nodes and edges on the path from node p to its ancestor
w.
Expand All @@ -559,7 +561,7 @@ def trace_path(p, w, edge, parent):

return cycle_nodes, cycle_edges

# @njit(cache=True)
@njit(cache=True)
def find_cycle(i, p, q, size, edge, parent):
"""Return the nodes and edges on the cycle containing edge i == (p, q)
when the latter is added to the spanning tree.
Expand All @@ -584,7 +586,7 @@ def find_cycle(i, p, q, size, edge, parent):

return cycle_nodes, cycle_edges

# @njit(cache=True)
@njit(cache=True)
def residual_capacity(i, p, capac, flows, tails):
"""Return the residual capacity of an edge i in the direction away
from its endpoint p.
Expand All @@ -595,7 +597,7 @@ def residual_capacity(i, p, capac, flows, tails):
else:
return flows[np.int64(i)]

# @njit(cache=True)
@njit(cache=True)
def find_leaving_edge(cycle_nodes, cycle_edges, capac, flows, tails, heads):
"""Return the leaving edge in a cycle represented by cycle_nodes and
cycle_edges.
Expand All @@ -617,7 +619,7 @@ def find_leaving_edge(cycle_nodes, cycle_edges, capac, flows, tails, heads):
t = heads[np.int64(j)] if tails[np.int64(j)] == s else tails[np.int64(j)]
return j, s, t

# @njit(cache=True)
@njit(cache=True)
def augment_flow(cycle_nodes, cycle_edges, f, tails, flows):
"""Augment f units of flow along a cycle representing Wn with cycle_edges.
"""
Expand All @@ -627,7 +629,7 @@ def augment_flow(cycle_nodes, cycle_edges, f, tails, flows):
else:
flows[int(i)] -= f

# @njit(cache=True)
@njit(cache=True)
def trace_subtree(p, last, next):
"""Yield the nodes in the subtree rooted at a node p.
"""
Expand All @@ -641,7 +643,7 @@ def trace_subtree(p, last, next):

return np.array(tree, dtype=np.int64)

# @njit(cache=True)
@njit(cache=True)
def remove_edge(s, t, size, prev, last, next, parent, edge):
"""Remove an edge (s, t) where parent[t] == s from the spanning tree.
"""
Expand All @@ -666,7 +668,7 @@ def remove_edge(s, t, size, prev, last, next, parent, edge):
last[s] = prev_t
s = parent[s]

# @njit(cache=True)
@njit(cache=True)
def make_root(q, parent, size, last, prev, next, edge):
"""
Make a node q the root of its containing subtree.
Expand Down Expand Up @@ -714,7 +716,7 @@ def make_root(q, parent, size, last, prev, next, edge):
prev[q] = last_p
last[q] = last_p

# @njit(cache=True)
@njit(cache=True)
def add_edge(i, p, q, next, prev, last, size, parent, edge):
"""Add an edge (p, q) to the spanning tree where q is the root of a
subtree.
Expand All @@ -740,7 +742,7 @@ def add_edge(i, p, q, next, prev, last, size, parent, edge):
last[p] = last_q
p = parent[p]

# @njit(cache=True)
@njit(cache=True)
def update_potentials(i, p, q, heads, potentials, costs, last, next):
"""Update the potentials of the nodes in the subtree rooted at a node
q connected to its parent p by an edge i.
Expand All @@ -754,7 +756,7 @@ def update_potentials(i, p, q, heads, potentials, costs, last, next):
for q in tree:
potentials[q] += d

# @njit(cache=True)
@njit(cache=True)
def network_simplex(source_demands, sink_demands, network_costs):
'''
This is a port of the network simplex algorithm implented by Loïc Séguin-C
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
setup(
name = 'ElMD',
packages = ['ElMD'],
version = '0.4.14',
version = '0.4.15',
license='GPL3',
description = 'An implementation of the Element movers distance for chemical similarity of ionic compositions',
author = 'Cameron Hagreaves',
author_email = '[email protected]',
url = 'https://github.com/lrcfmd/ElMD/',
download_url = 'https://github.com/lrcfmd/ElMD/archive/v0.4.14.tar.gz',
download_url = 'https://github.com/lrcfmd/ElMD/archive/v0.4.15.tar.gz',
keywords = ['ChemInformatics', 'Materials Science', 'Machine Learning', 'Materials Representation'],
package_data={"elementFeatures": ["el_lookup/*.json"]},
include_package_data=True,
Expand Down

0 comments on commit 2047b81

Please sign in to comment.