-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathetm.py
137 lines (118 loc) · 4.98 KB
/
etm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import torch
import torch.nn.functional as F
import numpy as np
import math
from torch import nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class ETM(nn.Module):
def __init__(self, num_topics, vocab_size, t_hidden_size, rho_size, emsize,
theta_act, embeddings=None, train_embeddings=True, enc_drop=0.5):
super(ETM, self).__init__()
## define hyperparameters
self.num_topics = num_topics
self.vocab_size = vocab_size
self.t_hidden_size = t_hidden_size
self.rho_size = rho_size
self.enc_drop = enc_drop
self.emsize = emsize
self.t_drop = nn.Dropout(enc_drop)
self.theta_act = self.get_activation(theta_act)
## define the word embedding matrix \rho
if train_embeddings:
self.rho = nn.Linear(rho_size, vocab_size, bias=False)
else:
num_embeddings, emsize = embeddings.size()
rho = nn.Embedding(num_embeddings, emsize)
self.rho = embeddings.clone().float().to(device)
## define the matrix containing the topic embeddings
self.alphas = nn.Linear(rho_size, num_topics, bias=False)#nn.Parameter(torch.randn(rho_size, num_topics))
## define variational distribution for \theta_{1:D} via amortizartion
self.q_theta = nn.Sequential(
nn.Linear(vocab_size, t_hidden_size),
self.theta_act,
nn.Linear(t_hidden_size, t_hidden_size),
self.theta_act,
)
self.mu_q_theta = nn.Linear(t_hidden_size, num_topics, bias=True)
self.logsigma_q_theta = nn.Linear(t_hidden_size, num_topics, bias=True)
def get_activation(self, act):
if act == 'tanh':
act = nn.Tanh()
elif act == 'relu':
act = nn.ReLU()
elif act == 'softplus':
act = nn.Softplus()
elif act == 'rrelu':
act = nn.RReLU()
elif act == 'leakyrelu':
act = nn.LeakyReLU()
elif act == 'elu':
act = nn.ELU()
elif act == 'selu':
act = nn.SELU()
elif act == 'glu':
act = nn.GLU()
else:
print('Defaulting to tanh activations...')
act = nn.Tanh()
return act
def reparameterize(self, mu, logvar):
"""Returns a sample from a Gaussian distribution via reparameterization.
"""
if self.training:
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps.mul_(std).add_(mu)
else:
return mu
def encode(self, bows):
"""Returns paramters of the variational distribution for \theta.
input: bows
batch of bag-of-words...tensor of shape bsz x V
output: mu_theta, log_sigma_theta
"""
q_theta = self.q_theta(bows)
if self.enc_drop > 0:
q_theta = self.t_drop(q_theta)
mu_theta = self.mu_q_theta(q_theta)
logsigma_theta = self.logsigma_q_theta(q_theta)
# kl_theta = -0.5 * torch.sum(1 + logsigma_theta - mu_theta.pow(2) - logsigma_theta.exp(), dim=-1).mean()
# Roland changed the prior on theta to have less variance, so that the document topic distribution will be less sparse
variance = .025 * torch.ones(1, device=self.rho.device)
kl_theta = -0.5 * torch.sum(1 + logsigma_theta - torch.log(variance) - (1. / variance) * mu_theta.pow(2) - (1. / variance) * logsigma_theta.exp(), dim=-1).mean()
#
return mu_theta, logsigma_theta, kl_theta
def get_beta(self):
try:
logit = self.alphas(self.rho.weight) # torch.mm(self.rho, self.alphas)
except:
logit = self.alphas(self.rho)
# beta = F.softmax(logit, dim=0).transpose(1, 0) ## softmax over vocab dimension
# Roland added a multiplicative factor to the logit to induce more sparsity in the distribution over words
factor = .6 * torch.ones(1, device=self.rho.device)
beta = F.softmax(factor * logit, dim=0).transpose(1, 0) ## softmax over vocab dimension
#
return beta
def get_theta(self, normalized_bows):
mu_theta, logsigma_theta, kld_theta = self.encode(normalized_bows)
z = self.reparameterize(mu_theta, logsigma_theta)
theta = F.softmax(z, dim=-1)
return theta, kld_theta
def decode(self, theta, beta):
res = torch.mm(theta, beta)
preds = torch.log(res+1e-6)
return preds
def forward(self, bows, normalized_bows, theta=None, aggregate=True):
## get \theta
if theta is None:
theta, kld_theta = self.get_theta(normalized_bows)
else:
kld_theta = None
## get \beta
beta = self.get_beta()
## get prediction loss
preds = self.decode(theta, beta)
recon_loss = -(preds * bows).sum(1)
if aggregate:
recon_loss = recon_loss.mean()
return recon_loss, kld_theta