Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ONNX export failed #212

Open
pfeatherstone opened this issue Nov 27, 2023 · 14 comments
Open

ONNX export failed #212

pfeatherstone opened this issue Nov 27, 2023 · 14 comments

Comments

@pfeatherstone
Copy link
Contributor

Here is a repro:

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.lm = TransformerWrapper (
            num_tokens          = 256,
            max_seq_len         = 0,
            num_memory_tokens   = 20,
            attn_layers = Decoder (
                dim             = 512,
                depth           = 1,
                heads           = 4,
                rotary_pos_emb  = True,
                shift_tokens    = 1,
                attn_flash      = True,
                attn_onnxable   = True,
                use_scalenorm   = True,
                sandwich_norm   = True
            )
        )
    def forward(self, x, mask):
        return self.lm(x, mask=mask, return_embeddings=True)

net = Model()
x = torch.randint(0, 256, size=(4, 1024))
m = x < 128
x = net(x, m)
print('Normal inferrence ok')

torch.onnx.export(net, (x,m), '/tmp/model.onnx', opset_version=17, 
                  input_names=['x', 'mask'],
                  output_names=['embeddings'],
                  dynamic_axes={'x'             : {0: 'B', 1: 'T'},
                                'mask'          : {0: 'B', 1: 'T'},
                                'embeddings'    : {0: 'B', 1: 'T'}})
print('Onnx export ok')

The export fails with message:
"b, n, device, num_mems, has_memory_tokens, emb_frac_gradient = *x.shape, x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient
ValueError: too many values to unpack (expected 6)"

@pfeatherstone
Copy link
Contributor Author

Then if i change:

*x.shape,

to

x.shape[0], x.shape[1]

I get another error:

x_transformers.py", line 1238, in forward
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length)
return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
RuntimeError: einsum(): the number of subscripts in the equation (1) does not match the number of dimensions (0) for operand 0 and no ellipsis was given

@pfeatherstone
Copy link
Contributor Author

pfeatherstone commented Nov 27, 2023

It would seem that during normal inference max_rotary_emb_length is an int, during JIT tracing or ONNX export it's a 0-dimensional tensor.

EDIT:
It looks like generally something like x.shape[0] is a normal int in normal pytorch, while in tracing, scripting or ONNX export, it's a torch.Tensor. they must have changed the behaviour recently.

@pfeatherstone
Copy link
Contributor Author

Changing:

if isinstance(seq_arange_or_len, int):

at line 432
to

if isinstance(seq_arange_or_len, int) or seq_arange_or_len.dim() == 0:

seems to resolve everything but i dunno, this seems like a hack.

lucidrains added a commit that referenced this issue Nov 27, 2023
@lucidrains
Copy link
Owner

@pfeatherstone ah yea, think i may have a solution

threw in some fixes (but by no means for all configurations)

let me know if that works

@lucidrains
Copy link
Owner

@pfeatherstone you set max_seq_len to 0 to turn off absolute positional embedding? may not work as you intended (but should be fixed)

@pfeatherstone
Copy link
Contributor Author

I now get the following error during export:

File ".../x_transformers/x_transformers.py", line 577, in forward
    segments_to_shift = list(map(lambda args: shift(*args, mask = mask), zip(segments_to_shift, shifts)))
  File ".../x_transformers/x_transformers.py", line 577, in <lambda>
    segments_to_shift = list(map(lambda args: shift(*args, mask = mask), zip(segments_to_shift, shifts)))
  File ".../x_transformers/x_transformers.py", line 560, in shift
    t = t.masked_fill(~mask[..., None], 0.)
RuntimeError: The size of tensor a (1044) must match the size of tensor b (524308) at non-singleton dimension 1

@pfeatherstone
Copy link
Contributor Author

@lucidrains sorry to bother again. But it would be really cool to get this working with ONNX. At some point I might submit a PR which adds CI/CD. Some unit tests would go a long way

@lucidrains
Copy link
Owner

@pfeatherstone can you try it without shift tokens?

@lucidrains
Copy link
Owner

lucidrains commented Nov 28, 2023

@pfeatherstone yea, i know some others have already gotten onnx to work in production, so it definitely works for some configurations, just not all. the repository at this point prioritizes simplicity; it is not worth bending over backwards to make onnx work for all settings.

@pfeatherstone
Copy link
Contributor Author

@lucidrains No it didn't work either

@lucidrains
Copy link
Owner

ah alright, i'll have to circle back to this some other time

@pfeatherstone
Copy link
Contributor Author

the repository at this point prioritizes simplicity; it is not worth bending over backwards to make onnx work for all settings.

OK cool. to be honest, once i've nailed down the configurations i want, i might write from scratch keeping exactly what i need, then it will probably be easier to debug the onnx export.

@lucidrains
Copy link
Owner

lucidrains commented Nov 28, 2023

@pfeatherstone yes exactly, that is how i intended it to be

@pfeatherstone
Copy link
Contributor Author

I've tried so many configurations and it turns out i only really need:

  • shifted tokens
  • rotary embeddings
  • either rms norm or scale norm
  • memory tokens
  • xl recurrence (maybe. i need to try)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants