-
Notifications
You must be signed in to change notification settings - Fork 14
/
distributions.py
77 lines (63 loc) · 2.54 KB
/
distributions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from __future__ import print_function
import torch
import torch.utils.data
import torch.nn.functional as F
from torch.autograd import Variable
min_epsilon = 1e-5
max_epsilon = 1.-1e-5
#=======================================================================================================================
def log_Normal_diag(x, mean, log_var, average=False, dim=None):
log_normal = -0.5 * ( log_var + torch.pow( x - mean, 2 ) / torch.exp( log_var ) )
if average:
return torch.mean( log_normal, dim )
else:
return torch.sum( log_normal, dim )
def log_Normal_standard(x, average=False, dim=None):
log_normal = -0.5 * torch.pow( x , 2 )
if average:
return torch.mean( log_normal, dim )
else:
return torch.sum( log_normal, dim )
def log_Bernoulli(x, mean, average=False, dim=None):
probs = torch.clamp( mean, min=min_epsilon, max=max_epsilon )
log_bernoulli = x * torch.log( probs ) + (1. - x ) * torch.log( 1. - probs )
if average:
return torch.mean( log_bernoulli, dim )
else:
return torch.sum( log_bernoulli, dim )
def logisticCDF(x, u, s):
return 1. / ( 1. + torch.exp( -(x-u) / s ) )
def sigmoid(x):
return 1. / ( 1. + torch.exp( -x ) )
def log_Logistic_256(x, mean, logvar, average=False, reduce=True, dim=None):
bin_size = 1. / 256.
# implementation like https://github.com/openai/iaf/blob/master/tf_utils/distributions.py#L28
scale = torch.exp(logvar)
x = (torch.floor(x / bin_size) * bin_size - mean) / scale
cdf_plus = torch.sigmoid(x + bin_size/scale)
cdf_minus = torch.sigmoid(x)
# calculate final log-likelihood for an image
log_logist_256 = - torch.log(cdf_plus - cdf_minus + 1.e-7)
if reduce:
if average:
return torch.mean(log_logist_256, dim)
else:
return torch.sum(log_logist_256, dim)
else:
return log_logist_256
def log_Logistic_512(x, mean, logvar, average=False, reduce=True, dim=None):
bin_size = 1. / 512.
# implementation like https://github.com/openai/iaf/blob/master/tf_utils/distributions.py#L28
scale = torch.exp(logvar)
x = (torch.floor(x / bin_size) * bin_size - mean) / scale
cdf_plus = torch.sigmoid(x + bin_size/scale)
cdf_minus = torch.sigmoid(x)
# calculate final log-likelihood for an image
log_logist_512 = - torch.log(cdf_plus - cdf_minus + 1.e-7)
if reduce:
if average:
return torch.mean(log_logist_512, dim)
else:
return torch.sum(log_logist_512, dim)
else:
return log_logist_512