-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathlayer.py
210 lines (162 loc) · 6.79 KB
/
layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
# Contain basic layers
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from fast_weight import fast_weight_delta
from self_ref_v0 import self_ref_v0, stateful_self_ref_v0
@torch.jit.script
def elu_p1(x):
return F.elu(x, 1., False) + 1.
@torch.jit.script
def sum_norm(x):
return x / x.sum(-1, keepdim=True)
# A block of residual feed-forward layers in Transformer
class TransformerFFlayers(nn.Module):
def __init__(self, ff_dim, res_dim, dropout, use_layernorm=True):
super(TransformerFFlayers, self).__init__()
self.res_dim = res_dim
self.ff_dim = ff_dim
self.dropout = dropout
self.use_layernorm = use_layernorm
self.ff_layers = nn.Sequential(
nn.Linear(res_dim, ff_dim), nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(ff_dim, res_dim),
nn.Dropout(dropout),
)
if use_layernorm:
self.layer_norm = nn.LayerNorm(res_dim)
def forward(self, x):
out = self.layer_norm(x) if self.use_layernorm else x
out = self.ff_layers(out) + x
return out
# Fast weight layer with feed-forward fast net
class FastFFlayer(nn.Module):
def __init__(self, num_head, dim_head, in_dim, dropout):
super(FastFFlayer, self).__init__()
self.num_head = num_head
self.dim_head = dim_head
self.in_dim = in_dim
self.fw_layer = fast_weight_delta
self.slow_net = nn.Linear(
in_dim, num_head * (3 * dim_head + 1), bias=False)
self.layer_norm = nn.LayerNorm(in_dim)
self.out_linear = nn.Linear(num_head * dim_head, in_dim, bias=False)
self.drop = nn.Dropout(dropout)
def forward(self, x):
# x shape: (len, B, n_head * d_head)
slen, bsz, _ = x.size()
out = self.layer_norm(x)
qkvb = self.slow_net(out)
qkvb = qkvb.view(slen, bsz, self.num_head, 3 * self.dim_head + 1)
head_q, head_k, head_v, head_beta = torch.split(
qkvb, (self.dim_head,) * 3 + (1,), -1)
head_beta = torch.sigmoid(head_beta)
# reshape to (B, heads, len, dim)
head_q = head_q.permute(1, 2, 0, 3)
head_k = head_k.permute(1, 2, 0, 3)
head_v = head_v.permute(1, 2, 0, 3)
head_beta = head_beta.permute(1, 2, 0, 3)
head_q = elu_p1(head_q)
head_k = elu_p1(head_k)
# normalize k and q, crucial for stable training.
head_k = sum_norm(head_k)
head_q = sum_norm(head_q)
fast_weights = torch.zeros(
bsz, self.num_head, self.dim_head, self.dim_head,
device=head_k.device)
out = self.fw_layer(head_q, head_k, head_v, head_beta, fast_weights)
out = out.transpose(1, 2)
out = out.reshape(bsz, slen, self.num_head * self.dim_head)
out = out.transpose(0, 1)
# expect [qlen, B, n_head * d_head]
# linear projection
out = self.out_linear(out)
out = self.drop(out)
out = x + out
return out
# self referential weight matrix layer
class SRWMlayer(nn.Module):
def __init__(self, num_head, dim_head, in_dim, dropout, use_ln=True,
use_input_softmax=False, beta_init=-1.0, stateful=False):
super(SRWMlayer, self).__init__()
self.num_head = num_head
self.dim_head = dim_head
self.in_dim = in_dim
self.use_ln = use_ln
self.use_input_softmax = use_input_softmax
self.stateful = stateful
if stateful:
self.sr_layer = stateful_self_ref_v0
else:
self.sr_layer = self_ref_v0
n_head = num_head
d_head = dim_head
self.W_y = nn.Parameter(torch.Tensor(1, n_head, d_head, d_head),
requires_grad=True)
self.W_q = nn.Parameter(torch.Tensor(1, n_head, d_head, d_head),
requires_grad=True)
self.W_k = nn.Parameter(torch.Tensor(1, n_head, d_head, d_head),
requires_grad=True)
self.w_b = nn.Parameter(torch.Tensor(1, n_head, d_head, 4),
requires_grad=True)
if use_ln:
self.layer_norm = nn.LayerNorm(in_dim)
self.out_linear = nn.Linear(num_head * dim_head, in_dim, bias=False)
self.drop = nn.Dropout(dropout)
self.reset_parameters(beta_init)
def reset_parameters(self, beta_init):
std = 1.0 / math.sqrt(self.dim_head)
nn.init.normal_(self.W_y, mean=0., std=std)
nn.init.normal_(self.W_q, mean=0., std=std)
nn.init.normal_(self.W_k, mean=0., std=std)
# tried -1 for beta but 0 seems to be better
# nn.init.normal_(self.w_b, mean=-5., std=std)
nn.init.normal_(self.w_b, mean=beta_init, std=std)
def forward(self, h, state=None, get_state=False):
# x shape: (len, B, n_head * d_head)
slen, bsz, _ = h.size()
# out = self.layer_norm(x)
x = h.reshape(slen, bsz, self.num_head, self.dim_head)
if self.use_input_softmax:
x = F.softmax(x, dim=-1)
# reshape to (B, heads, len, dim)
x = x.permute(1, 2, 0, 3)
if state is not None: # state stores the shift from the base weights.
W_y_bc, W_q_bc, W_k_bc, w_b_bc = state
W_y_bc = W_y_bc + self.W_y.repeat(bsz, 1, 1, 1)
W_q_bc = W_q_bc + self.W_q.repeat(bsz, 1, 1, 1)
W_k_bc = W_k_bc + self.W_k.repeat(bsz, 1, 1, 1)
w_b_bc = w_b_bc + self.w_b.repeat(bsz, 1, 1, 1)
else:
W_y_bc = self.W_y.repeat(bsz, 1, 1, 1)
W_q_bc = self.W_q.repeat(bsz, 1, 1, 1)
W_k_bc = self.W_k.repeat(bsz, 1, 1, 1)
w_b_bc = self.w_b.repeat(bsz, 1, 1, 1)
if self.stateful:
out, W_y_bc, W_q_bc, W_k_bc, w_b_bc = self.sr_layer(x, W_y_bc, W_q_bc, W_k_bc, w_b_bc)
else:
out = self.sr_layer(x, W_y_bc, W_q_bc, W_k_bc, w_b_bc)
out = out.transpose(1, 2)
out = out.reshape(bsz, slen, self.num_head * self.dim_head)
out = out.transpose(0, 1)
# expect [qlen, B, n_head * d_head]
# linear projection
out = self.out_linear(out)
out = self.drop(out)
if self.use_ln:
out = self.layer_norm(h) + out
else:
out = h + out
# out = self.layer_norm(h) + out
# compute the new shift (not very efficient; get it better from kernel)
# if state is not None and get_state:
if get_state:
W_y_bc = W_y_bc - self.W_y.repeat(bsz, 1, 1, 1)
W_q_bc = W_q_bc - self.W_q.repeat(bsz, 1, 1, 1)
W_k_bc = W_k_bc - self.W_k.repeat(bsz, 1, 1, 1)
w_b_bc = w_b_bc - self.w_b.repeat(bsz, 1, 1, 1)
state = (W_y_bc, W_q_bc, W_k_bc, w_b_bc)
return out, state
return out