-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathtransformer.py
240 lines (170 loc) · 6.88 KB
/
transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
"""
Pure-from-the-ground-up transformer, based on https://github.com/vpj/jax_transformer/blob/master/transformer.py
"""
from timer import timer
import jax
from jax import vmap
import jax.numpy as jnp
from functools import partial
import jax.experimental.host_callback
from jaxutils.Arg import Arg
from jaxutils.ParamsDict import ParamsDict
def rand(rng, f, shape, **kwargs):
"""
Wrap jax.random.foo function to split the incoming rng, and return the new rng beside the payload
rng = ... from previous code ...
rng, vals1 = rand(rng, jax.random.uniform, (9,3), minval=-2.0, maxval=2.0)
# ^-- rng is now newly split
rng, vals2 = rand(rng, jax.random.normal, (3,9))
# ^-- rng is split again
"""
rng, rng1 = jax.random.split(rng)
return rng, f(rng1, shape, **kwargs)
def linear_init_uniform(rng: jax.random.PRNGKey, in_features: int, out_features: int):
"""
Initialize a linear layer with uniform weights and zero bias
"""
params = ParamsDict()
rnd_range = 1 / in_features**0.5
rng, params.weight = rand(
rng,
jax.random.uniform,
(in_features, out_features),
minval=-rnd_range,
maxval=rnd_range,
)
params.bias = jnp.zeros((out_features,))
return rng, params
# Layer norm
def elementwise_linear_init_identity(shape):
"""
Initialize an elementwise_linear layer with unit gain, zero bias
"""
return ParamsDict(gain=jnp.ones(shape), bias=jnp.zeros(shape))
def linear(params, x: jnp.ndarray):
return x @ params.weight + params.bias[None, :]
def elementwise_linear(params, x: jnp.ndarray):
return params.gain[None, :] * x + params.bias[None, :]
def standardize(x, eps=1e-5):
return (x - x.mean()) / (x.std() + eps)
flip_pe_coef = Arg("flip-pe", False, "Scale token embedding, not position embedding")
def transformer_init(
rng: jax.random.PRNGKey,
n_vocab: int,
d_model: int,
n_layers: int,
n_heads: int,
d_k: int,
d_ff: int,
max_len=4096,
):
assert d_k * n_heads == d_model
# Build config struct for call
config = ParamsDict()
config.d_k = d_k
config.heads = n_heads
if flip_pe_coef():
config.lambda_e = d_model**-0.5
config.lambda_pe = 1.0
else:
config.lambda_e = d_model**-0.5
config.lambda_pe = 1.0
config.tau = 1 / d_k**0.5
# Build initializers for params
params = ParamsDict()
# Create embedding layer
rng, params.embeddings = rand(rng, jax.random.normal, (n_vocab, d_model))
# Positional encodings initialized to zeros
params.positional_encodings = jnp.zeros((max_len, d_model))
# For transformer layers
params.layers = []
for _ in range(n_layers):
layer = ParamsDict()
layer.norm_self_attn = elementwise_linear_init_identity(d_model)
layer.heads = []
for _ in range(n_heads):
head = ParamsDict()
rng, head.query = linear_init_uniform(rng, d_model, d_k)
rng, head.key = linear_init_uniform(rng, d_model, d_k)
rng, head.value = linear_init_uniform(rng, d_model, d_k)
layer.heads.append(head)
layer.norm_ff = elementwise_linear_init_identity(d_model)
rng, layer.ffn1 = linear_init_uniform(rng, d_model, d_ff)
rng, layer.ffn2 = linear_init_uniform(rng, d_ff, d_model)
params.layers.append(layer)
# Final normalization and output layer
params.pre_output_norm = elementwise_linear_init_identity(d_model)
rng, params.output = linear_init_uniform(rng, d_model, n_vocab)
return rng, config, params
# Format off for the size annotations
# fmt: off
@partial(jax.jit, static_argnums=0)
def transformer(cfg, params, x: jnp.ndarray):
"""
cfg: Config, from transformer_init, holds hyperparameters
params: Current transformer parameters, initialized in init
x: 1D array of L integers, representing the input sequence
output: L x n_vocab logits
"""
print("Compiling for L=", x.shape)
L, = x.shape # x is just 1D. Vmap/pmap will handle batching
# Create mask: 0 to attend, -Inf to ignore
mask = jnp.log(jnp.tril(jnp.ones((L, L))))
# Start with token embeddings
embeddings = cfg.lambda_e * params.embeddings[x, :] # L x Dm
# Add (learned) positional encodings
embeddings += cfg.lambda_pe * params.positional_encodings[:L, :]
# Apply the transformer layers
for layer in params.layers:
# Layer-normalize embeddings
t1 = vmap(standardize)(embeddings)
t1 = elementwise_linear(layer.norm_self_attn, t1) # L x Dm
# Multi-head self-attention
self_attns = []
for head in layer.heads:
# Project into this head's query/key space
query = linear(head.query, t1) # L x Dk
key = linear(head.key, t1) # L x Dk
# Compute L x L attention matrix
score = query @ key.T + mask # L x L
attn = jax.nn.softmax(cfg.tau * score, axis=1) # L x L
value = linear(head.value, t1) # L x Dk
self_attn = attn @ value # L x Dk
# Add this head's contribution into embeddings
self_attns += [self_attn] # [L x Dk for #heads]
t2 = t1 + jnp.hstack(self_attns)
# Layer-normalize embeddings
t2 = vmap(standardize)(t2)
t2 = elementwise_linear(layer.norm_ff, t2) # L x Dm
# Feedforward fully connected
t2 = linear(layer.ffn1, t2) # L x Dff
t2 = jax.nn.relu(t2)
t2 = linear(layer.ffn2, t2) # L x Dm
# Add this layer's contribution into embeddings
embeddings += t2
# Layer-normalize embeddings
embeddings = vmap(standardize)(embeddings)
embeddings = elementwise_linear(params.pre_output_norm, embeddings)
# And linearly project to output dimension
return linear(params.output, embeddings) # L x n_vocab
# fmt: on
def crossentropy(output: jnp.ndarray, target: int):
return -jax.nn.log_softmax(output)[target]
def seq_crossentropy(output: jnp.ndarray, targets: jnp.ndarray):
return vmap(crossentropy)(output, targets).mean()
def transformer_loss(cfg, params, x):
"""
# Transformer loss for one example
cfg: Config, from init
params: Current transformer parameters, initialized in init
x: 1D array of integers, representing the input sequence
"""
output = transformer(cfg, params, x)
return seq_crossentropy(output[:-1], x[1:])
# We don't jit this, as the loop will unroll, and take a long time to compile
def transformer_sample(cfg, params, seq: jnp.ndarray, length: int = 20):
for _i in range(length):
output = transformer(cfg, params, seq)
idx = jnp.argmax(output[-1])
seq = jnp.concatenate((seq, idx[None]))
return seq