Skip to content

Commit

Permalink
bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
nocotan committed Apr 3, 2021
1 parent 961080e commit 57f72b0
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 3 deletions.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.0.7
1.0.8
2 changes: 1 addition & 1 deletion gs_divergence/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from gs_divergence.geodesical_skew_divergence import gs_div
from gs_divergence.symmetrized_geodesical_skew_divergence import symmetrized_gs_div

__version__ = '1.0.7'
__version__ = '1.0.8'
2 changes: 1 addition & 1 deletion gs_divergence/geodesical_skew_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def gs_div(
assert lmd >= 0 and lmd <= 1

skew_target = alpha_geodesic(input, target, alpha=alpha, lmd=lmd)
div = input * torch.log(input / skew_target)
div = input * torch.log(input / skew_target + 1e-12)
if reduction == 'batchmean':
div = div.sum() / input.size()[0]
elif reduction == 'sum':
Expand Down
1 change: 1 addition & 0 deletions tests/test_alpha_geodesic.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def test_value_inf(self):
b = torch.Tensor([[0.4, 0.4, 0.2], [0.2, 0.1, 0.7]])

g = alpha_geodesic(a, b, alpha=100, lmd=0.5)
print(g)
res = torch.min(a, b)

self.assertTrue(torch.all(torch.isclose(g, res)))
Expand Down
21 changes: 21 additions & 0 deletions tests/test_gs_div.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import unittest
import torch

from gs_divergence import gs_div


class TestGSDiv(unittest.TestCase):
def test_alpha_minus_1(self):
a = torch.Tensor([1, 2, 3])
b = torch.Tensor([4, 5, 6])
g = gs_div(a, b, alpha=-1, lmd=0.5)

self.assertIsNotNone(g)

def test_value_0_2d(self):
a = torch.Tensor([[0.1, 0.2, 0.7], [0.5, 0.5, 0.0]])
b = torch.Tensor([[0.4, 0.4, 0.2], [0.2, 0.1, 0.7]])

g = gs_div(a, b, alpha=1, lmd=0.5)

self.assertTrue(torch.isinf(g).sum() == 0)

0 comments on commit 57f72b0

Please sign in to comment.