|
37 | 37 |
|
38 | 38 |
|
39 | 39 | class TmaAutoTuneHelper:
|
40 |
| - |
41 | 40 | # duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498
|
42 | 41 | class KernelParamWrapper:
|
43 | 42 | def __init__(self, desc):
|
@@ -734,7 +733,6 @@ def _attn_fwd_tma( # Q, V, desc_k, desc_v, sm_scale, M, Out, #
|
734 | 733 | HEAD_DIM: tl.constexpr, #
|
735 | 734 | STAGE: tl.constexpr, #
|
736 | 735 | ):
|
737 |
| - |
738 | 736 | tl.static_assert(BLOCK_N <= HEAD_DIM)
|
739 | 737 | start_m = tl.program_id(0)
|
740 | 738 | off_hz = tl.program_id(1)
|
@@ -848,7 +846,14 @@ def _attn_fwd_tma( # Q, V, desc_k, desc_v, sm_scale, M, Out, #
|
848 | 846 |
|
849 | 847 | @triton.jit
|
850 | 848 | def _attn_bwd_preprocess(
|
851 |
| - O, DO, Delta, Z, H, N_CTX, BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr # # # # |
| 849 | + O, |
| 850 | + DO, |
| 851 | + Delta, |
| 852 | + Z, |
| 853 | + H, |
| 854 | + N_CTX, |
| 855 | + BLOCK_M: tl.constexpr, |
| 856 | + HEAD_DIM: tl.constexpr, # # # # |
852 | 857 | ):
|
853 | 858 | off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
854 | 859 | off_hz = tl.program_id(1)
|
@@ -1179,7 +1184,6 @@ def _attn_bwd(
|
1179 | 1184 |
|
1180 | 1185 |
|
1181 | 1186 | class _attention_ws(torch.autograd.Function):
|
1182 |
| - |
1183 | 1187 | @staticmethod
|
1184 | 1188 | def forward(ctx, q, k, v, causal, sm_scale):
|
1185 | 1189 | # shape constraints
|
@@ -1232,7 +1236,7 @@ def forward(ctx, q, k, v, causal, sm_scale):
|
1232 | 1236 | N_CTX=q.shape[2], #
|
1233 | 1237 | HEAD_DIM=HEAD_DIM_K, #
|
1234 | 1238 | STAGE=stage, #
|
1235 |
| - **extra_kern_args |
| 1239 | + **extra_kern_args, |
1236 | 1240 | )
|
1237 | 1241 |
|
1238 | 1242 | ctx.save_for_backward(q, k, v, o, M)
|
@@ -1304,7 +1308,6 @@ def backward(ctx, do):
|
1304 | 1308 |
|
1305 | 1309 |
|
1306 | 1310 | class _attention(torch.autograd.Function):
|
1307 |
| - |
1308 | 1311 | @staticmethod
|
1309 | 1312 | def forward(ctx, q, k, v, causal, sm_scale):
|
1310 | 1313 | # shape constraints
|
@@ -1355,7 +1358,7 @@ def forward(ctx, q, k, v, causal, sm_scale):
|
1355 | 1358 | N_CTX=q.shape[2], #
|
1356 | 1359 | HEAD_DIM=HEAD_DIM_K, #
|
1357 | 1360 | STAGE=stage, #
|
1358 |
| - **extra_kern_args |
| 1361 | + **extra_kern_args, |
1359 | 1362 | )
|
1360 | 1363 |
|
1361 | 1364 | ctx.save_for_backward(q, k, v, o, M)
|
@@ -1427,7 +1430,6 @@ def backward(ctx, do):
|
1427 | 1430 |
|
1428 | 1431 |
|
1429 | 1432 | class _attention_tma(torch.autograd.Function):
|
1430 |
| - |
1431 | 1433 | @staticmethod
|
1432 | 1434 | def forward(ctx, q, k, v, causal, sm_scale):
|
1433 | 1435 | # shape constraints
|
@@ -1587,7 +1589,7 @@ def grid_tma(META):
|
1587 | 1589 | N_CTX=q.shape[2], #
|
1588 | 1590 | HEAD_DIM=HEAD_DIM_K, #
|
1589 | 1591 | STAGE=stage, #
|
1590 |
| - **extra_kern_args |
| 1592 | + **extra_kern_args, |
1591 | 1593 | )
|
1592 | 1594 |
|
1593 | 1595 | ctx.save_for_backward(q, k, v, o, M)
|
|
0 commit comments