|
5 | 5 |
|
6 | 6 | # RUN: %PYTHON %s | FileCheck %s |
7 | 7 |
|
| 8 | +from torch._tensor import Tensor |
| 9 | +from torch.nn.attention.flex_attention import _mask_mod_signature |
8 | 10 | from typing import List |
9 | 11 |
|
10 | 12 | import torch |
@@ -251,6 +253,82 @@ def body(i, x): |
251 | 253 | print(m) |
252 | 254 |
|
253 | 255 |
|
| 256 | +@run |
| 257 | +# CHECK-LABEL: test_flex_attention |
| 258 | +# Check that helper functions are emitted as private functions |
| 259 | +# CHECK: func.func private @sdpa_score{{[0-9]+}}(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>, %arg4: !torch.vtensor<[],si32>) -> !torch.vtensor<[],f32> |
| 260 | +# CHECK: func.func private @sdpa_mask{{[0-9]+}}(%arg0: !torch.vtensor<[],si32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>) -> !torch.vtensor<[],i1> |
| 261 | +# Check the main function calls flex_attention with both score_mod and mask_mod |
| 262 | +# CHECK: func.func @test_flex_attention( |
| 263 | +# CHECK-SAME: %arg{{[0-9]+}}: !torch.vtensor<[2,4,16,8],f32>, %arg{{[0-9]+}}: !torch.vtensor<[2,4,16,8],f32>, %arg{{[0-9]+}}: !torch.vtensor<[2,4,16,8],f32>) |
| 264 | +# CHECK-SAME: -> (!torch.vtensor<[2,4,16,8],f32>, !torch.vtensor<[2,4,16],f32>) |
| 265 | +# CHECK: %[[SCALE:.*]] = torch.constant.float 1.000000e+00 |
| 266 | +# CHECK: %[[TRUE:.*]] = torch.constant.bool true |
| 267 | +# CHECK: %[[OUTPUT:.*]], %[[LSE:.*]] = torch.aten.flex_attention %{{.*}}, %{{.*}}, %{{.*}}, %[[SCALE]], %[[TRUE]] {mask_mod_fn = @sdpa_mask{{[0-9]+}}, score_mod_fn = @sdpa_score{{[0-9]+}}} |
| 268 | +# CHECK: return %[[OUTPUT]], %[[LSE]] |
| 269 | +def test_flex_attention(): |
| 270 | + from torch._higher_order_ops.flex_attention import flex_attention |
| 271 | + from torch.nn.attention.flex_attention import create_block_mask |
| 272 | + |
| 273 | + class FlexAttentionWithMaskModule(nn.Module): |
| 274 | + def __init__(self, B, H, M): |
| 275 | + super().__init__() |
| 276 | + |
| 277 | + # Create block mask and register tensors as buffers |
| 278 | + def causal_mask(b, h, q_idx, kv_idx): |
| 279 | + return q_idx >= kv_idx |
| 280 | + |
| 281 | + bm = create_block_mask(causal_mask, B, H, M, M, device="cpu") |
| 282 | + bm_tuple = bm.as_tuple() |
| 283 | + |
| 284 | + # Register each tensor component as a buffer so torch.export can track them |
| 285 | + # as part of the module's state. Without this, the export tracer cannot |
| 286 | + # properly capture these tensors in the graph since they're created outside |
| 287 | + # the forward pass and would appear as untracked external references. |
| 288 | + for idx, tensor in enumerate(bm_tuple): |
| 289 | + if isinstance(tensor, torch.Tensor): |
| 290 | + self.register_buffer(f"bm_tensor_{idx}", tensor, persistent=False) |
| 291 | + else: |
| 292 | + setattr(self, f"bm_scalar_{idx}", tensor) |
| 293 | + |
| 294 | + self.bm_tuple_length = len(bm_tuple) |
| 295 | + |
| 296 | + def forward(self, q, k, v): |
| 297 | + def score_mod(score, b, h, q_idx, kv_idx): |
| 298 | + return score * 0.5 |
| 299 | + |
| 300 | + # Reconstruct block mask tuple from buffers |
| 301 | + bm_tuple = tuple( |
| 302 | + ( |
| 303 | + getattr(self, f"bm_tensor_{idx}") |
| 304 | + if hasattr(self, f"bm_tensor_{idx}") |
| 305 | + else getattr(self, f"bm_scalar_{idx}") |
| 306 | + ) |
| 307 | + for idx in range(self.bm_tuple_length) |
| 308 | + ) |
| 309 | + |
| 310 | + return flex_attention( |
| 311 | + q, |
| 312 | + k, |
| 313 | + v, |
| 314 | + score_mod=score_mod, |
| 315 | + block_mask=bm_tuple, |
| 316 | + scale=1.0, |
| 317 | + kernel_options={}, |
| 318 | + ) |
| 319 | + |
| 320 | + # Export -> import to Torch-MLIR |
| 321 | + B, H, M, K = 2, 4, 16, 8 |
| 322 | + q = torch.randn(B, H, M, K) |
| 323 | + k = torch.randn(B, H, M, K) |
| 324 | + v = torch.randn(B, H, M, K) |
| 325 | + |
| 326 | + m = fx.export_and_import( |
| 327 | + FlexAttentionWithMaskModule(B, H, M), q, k, v, func_name="test_flex_attention" |
| 328 | + ) |
| 329 | + print(m) |
| 330 | + |
| 331 | + |
254 | 332 | @run |
255 | 333 | # CHECK-LABEL: test_stack_trace |
256 | 334 | # CHECK: #loc[[LOC1:.+]] = loc( |
|
0 commit comments