Skip to content

Commit

Permalink
enable gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
nocotan committed Apr 3, 2021
1 parent d069e62 commit 62f54fb
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 7 deletions.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.0.5
1.0.6
2 changes: 1 addition & 1 deletion gs_divergence/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from gs_divergence.geodesical_skew_divergence import gs_div
from gs_divergence.symmetrized_geodesical_skew_divergence import symmetrized_gs_div

__version__ = '1.0.5'
__version__ = '1.0.6'
10 changes: 5 additions & 5 deletions gs_divergence/alpha_geodesic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,18 @@ def alpha_geodesic(
$\alpha$-geodesic between two probability distributions
"""

a += 1e-12
b += 1e-12
a_ = a + 1e-12
b_ = b + 1e-12
if alpha == 1:
return torch.exp((1 - lmd) * torch.log(a) + lmd * torch.log(b))
return torch.exp((1 - lmd) * torch.log(a_) + lmd * torch.log(b_))
elif alpha >= 1e+9:
return torch.min(a, b)
elif alpha <= -1e+9:
return torch.max(a, b)
else:
p = (1 - alpha) / 2
lhs = a ** p
rhs = b ** p
lhs = a_ ** p
rhs = b_ ** p
g = ((1 - lmd) * lhs + lmd * rhs) ** (1/p)

if alpha > 0 and (g == 0).sum() > 0:
Expand Down
8 changes: 8 additions & 0 deletions tests/test_alpha_geodesic.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,11 @@ def test_value_inf(self):
res = torch.min(a, b)

self.assertTrue(torch.equal(g, res))

def test_grad(self):
a = torch.tensor([[0.1, 0.2, 0.7], [0.5, 0.5, 0.0]], requires_grad=True)
b = torch.tensor([[0.4, 0.4, 0.2], [0.2, 0.1, 0.7]])

g = alpha_geodesic(a, b, alpha=1, lmd=0.5)

self.assertIsNotNone(g.grad_fn)

0 comments on commit 62f54fb

Please sign in to comment.