class Model(nn.Module):
def __init__(self):
super().__init__()
self.lm = TransformerWrapper (
num_tokens = 256,
max_seq_len = 0,
num_memory_tokens = 20,
attn_layers = Decoder (
dim = 512,
depth = 1,
heads = 4,
rotary_pos_emb = True,
shift_tokens = 1,
attn_flash = True,
attn_onnxable = True,
use_scalenorm = True,
sandwich_norm = True
)
)
def forward(self, x, mask):
return self.lm(x, mask=mask, return_embeddings=True)
net = Model()
x = torch.randint(0, 256, size=(4, 1024))
m = x < 128
x = net(x, m)
print('Normal inferrence ok')
torch.onnx.export(net, (x,m), '/tmp/model.onnx', opset_version=17,
input_names=['x', 'mask'],
output_names=['embeddings'],
dynamic_axes={'x' : {0: 'B', 1: 'T'},
'mask' : {0: 'B', 1: 'T'},
'embeddings' : {0: 'B', 1: 'T'}})
print('Onnx export ok')
Here is a repro:
The export fails with message:
"b, n, device, num_mems, has_memory_tokens, emb_frac_gradient = *x.shape, x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient
ValueError: too many values to unpack (expected 6)"