Skip to content

Exporting MultiHeadAttention to pytorch for unit tests

Corey Lowman edited this page Aug 30, 2022 · 7 revisions

This assumes you've MHA params as well as q, k, v like so:

mha.w_q.save("w_q.npz");
mha.w_k.save("w_k.npz");
mha.w_v.save("w_v.npz");
mha.w_o.save("w_o.npz");
numpy::save("q.npy", q.data());
numpy::save("k.npy", k.data());
numpy::save("v.npy", v.data());

Unbatched MHA

import torch
import numpy as np

torch.set_printoptions(precision=8, sci_mode=False)

S = 3
E = 8
H = 2

mha = torch.nn.MultiheadAttention(E, H, batch_first=True)

w_q = np.load("w_q.npz")
w_k = np.load("w_k.npz")
w_v = np.load("w_v.npz")
w_o = np.load("w_o.npz")

q_weight = torch.from_numpy(w_q["weight"])
k_weight = torch.from_numpy(w_k["weight"])
v_weight = torch.from_numpy(w_v["weight"])
weight = torch.cat([q_weight, k_weight, v_weight])
assert weight.shape == mha.in_proj_weight.shape
mha.in_proj_weight.data = weight

q_bias = torch.from_numpy(w_q["bias"])
k_bias = torch.from_numpy(w_k["bias"])
v_bias = torch.from_numpy(w_v["bias"])
bias = torch.cat([q_bias, k_bias, v_bias])
assert bias.shape == mha.in_proj_bias.shape
mha.in_proj_bias.data = bias

o_weight = torch.from_numpy(w_o["weight"])
o_bias = torch.from_numpy(w_o["bias"])
assert o_weight.shape == mha.out_proj.weight.shape
assert o_bias.shape == mha.out_proj.bias.shape
mha.out_proj.weight.data = o_weight
mha.out_proj.bias.data = o_bias

Q = torch.from_numpy(np.load("q.npy"))
K = torch.from_numpy(np.load("k.npy"))
V = torch.from_numpy(np.load("v.npy"))
Y = torch.from_numpy(np.load("y.npy"))

# v = F.linear(V, v_weight, v_bias).reshape(S, H, E // H).transpose(1, 0)
# k = F.linear(K, k_weight, k_bias).reshape(S, H, E // H).transpose(1, 0)
# q = F.linear(Q, q_weight, q_bias).reshape(S, H, E // H).transpose(1, 0)
# w = torch.bmm(q, k.transpose(-2, -1)) / math.sqrt(E // H)
# w = w.softmax(-1)
# t = torch.bmm(w, v)
# t = t.transpose(1, 0).reshape(S, E)
# o = F.linear(t, o_weight, o_bias)
#
# assert o.shape == Y.shape, (o.shape, Y.shape)
# assert torch.allclose(o, Y), (o, Y)

y, _ = mha(query=Q, key=K, value=V)
assert y.shape == Y.shape
assert torch.allclose(y, Y)

print(y)

Batched MHA

import math
import torch.nn.functional as F
import torch
import numpy as np

torch.set_printoptions(precision=8, sci_mode=False)

S1 = 3
S2 = 4
E = 8
H = 2
B = 5

mha = torch.nn.MultiheadAttention(E, H, batch_first=True)

w_q = np.load("w_q.npz")
w_k = np.load("w_k.npz")
w_v = np.load("w_v.npz")
w_o = np.load("w_o.npz")

q_weight = torch.from_numpy(w_q["weight"])
k_weight = torch.from_numpy(w_k["weight"])
v_weight = torch.from_numpy(w_v["weight"])
weight = torch.cat([q_weight, k_weight, v_weight])
assert weight.shape == mha.in_proj_weight.shape
mha.in_proj_weight.data = weight

q_bias = torch.from_numpy(w_q["bias"])
k_bias = torch.from_numpy(w_k["bias"])
v_bias = torch.from_numpy(w_v["bias"])
bias = torch.cat([q_bias, k_bias, v_bias])
assert bias.shape == mha.in_proj_bias.shape
mha.in_proj_bias.data = bias

o_weight = torch.from_numpy(w_o["weight"])
o_bias = torch.from_numpy(w_o["bias"])
assert o_weight.shape == mha.out_proj.weight.shape
assert o_bias.shape == mha.out_proj.bias.shape
mha.out_proj.weight.data = o_weight
mha.out_proj.bias.data = o_bias

Q = torch.from_numpy(np.load("q.npy"))
K = torch.from_numpy(np.load("k.npy"))
V = torch.from_numpy(np.load("v.npy"))
Y = torch.from_numpy(np.load("y.npy"))

v = F.linear(V, v_weight, v_bias).reshape(B, S2, H, E // H).permute(0, 2, 1, 3)
k = F.linear(K, k_weight, k_bias).reshape(B, S2, H, E // H).permute(0, 2, 1, 3)
q = F.linear(Q, q_weight, q_bias).reshape(B, S1, H, E // H).permute(0, 2, 1, 3)

w = (q @ k.transpose(-2, -1)) / math.sqrt(E // H)
w = w.softmax(-1)
t = w @ v
t = t.permute(0, 2, 1, 3).reshape(B, S1, E)
o = F.linear(t, o_weight, o_bias)

y, _ = mha(query=Q, key=K, value=V)

assert o.shape == y.shape, (o.shape, y.shape)
assert torch.allclose(o, y), (o, y)

assert y.shape == Y.shape
assert torch.allclose(y, Y), (y, Y)

print(y)

TransformerEncoderLayer

import torch
import numpy as np


def set_param(p, t):
    assert p.shape == t.shape
    p.data = t


def load_mha(mha, attn_npz):
    q_weight = torch.from_numpy(attn_npz["w_q.weight"])
    k_weight = torch.from_numpy(attn_npz["w_k.weight"])
    v_weight = torch.from_numpy(attn_npz["w_v.weight"])
    weight = torch.cat([q_weight, k_weight, v_weight])
    set_param(mha.in_proj_weight, weight)

    q_bias = torch.from_numpy(attn_npz["w_q.bias"])
    k_bias = torch.from_numpy(attn_npz["w_k.bias"])
    v_bias = torch.from_numpy(attn_npz["w_v.bias"])
    bias = torch.cat([q_bias, k_bias, v_bias])
    set_param(mha.in_proj_bias, bias)

    o_weight = torch.from_numpy(attn_npz["w_o.weight"])
    o_bias = torch.from_numpy(attn_npz["w_o.bias"])
    assert o_weight.shape == mha.out_proj.weight.shape
    assert o_bias.shape == mha.out_proj.bias.shape

    set_param(mha.out_proj.weight, o_weight)
    set_param(mha.out_proj.bias, o_bias)


torch.set_printoptions(precision=8, sci_mode=False)

BATCH = 3
SEQ_LEN = 5
EMBED_DIM = 9
NUM_HEADS = 3
FF_DIM = 16


encoder = torch.nn.TransformerEncoderLayer(
    EMBED_DIM, NUM_HEADS, dim_feedforward=FF_DIM, batch_first=True, dropout=0.0
)

for (name, param) in encoder.named_parameters():
    print(name, param.shape)

attn = np.load("attn.npz")
ff = np.load("ff.npz")
X = torch.from_numpy(np.load("x.npy"))
Y = torch.from_numpy(np.load("y.npy"))

load_mha(encoder.self_attn, attn)

set_param(encoder.linear1.weight, torch.from_numpy(ff["1.0.weight"]))
set_param(encoder.linear1.bias, torch.from_numpy(ff["1.0.bias"]))
set_param(encoder.linear2.weight, torch.from_numpy(ff["1.2.weight"]))
set_param(encoder.linear2.bias, torch.from_numpy(ff["1.2.bias"]))

y = encoder(X)
assert y.shape == Y.shape
assert torch.allclose(y, Y)

print(y)

TransformerDecoderLayer

import torch
import numpy as np


def set_param(p, t):
    assert p.shape == t.shape
    p.data = t


def load_mha(mha, attn_npz):
    q_weight = torch.from_numpy(attn_npz["w_q.weight"])
    k_weight = torch.from_numpy(attn_npz["w_k.weight"])
    v_weight = torch.from_numpy(attn_npz["w_v.weight"])
    weight = torch.cat([q_weight, k_weight, v_weight])
    set_param(mha.in_proj_weight, weight)

    q_bias = torch.from_numpy(attn_npz["w_q.bias"])
    k_bias = torch.from_numpy(attn_npz["w_k.bias"])
    v_bias = torch.from_numpy(attn_npz["w_v.bias"])
    bias = torch.cat([q_bias, k_bias, v_bias])
    set_param(mha.in_proj_bias, bias)

    o_weight = torch.from_numpy(attn_npz["w_o.weight"])
    o_bias = torch.from_numpy(attn_npz["w_o.bias"])
    assert o_weight.shape == mha.out_proj.weight.shape
    assert o_bias.shape == mha.out_proj.bias.shape

    set_param(mha.out_proj.weight, o_weight)
    set_param(mha.out_proj.bias, o_bias)


torch.set_printoptions(precision=8, sci_mode=False)

EMBED_DIM = 12
NUM_HEADS = 6
FF_DIM = 2

decoder = torch.nn.TransformerDecoderLayer(
    EMBED_DIM, NUM_HEADS, dim_feedforward=FF_DIM, batch_first=True, dropout=0.0
)

self_attn = np.load("self_attn.npz")
mha_attn = np.load("mha_attn.npz")
ff = np.load("ff.npz")
TGT = torch.from_numpy(np.load("tgt.npy"))
MEM = torch.from_numpy(np.load("mem.npy"))
Y = torch.from_numpy(np.load("y.npy"))


load_mha(decoder.self_attn, self_attn)
load_mha(decoder.multihead_attn, mha_attn)

print(list(ff.keys()))

set_param(decoder.linear1.weight, torch.from_numpy(ff["0.weight"]))
set_param(decoder.linear1.bias, torch.from_numpy(ff["0.bias"]))
set_param(decoder.linear2.weight, torch.from_numpy(ff["2.weight"]))
set_param(decoder.linear2.bias, torch.from_numpy(ff["2.bias"]))

y = decoder(TGT, MEM)
assert y.shape == Y.shape
assert torch.allclose(y, Y, atol=1e-6), (y[0][0], Y[0][0])

print(y)