-
-
Notifications
You must be signed in to change notification settings - Fork 103
Exporting MultiHeadAttention to pytorch for unit tests
Corey Lowman edited this page Aug 30, 2022
·
7 revisions
This assumes you've MHA params as well as q, k, v like so:
mha.w_q.save("w_q.npz");
mha.w_k.save("w_k.npz");
mha.w_v.save("w_v.npz");
mha.w_o.save("w_o.npz");
numpy::save("q.npy", q.data());
numpy::save("k.npy", k.data());
numpy::save("v.npy", v.data());
import torch
import numpy as np
torch.set_printoptions(precision=8, sci_mode=False)
S = 3
E = 8
H = 2
mha = torch.nn.MultiheadAttention(E, H, batch_first=True)
w_q = np.load("w_q.npz")
w_k = np.load("w_k.npz")
w_v = np.load("w_v.npz")
w_o = np.load("w_o.npz")
q_weight = torch.from_numpy(w_q["weight"])
k_weight = torch.from_numpy(w_k["weight"])
v_weight = torch.from_numpy(w_v["weight"])
weight = torch.cat([q_weight, k_weight, v_weight])
assert weight.shape == mha.in_proj_weight.shape
mha.in_proj_weight.data = weight
q_bias = torch.from_numpy(w_q["bias"])
k_bias = torch.from_numpy(w_k["bias"])
v_bias = torch.from_numpy(w_v["bias"])
bias = torch.cat([q_bias, k_bias, v_bias])
assert bias.shape == mha.in_proj_bias.shape
mha.in_proj_bias.data = bias
o_weight = torch.from_numpy(w_o["weight"])
o_bias = torch.from_numpy(w_o["bias"])
assert o_weight.shape == mha.out_proj.weight.shape
assert o_bias.shape == mha.out_proj.bias.shape
mha.out_proj.weight.data = o_weight
mha.out_proj.bias.data = o_bias
Q = torch.from_numpy(np.load("q.npy"))
K = torch.from_numpy(np.load("k.npy"))
V = torch.from_numpy(np.load("v.npy"))
Y = torch.from_numpy(np.load("y.npy"))
# v = F.linear(V, v_weight, v_bias).reshape(S, H, E // H).transpose(1, 0)
# k = F.linear(K, k_weight, k_bias).reshape(S, H, E // H).transpose(1, 0)
# q = F.linear(Q, q_weight, q_bias).reshape(S, H, E // H).transpose(1, 0)
# w = torch.bmm(q, k.transpose(-2, -1)) / math.sqrt(E // H)
# w = w.softmax(-1)
# t = torch.bmm(w, v)
# t = t.transpose(1, 0).reshape(S, E)
# o = F.linear(t, o_weight, o_bias)
#
# assert o.shape == Y.shape, (o.shape, Y.shape)
# assert torch.allclose(o, Y), (o, Y)
y, _ = mha(query=Q, key=K, value=V)
assert y.shape == Y.shape
assert torch.allclose(y, Y)
print(y)
import math
import torch.nn.functional as F
import torch
import numpy as np
torch.set_printoptions(precision=8, sci_mode=False)
S1 = 3
S2 = 4
E = 8
H = 2
B = 5
mha = torch.nn.MultiheadAttention(E, H, batch_first=True)
w_q = np.load("w_q.npz")
w_k = np.load("w_k.npz")
w_v = np.load("w_v.npz")
w_o = np.load("w_o.npz")
q_weight = torch.from_numpy(w_q["weight"])
k_weight = torch.from_numpy(w_k["weight"])
v_weight = torch.from_numpy(w_v["weight"])
weight = torch.cat([q_weight, k_weight, v_weight])
assert weight.shape == mha.in_proj_weight.shape
mha.in_proj_weight.data = weight
q_bias = torch.from_numpy(w_q["bias"])
k_bias = torch.from_numpy(w_k["bias"])
v_bias = torch.from_numpy(w_v["bias"])
bias = torch.cat([q_bias, k_bias, v_bias])
assert bias.shape == mha.in_proj_bias.shape
mha.in_proj_bias.data = bias
o_weight = torch.from_numpy(w_o["weight"])
o_bias = torch.from_numpy(w_o["bias"])
assert o_weight.shape == mha.out_proj.weight.shape
assert o_bias.shape == mha.out_proj.bias.shape
mha.out_proj.weight.data = o_weight
mha.out_proj.bias.data = o_bias
Q = torch.from_numpy(np.load("q.npy"))
K = torch.from_numpy(np.load("k.npy"))
V = torch.from_numpy(np.load("v.npy"))
Y = torch.from_numpy(np.load("y.npy"))
v = F.linear(V, v_weight, v_bias).reshape(B, S2, H, E // H).permute(0, 2, 1, 3)
k = F.linear(K, k_weight, k_bias).reshape(B, S2, H, E // H).permute(0, 2, 1, 3)
q = F.linear(Q, q_weight, q_bias).reshape(B, S1, H, E // H).permute(0, 2, 1, 3)
w = (q @ k.transpose(-2, -1)) / math.sqrt(E // H)
w = w.softmax(-1)
t = w @ v
t = t.permute(0, 2, 1, 3).reshape(B, S1, E)
o = F.linear(t, o_weight, o_bias)
y, _ = mha(query=Q, key=K, value=V)
assert o.shape == y.shape, (o.shape, y.shape)
assert torch.allclose(o, y), (o, y)
assert y.shape == Y.shape
assert torch.allclose(y, Y), (y, Y)
print(y)
import torch
import numpy as np
def set_param(p, t):
assert p.shape == t.shape
p.data = t
def load_mha(mha, attn_npz):
q_weight = torch.from_numpy(attn_npz["w_q.weight"])
k_weight = torch.from_numpy(attn_npz["w_k.weight"])
v_weight = torch.from_numpy(attn_npz["w_v.weight"])
weight = torch.cat([q_weight, k_weight, v_weight])
set_param(mha.in_proj_weight, weight)
q_bias = torch.from_numpy(attn_npz["w_q.bias"])
k_bias = torch.from_numpy(attn_npz["w_k.bias"])
v_bias = torch.from_numpy(attn_npz["w_v.bias"])
bias = torch.cat([q_bias, k_bias, v_bias])
set_param(mha.in_proj_bias, bias)
o_weight = torch.from_numpy(attn_npz["w_o.weight"])
o_bias = torch.from_numpy(attn_npz["w_o.bias"])
assert o_weight.shape == mha.out_proj.weight.shape
assert o_bias.shape == mha.out_proj.bias.shape
set_param(mha.out_proj.weight, o_weight)
set_param(mha.out_proj.bias, o_bias)
torch.set_printoptions(precision=8, sci_mode=False)
BATCH = 3
SEQ_LEN = 5
EMBED_DIM = 9
NUM_HEADS = 3
FF_DIM = 16
encoder = torch.nn.TransformerEncoderLayer(
EMBED_DIM, NUM_HEADS, dim_feedforward=FF_DIM, batch_first=True, dropout=0.0
)
for (name, param) in encoder.named_parameters():
print(name, param.shape)
attn = np.load("attn.npz")
ff = np.load("ff.npz")
X = torch.from_numpy(np.load("x.npy"))
Y = torch.from_numpy(np.load("y.npy"))
load_mha(encoder.self_attn, attn)
set_param(encoder.linear1.weight, torch.from_numpy(ff["1.0.weight"]))
set_param(encoder.linear1.bias, torch.from_numpy(ff["1.0.bias"]))
set_param(encoder.linear2.weight, torch.from_numpy(ff["1.2.weight"]))
set_param(encoder.linear2.bias, torch.from_numpy(ff["1.2.bias"]))
y = encoder(X)
assert y.shape == Y.shape
assert torch.allclose(y, Y)
print(y)
import torch
import numpy as np
def set_param(p, t):
assert p.shape == t.shape
p.data = t
def load_mha(mha, attn_npz):
q_weight = torch.from_numpy(attn_npz["w_q.weight"])
k_weight = torch.from_numpy(attn_npz["w_k.weight"])
v_weight = torch.from_numpy(attn_npz["w_v.weight"])
weight = torch.cat([q_weight, k_weight, v_weight])
set_param(mha.in_proj_weight, weight)
q_bias = torch.from_numpy(attn_npz["w_q.bias"])
k_bias = torch.from_numpy(attn_npz["w_k.bias"])
v_bias = torch.from_numpy(attn_npz["w_v.bias"])
bias = torch.cat([q_bias, k_bias, v_bias])
set_param(mha.in_proj_bias, bias)
o_weight = torch.from_numpy(attn_npz["w_o.weight"])
o_bias = torch.from_numpy(attn_npz["w_o.bias"])
assert o_weight.shape == mha.out_proj.weight.shape
assert o_bias.shape == mha.out_proj.bias.shape
set_param(mha.out_proj.weight, o_weight)
set_param(mha.out_proj.bias, o_bias)
torch.set_printoptions(precision=8, sci_mode=False)
EMBED_DIM = 12
NUM_HEADS = 6
FF_DIM = 2
decoder = torch.nn.TransformerDecoderLayer(
EMBED_DIM, NUM_HEADS, dim_feedforward=FF_DIM, batch_first=True, dropout=0.0
)
self_attn = np.load("self_attn.npz")
mha_attn = np.load("mha_attn.npz")
ff = np.load("ff.npz")
TGT = torch.from_numpy(np.load("tgt.npy"))
MEM = torch.from_numpy(np.load("mem.npy"))
Y = torch.from_numpy(np.load("y.npy"))
load_mha(decoder.self_attn, self_attn)
load_mha(decoder.multihead_attn, mha_attn)
print(list(ff.keys()))
set_param(decoder.linear1.weight, torch.from_numpy(ff["0.weight"]))
set_param(decoder.linear1.bias, torch.from_numpy(ff["0.bias"]))
set_param(decoder.linear2.weight, torch.from_numpy(ff["2.weight"]))
set_param(decoder.linear2.bias, torch.from_numpy(ff["2.bias"]))
y = decoder(TGT, MEM)
assert y.shape == Y.shape
assert torch.allclose(y, Y, atol=1e-6), (y[0][0], Y[0][0])
print(y)