Skip to content

Commit 095cb61

Browse files
Added basic lit test for flex_attention
Signed-off-by: Keshav Vinayak Jha <[email protected]>
1 parent 40c76f2 commit 095cb61

File tree

1 file changed

+78
-0
lines changed

1 file changed

+78
-0
lines changed

test/python/fx_importer/basic_test.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
# RUN: %PYTHON %s | FileCheck %s
77

8+
from torch._tensor import Tensor
9+
from torch.nn.attention.flex_attention import _mask_mod_signature
810
from typing import List
911

1012
import torch
@@ -251,6 +253,82 @@ def body(i, x):
251253
print(m)
252254

253255

256+
@run
257+
# CHECK-LABEL: test_flex_attention
258+
# Check that helper functions are emitted as private functions
259+
# CHECK: func.func private @sdpa_score{{[0-9]+}}(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>, %arg4: !torch.vtensor<[],si32>) -> !torch.vtensor<[],f32>
260+
# CHECK: func.func private @sdpa_mask{{[0-9]+}}(%arg0: !torch.vtensor<[],si32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>) -> !torch.vtensor<[],i1>
261+
# Check the main function calls flex_attention with both score_mod and mask_mod
262+
# CHECK: func.func @test_flex_attention(
263+
# CHECK-SAME: %arg{{[0-9]+}}: !torch.vtensor<[2,4,16,8],f32>, %arg{{[0-9]+}}: !torch.vtensor<[2,4,16,8],f32>, %arg{{[0-9]+}}: !torch.vtensor<[2,4,16,8],f32>)
264+
# CHECK-SAME: -> (!torch.vtensor<[2,4,16,8],f32>, !torch.vtensor<[2,4,16],f32>)
265+
# CHECK: %[[SCALE:.*]] = torch.constant.float 1.000000e+00
266+
# CHECK: %[[TRUE:.*]] = torch.constant.bool true
267+
# CHECK: %[[OUTPUT:.*]], %[[LSE:.*]] = torch.aten.flex_attention %{{.*}}, %{{.*}}, %{{.*}}, %[[SCALE]], %[[TRUE]] {mask_mod_fn = @sdpa_mask{{[0-9]+}}, score_mod_fn = @sdpa_score{{[0-9]+}}}
268+
# CHECK: return %[[OUTPUT]], %[[LSE]]
269+
def test_flex_attention():
270+
from torch._higher_order_ops.flex_attention import flex_attention
271+
from torch.nn.attention.flex_attention import create_block_mask
272+
273+
class FlexAttentionWithMaskModule(nn.Module):
274+
def __init__(self, B, H, M):
275+
super().__init__()
276+
277+
# Create block mask and register tensors as buffers
278+
def causal_mask(b, h, q_idx, kv_idx):
279+
return q_idx >= kv_idx
280+
281+
bm = create_block_mask(causal_mask, B, H, M, M, device="cpu")
282+
bm_tuple = bm.as_tuple()
283+
284+
# Register each tensor component as a buffer so torch.export can track them
285+
# as part of the module's state. Without this, the export tracer cannot
286+
# properly capture these tensors in the graph since they're created outside
287+
# the forward pass and would appear as untracked external references.
288+
for idx, tensor in enumerate(bm_tuple):
289+
if isinstance(tensor, torch.Tensor):
290+
self.register_buffer(f"bm_tensor_{idx}", tensor, persistent=False)
291+
else:
292+
setattr(self, f"bm_scalar_{idx}", tensor)
293+
294+
self.bm_tuple_length = len(bm_tuple)
295+
296+
def forward(self, q, k, v):
297+
def score_mod(score, b, h, q_idx, kv_idx):
298+
return score * 0.5
299+
300+
# Reconstruct block mask tuple from buffers
301+
bm_tuple = tuple(
302+
(
303+
getattr(self, f"bm_tensor_{idx}")
304+
if hasattr(self, f"bm_tensor_{idx}")
305+
else getattr(self, f"bm_scalar_{idx}")
306+
)
307+
for idx in range(self.bm_tuple_length)
308+
)
309+
310+
return flex_attention(
311+
q,
312+
k,
313+
v,
314+
score_mod=score_mod,
315+
block_mask=bm_tuple,
316+
scale=1.0,
317+
kernel_options={},
318+
)
319+
320+
# Export -> import to Torch-MLIR
321+
B, H, M, K = 2, 4, 16, 8
322+
q = torch.randn(B, H, M, K)
323+
k = torch.randn(B, H, M, K)
324+
v = torch.randn(B, H, M, K)
325+
326+
m = fx.export_and_import(
327+
FlexAttentionWithMaskModule(B, H, M), q, k, v, func_name="test_flex_attention"
328+
)
329+
print(m)
330+
331+
254332
@run
255333
# CHECK-LABEL: test_stack_trace
256334
# CHECK: #loc[[LOC1:.+]] = loc(

0 commit comments

Comments
 (0)