-
Notifications
You must be signed in to change notification settings - Fork 83
/
utils.py
66 lines (52 loc) · 1.85 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import torch.nn as nn
import numpy as np
def weights_init(m):
"""
Initialise weights of the model.
"""
if(type(m) == nn.ConvTranspose2d or type(m) == nn.Conv2d):
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif(type(m) == nn.BatchNorm2d):
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
class NormalNLLLoss:
"""
Calculate the negative log likelihood
of normal distribution.
This needs to be minimised.
Treating Q(cj | x) as a factored Gaussian.
"""
def __call__(self, x, mu, var):
logli = -0.5 * (var.mul(2 * np.pi) + 1e-6).log() - (x - mu).pow(2).div(var.mul(2.0) + 1e-6)
nll = -(logli.sum(1).mean())
return nll
def noise_sample(n_dis_c, dis_c_dim, n_con_c, n_z, batch_size, device):
"""
Sample random noise vector for training.
INPUT
--------
n_dis_c : Number of discrete latent code.
dis_c_dim : Dimension of discrete latent code.
n_con_c : Number of continuous latent code.
n_z : Dimension of iicompressible noise.
batch_size : Batch Size
device : GPU/CPU
"""
z = torch.randn(batch_size, n_z, 1, 1, device=device)
idx = np.zeros((n_dis_c, batch_size))
if(n_dis_c != 0):
dis_c = torch.zeros(batch_size, n_dis_c, dis_c_dim, device=device)
for i in range(n_dis_c):
idx[i] = np.random.randint(dis_c_dim, size=batch_size)
dis_c[torch.arange(0, batch_size), i, idx[i]] = 1.0
dis_c = dis_c.view(batch_size, -1, 1, 1)
if(n_con_c != 0):
# Random uniform between -1 and 1.
con_c = torch.rand(batch_size, n_con_c, 1, 1, device=device) * 2 - 1
noise = z
if(n_dis_c != 0):
noise = torch.cat((z, dis_c), dim=1)
if(n_con_c != 0):
noise = torch.cat((noise, con_c), dim=1)
return noise, idx