-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlayers.py
151 lines (122 loc) · 5.22 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from flax import linen as nn
from jax import numpy as jnp
class FeedForward(nn.Module):
"""Position-wise FeedForward with Gates Linear Units and GELU activation
as in https://arxiv.org/pdf/2002.05202.pdf
"""
multiplicative: int = 4 # d_ff / d_model
dropout_rate: float = 0.
@nn.compact
def __call__(self, x, deterministic: bool = True):
d_model = x.shape[-1]
# keep similar number of params wrt non-GLU
mult = int(self.multiplicative/3 * 2)
d_ff = d_model * mult
gate = nn.Dense(d_ff, use_bias=False, name="wi_0")(x)
x = nn.Dense(d_ff, use_bias=False, name="wi_1")(x)
x = nn.gelu(gate, approximate=True) * x
x = nn.Dropout(rate=self.dropout_rate, name="dropout")(x, deterministic=deterministic)
x = nn.Dense(d_model, use_bias=False, name="wo")(x)
return x
class SelfAttention(nn.Module):
num_heads: int
causal: bool = True
dropout_rate: float = 0.0
@nn.compact
def __call__(self, x, attn_mask, deterministic: bool = True):
"""
Args:
- x of shape (b, n, d)
- attn_mask of shape (b, n)
"""
b, n, d, h = *x.shape, self.num_heads
head_size = d / h
# (b, n, d*3)
x = nn.Dense(d * 3, use_bias=False, name="qkv_projection")(x)
# (b, n, d) -> (b, n, h, hsize) -> (b, h, n, hsize)
q, k, v = [_x.reshape(b, n, h, head_size).transpose((0, 2, 1, 3)) for _x in x.split(3, axis=-1)]
# attention : (b, h, n, n)
attention = q @ k.transpose((0,1,3,2)) * (d ** -0.5)
# Fill -inf into mask
if attn_mask is not None:
# (b, 1, n, 1) * (b, 1, 1, n)
# TODO: need to check
attn_mask = attn_mask[:, None, :, None] * attn_mask[:, None, None, :]
else:
if self.causal:
attn_mask = jnp.triu( jnp.ones( (n, n) ), k=1 )
attn_mask = 1. - attn_mask[None, None, :, :]
if attn_mask is not None:
attention = attention * attn_mask
attention = jnp.where(attention == 0., -jnp.inf, attention)
attention_weights = nn.softmax(attention, axis=-1)
attention_weights = nn.Dropout(self.dropout_rate)(attention_weights, deterministic=deterministic)
# context : (b, h, n, hsize) = (b, h, n, n) * (b, h, n, hsize)
context = attention_weights @ v
context = context.transpose((0, 2, 1, 3)).reshape((b,n,d))
context = nn.Dense(d, name="out_projection")(context)
return context
class SubLayer(nn.Module):
num_heads: int
ff_multiplicative: int = 4
causal: bool = False
attention_dropout: float = 0.
ff_dropout: float = 0.
def setup(self):
self.ln0 = nn.LayerNorm()
self.self_attn = SelfAttention(num_heads=self.num_heads, causal=self.causal, dropout_rate=self.attention_dropout)
self.ln1 = nn.LayerNorm()
self.ffn = FeedForward(multiplicative=self.ff_multiplicative, dropout_rate=self.ff_dropout)
def __call__(self, x, attention_mask = None, deterministic = False):
x = self.self_attn(self.ln0(x), attention_mask, deterministic=deterministic) + x
x = self.ffn(self.ln1(x), deterministic=deterministic) + x
return x
class Encoder(nn.Module):
N: int
num_heads: int
causal: bool = False
ff_multiplicative: int = 4
attention_dropout: float = 0.
ff_dropout: float = 0.
def setup(self):
self.layers = [
SubLayer(num_heads=self.num_heads,
causal=self.causal,
ff_multiplicative=self.ff_multiplicative,
attention_dropout=self.attention_dropout,
ff_dropout=self.ff_dropout)
for i in range(self.N)
]
def __call__(self, x, src_mask = None, deterministic = False):
for layer in self.layers:
x = layer(x, attention_mask = src_mask, deterministic=deterministic)
return x
class Transformer(nn.Module):
num_tokens: int
embed_size: int
max_seq_len: int
N: int
num_heads: int
causal: bool = False
ff_multiplicative: int = 4
attention_dropout: float = 0.
ff_dropout: float = 0.
def setup(self):
self.token_emb = nn.Embed(num_embeddings=self.num_tokens, features=self.embed_size)
self.pos_emb = nn.Embed(num_embeddings=self.max_seq_len, features=self.embed_size)
self.encoder = Encoder(N=self.N, num_heads=self.num_heads, causal=self.causal,
ff_multiplicative=self.ff_multiplicative,
attention_dropout=self.attention_dropout,
ff_dropout=self.ff_dropout)
def __call__(self, x, deterministic: bool = False, *args, **kwargs):
"""
- x : shape (b, n)
"""
batch_size, seq_len = x.shape
x = self.token_emb(x) # (b, n, d)
x = x + self.pos_emb( jnp.arange(seq_len) )[None, :]
# use default causal mask
x = self.encoder(x, src_mask=None, deterministic=deterministic, *args, **kwargs)
# x @ embedding.T : (b, n, d) @ (d, V) = (b, n, V)
x = self.token_emb.attend(x) # logits
return x