-
-
Notifications
You must be signed in to change notification settings - Fork 104
Exporting MultiHeadAttention to pytorch for unit tests
Corey Lowman edited this page Aug 30, 2022
·
7 revisions
This assumes you've MHA params as well as q, k, v.
import torch
import numpy as np
torch.set_printoptions(precision=8, sci_mode=False)
S = 3
E = 8
H = 2
mha = torch.nn.MultiheadAttention(E, H, batch_first=True)
w_q = np.load("w_q.npz")
w_k = np.load("w_k.npz")
w_v = np.load("w_v.npz")
w_o = np.load("w_o.npz")
q_weight = torch.from_numpy(w_q["weight"])
k_weight = torch.from_numpy(w_k["weight"])
v_weight = torch.from_numpy(w_v["weight"])
weight = torch.cat([q_weight, k_weight, v_weight])
assert weight.shape == mha.in_proj_weight.shape
mha.in_proj_weight.data = weight
q_bias = torch.from_numpy(w_q["bias"])
k_bias = torch.from_numpy(w_k["bias"])
v_bias = torch.from_numpy(w_v["bias"])
bias = torch.cat([q_bias, k_bias, v_bias])
assert bias.shape == mha.in_proj_bias.shape
mha.in_proj_bias.data = bias
o_weight = torch.from_numpy(w_o["weight"])
o_bias = torch.from_numpy(w_o["bias"])
assert o_weight.shape == mha.out_proj.weight.shape
assert o_bias.shape == mha.out_proj.bias.shape
mha.out_proj.weight.data = o_weight
mha.out_proj.bias.data = o_bias
Q = torch.from_numpy(np.load("q.npy"))
K = torch.from_numpy(np.load("k.npy"))
V = torch.from_numpy(np.load("v.npy"))
Y = torch.from_numpy(np.load("y.npy"))
# v = F.linear(V, v_weight, v_bias).reshape(S, H, E // H).transpose(1, 0)
# k = F.linear(K, k_weight, k_bias).reshape(S, H, E // H).transpose(1, 0)
# q = F.linear(Q, q_weight, q_bias).reshape(S, H, E // H).transpose(1, 0)
# w = torch.bmm(q, k.transpose(-2, -1)) / math.sqrt(E // H)
# w = w.softmax(-1)
# t = torch.bmm(w, v)
# t = t.transpose(1, 0).reshape(S, E)
# o = F.linear(t, o_weight, o_bias)
#
# assert o.shape == Y.shape, (o.shape, Y.shape)
# assert torch.allclose(o, Y), (o, Y)
y, _ = mha(query=Q, key=K, value=V)
assert y.shape == Y.shape
assert torch.allclose(y, Y)
print(y)