Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]: Initial DynamicFlexAttention wrapper class for dynamic sequence lengths #1960

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

zyklotomic
Copy link

Had a stab at making Flex Attention work without excessive recompilation. I am not fully confident in this approach, it kinda feels jank to the max. Hence, I wanted to have confirmation if this is the right approach.

In essence, the kernel has to recompile every time the input sizes change. Hence, why not compile a kernel for a larger size, and pad inputs when necessary, and then splice the result before returning. See code for more thorough comments.

I haven't had the chance to really test the performance yet. There are potential enhancements too that I mention in the comments.

Will attach testing code for a demo in a bit.

@zyklotomic
Copy link
Author

@danielhanchen
Copy link
Contributor

Hey! Great PR! Do you know why dynamic = True fails? Hmm padding to 128 will also require inputs ie the data collator to pad to 128

@zyklotomic
Copy link
Author

I have some interesting findings to report back! Should have dug deeper initially. Turns out getting dynamic shapes to work is something that has been worked on, and apparently is available in the nightly version of PyTorch.

Links of interest:
pytorch/pytorch#135206
pytorch/pytorch#147756
tolleybot/pytorch@4a57fd0

https://github.com/pytorch/pytorch/blob/8d08b4901586f230353a558ee00c16ad57f95178/torch/_inductor/kernel/flex_attention.py#L705 (most recent commit as of writing) -> which points to https://github.com/pytorch/pytorch/blob/main/torch/_inductor/kernel/flex_decoding.py#L336

  • Something interesting to note about the above is the check that the query length is under 128 as seen in sympy.Lt(query.get_size()[-2], 128) in order to use the dynamic kernel? So maybe there is a case for the approach seen in this PR after all for larger sequence lengths? It does line up with my understanding that dynamic=True is slow and not worth it for large sizes.

I did try my example notebook and set dynamic=True with the nightly kernel and ended up with a different failure though. Granted it is a nightly release after all. When dynamic=False, I saw that a lot of auto-tuning logic has since been added.

Not a torch.compile expert, a lot of this is new to me, and I definitely did not understand 100% of the links I shared lol.

As for your question on why dynamic=True failed, I suspect at least one reason is that the shapes just didn't match up with the hard-coded possible block sizes https://github.com/pytorch/pytorch/blob/v2.5.1/torch/_inductor/kernel/flex_attention.py#L822 ? A lot of speculation here.

What do you think is the best course of action? Should we wait for the PyTorch folks to stabilize instead?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants