-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdistributions.py
94 lines (74 loc) · 2.67 KB
/
distributions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import math
import torch
import numpy as np
import torch.functional as F
import torch.distributions as P
class Distribution(object):
def sample(self):
raise NotImplementedError
def get_param(self):
raise NotImplementedError
def logli(self):
raise NotImplementedError
def kl_div(self, other):
raise NotImplementedError
def logli_ratio(self, other, val):
logli_new = self.logli(val)
logli_old = other.logli(val)
return (logli_new - logli_old).exp()
class DiagonalGaussian(Distribution):
def __init__(self, mean, log_std):
self.mean = mean
self.log_std = log_std
self.normal = P.Normal(self.mean, (self.log_std.exp()))
self.diagn = P.Independent(self.normal, 1)
def detach(self):
self.mean = self.mean.detach()
self.log_std = self.log_std.detach()
self.normal = None
self.diagn = None
self.normal = P.Normal(self.mean, (self.log_std.exp()))
self.diagn = P.Independent(self.normal, 1)
def sample(self):
return self.diagn.sample()
def get_param(self):
return dict(mean=self.mean, log_std=self.log_std)
def logli(self, val):
return self.diagn.log_prob(val)
#@register_kl(P.Independent, P.Independent)
def kl_div(self, other):
#print("IN KL FUNC")
deviations = (other.log_std - self.log_std)
d1 = (2.0 * self.log_std).exp()
d2 = (2.0 * other.log_std).exp()
sqmeans = (self.mean - other.mean).pow(2)
d_KL = (sqmeans + d1 - d2) / (2.0 * d2 + 1e-8) + deviations
d_KL = d_KL.sum(1, keepdim=True)
return d_KL
def entropy(self):
return self.diagn.entropy()
def distributions_test():
torch.manual_seed(60)
mean = torch.Tensor([[0, 0, 0],[1, 2, 3]])
std = torch.Tensor([[1.0, 1.0, 1.0],[1.0, 0.0, 1.0]])
#sampl = mean
#mean = torch.Tensor([1,2,3])
#std = torch.Tensor([1, 1, 1])
dist1 = DiagonalGaussian(mean, std)
'''
std2 = torch.eye(3)
dist2 = P.MultivariateNormal(mean, std2)
assert dist1.sample().shape == dist2.sample().shape
assert dist1.logli(torch.Tensor([0,0,0])) == dist2.log_prob(torch.Tensor([0,0,0]))
dist3 = DiagonalGaussian(mean, std)
assert dist1.kl_div(dist3) == torch.tensor([0.])
assert dist1.logli_ratio(dist3, torch.Tensor([1,3,3])) == torch.tensor([1.])
print(dist1.entropy())'''
#std2 = torch.eye(3)
print(dist1.sample())
print(dist1.logli(mean).exp())
#d_test = P.MultivariateNormal(mean, std2)
#d_test1 = P.Normal(0.0, 1.0)
#print(d_test1.log_prob(0.0).exp())
#print(d_test.log_prob(mean).exp())
#distributions_test()