Skip to content

Conversation

@keshavvinayak01
Copy link
Contributor

@keshavvinayak01 keshavvinayak01 commented Nov 4, 2025

Description

  • Added support for PyTorch's flex_attention Higher-Order Operator in torch-mlir.
  • Implemented Torch_AtenFlexAttentionOp with 6 operands (query, key, value, scale, enable_gqa, return_lse) and 2 optional attributes (score_mod_fn, mask_mod_fn) for function references.
  • The FX importer (_import_hop_flex_attention) correctly extracts score/mask modification functions from get_attr nodes using module IDs, following the while_loop HOP pattern.
  • Includes TODO markers for kernel_options performance tuning parameters.
  • Imports flex_attention from PyTorch FX graphs into valid MLIR.

keshavvinayak01 and others added 17 commits October 22, 2025 09:41
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Change 1: Converts builtin tensors → Torch tensors when entering the loop body
Change 2: Ensures Torch tensors → builtin tensors when yielding back to the loop condition
Without these fixes, the conversion would fail when while loops carry tensor values

Also modified basic_test.py FILECHECK statements.

Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
1. Better documentation for AtenFlexAttentionOp
2. Function referece added as attributes to aten.flex_attention
3. Updates to _import_hop_flex_attention reflecting latest changes of module import.
4. Removed discardable attributes; scored_mod_fn and mask_mod_fn added as optionalAttr

Signed-off-by: Keshav Vinayak Jha <[email protected]>
Remove note about method usage for HOPs.
@keshavvinayak01 keshavvinayak01 changed the title Keshavvinayak01/torch aten flex attention [TORCH] Added flex_attention hop function Nov 4, 2025
Removed TODO note for grouped query attention support in the docstring and comments.
@keshavvinayak01 keshavvinayak01 force-pushed the keshavvinayak01/torch-aten-flex_attention branch from 095cb61 to 5e024f6 Compare November 6, 2025 09:36
@keshavvinayak01 keshavvinayak01 marked this pull request as ready for review November 6, 2025 09:37
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Copy link
Collaborator

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does enable importing to mlir.

However, the changes don't actually provide "support" for this op, since the torch op can neither be decomposed nor lowered to any other dialects.

Although we could review/merge this and subsequently add a lowering path for the op in MLIR, I would personally prefer the e2e support is added in the same PR as the import support.

This is a rather unique operator, so having passing e2e tests would give me a lot more confidence in the choices made here. Otherwise I'm basically just hoping that what you did generally makes sense (or doing a significant amount of work myself to check it out), because there really isn't much precedent for these kinds of choices in our codebase.

@Groverkss
Copy link
Member

Groverkss commented Nov 11, 2025

This does enable importing to mlir.

However, the changes don't actually provide "support" for this op, since the torch op can neither be decomposed nor lowered to any other dialects.

Although we could review/merge this and subsequently add a lowering path for the op in MLIR, I would personally prefer the e2e support is added in the same PR as the import support.

This is a rather unique operator, so having passing e2e tests would give me a lot more confidence in the choices made here. Otherwise I'm basically just hoping that what you did generally makes sense (or doing a significant amount of work myself to check it out), because there really isn't much precedent for these kinds of choices in our codebase.

The only thing needed to have this passing e2e tests is implementing TilingInterface for this operation:

LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,

With that said, it's an unreasonable bar to set that every operation must compile e2e through torch-mlir. Torch-MLIR is not a compiler, even though it has tests for e2e paths. The project docs explicitly call out this:

Torch-MLIR is primarily a project that is integrated into compilers to bridge them to PyTorch and ONNX. If contemplating a new integration, it may be helpful to refer to existing downstreams:

IREE
Blade
While most of the project is exercised via testing paths, there are some ways that an end user can directly use the APIs without further integration:

It should be okay to land support for ops through the importer without it running e2e tests in torch-mlir. I've looked at the implementation of e2e tests for more complex ops like attention, and they are not good implementations, they don't add much value.

We should as a project allow landing PRs that add support to the importer seperately from e2e tests (Atleast for HOPs). I don't think having a dummy implementation for an op should be the bar to land an operation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants