-
Notifications
You must be signed in to change notification settings - Fork 612
[TORCH] Added flex_attention hop function #4366
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[TORCH] Added flex_attention hop function #4366
Conversation
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Change 1: Converts builtin tensors → Torch tensors when entering the loop body Change 2: Ensures Torch tensors → builtin tensors when yielding back to the loop condition Without these fixes, the conversion would fail when while loops carry tensor values Also modified basic_test.py FILECHECK statements. Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
1. Better documentation for AtenFlexAttentionOp 2. Function referece added as attributes to aten.flex_attention 3. Updates to _import_hop_flex_attention reflecting latest changes of module import. 4. Removed discardable attributes; scored_mod_fn and mask_mod_fn added as optionalAttr Signed-off-by: Keshav Vinayak Jha <[email protected]>
Remove note about method usage for HOPs.
Removed TODO note for grouped query attention support in the docstring and comments.
095cb61 to
5e024f6
Compare
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
zjgarvey
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does enable importing to mlir.
However, the changes don't actually provide "support" for this op, since the torch op can neither be decomposed nor lowered to any other dialects.
Although we could review/merge this and subsequently add a lowering path for the op in MLIR, I would personally prefer the e2e support is added in the same PR as the import support.
This is a rather unique operator, so having passing e2e tests would give me a lot more confidence in the choices made here. Otherwise I'm basically just hoping that what you did generally makes sense (or doing a significant amount of work myself to check it out), because there really isn't much precedent for these kinds of choices in our codebase.
The only thing needed to have this passing e2e tests is implementing TilingInterface for this operation:
With that said, it's an unreasonable bar to set that every operation must compile e2e through torch-mlir. Torch-MLIR is not a compiler, even though it has tests for e2e paths. The project docs explicitly call out this: Torch-MLIR is primarily a project that is integrated into compilers to bridge them to PyTorch and ONNX. If contemplating a new integration, it may be helpful to refer to existing downstreams: It should be okay to land support for ops through the importer without it running e2e tests in torch-mlir. I've looked at the implementation of e2e tests for more complex ops like attention, and they are not good implementations, they don't add much value. We should as a project allow landing PRs that add support to the importer seperately from e2e tests (Atleast for HOPs). I don't think having a dummy implementation for an op should be the bar to land an operation. |
Description
Torch_AtenFlexAttentionOpwith 6 operands (query, key, value, scale, enable_gqa, return_lse) and 2 optional attributes (score_mod_fn, mask_mod_fn) for function references._import_hop_flex_attention) correctly extracts score/mask modification functions fromget_attrnodes using module IDs, following the while_loop HOP pattern.kernel_optionsperformance tuning parameters.