-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcell.py
71 lines (56 loc) · 2.54 KB
/
cell.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import numbers
import torch
from torch import jit, nn
from torch.nn import Parameter
# https://github.com/pytorch/pytorch/blob/master/benchmarks/fastrnns/custom_lstms.py
class LayerNormLSTMCell(jit.ScriptModule):
def __init__(self, input_size, hidden_size, decompose_layernorm=False):
super(LayerNormLSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size))
self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size))
# The layernorms provide learnable biases
if decompose_layernorm:
ln = LayerNorm
else:
ln = nn.LayerNorm
self.layernorm_i = ln(4 * hidden_size)
self.layernorm_h = ln(4 * hidden_size)
self.layernorm_c = ln(hidden_size)
@jit.script_method
def forward(self, input, state):
# type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]]
hx, cx = state
igates = self.layernorm_i(torch.mm(input, self.weight_ih.t()))
hgates = self.layernorm_h(torch.mm(hx, self.weight_hh.t()))
gates = igates + hgates
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = torch.sigmoid(ingate)
forgetgate = torch.sigmoid(forgetgate)
cellgate = torch.tanh(cellgate)
outgate = torch.sigmoid(outgate)
cy = self.layernorm_c((forgetgate * cx) + (ingate * cellgate))
hy = outgate * torch.tanh(cy)
return hy, cy
# from: https://github.com/pytorch/pytorch/blob/master/benchmarks/fastrnns/custom_lstms.py
class LayerNorm(jit.ScriptModule):
def __init__(self, normalized_shape):
super(LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
normalized_shape = torch.Size(normalized_shape)
# XXX: This is true for our LSTM / NLP use case and helps simplify code
assert len(normalized_shape) == 1
self.weight = Parameter(torch.ones(normalized_shape))
self.bias = Parameter(torch.zeros(normalized_shape))
self.normalized_shape = normalized_shape
@jit.script_method
def compute_layernorm_stats(self, input):
mu = input.mean(-1, keepdim=True)
sigma = input.std(-1, keepdim=True, unbiased=False)
return mu, sigma
@jit.script_method
def forward(self, input):
mu, sigma = self.compute_layernorm_stats(input)
return (input - mu) / sigma * self.weight + self.bias