From d6a533daded57b813c475f5eb9955556a269613d Mon Sep 17 00:00:00 2001 From: yash solanki Date: Fri, 10 Oct 2025 15:48:57 +0530 Subject: [PATCH 01/10] Fix mask conversion and supply TILE_M/N in backward --- stream_attention/core/fused_online_attention.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/stream_attention/core/fused_online_attention.py b/stream_attention/core/fused_online_attention.py index 3bc6795..45e9e67 100644 --- a/stream_attention/core/fused_online_attention.py +++ b/stream_attention/core/fused_online_attention.py @@ -1050,10 +1050,11 @@ def _prepare_attn_mask( dtype: torch.dtype, ) -> torch.Tensor: mask = attention_mask - if mask.dtype == torch.float16 or mask.dtype == torch.bfloat16: - mask = mask.to(dtype) if mask.dtype == torch.bool: - pass + mask = mask.to(torch.float32) + mask = mask.masked_fill(mask > 0, float('-inf')) + else: + mask = mask.to(dtype) if mask.dim() == 2: mask = mask.view(batch_size, 1, 1, seq_len_k) elif mask.dim() == 3: @@ -1576,6 +1577,8 @@ def backward(ctx, grad_output): M=seq_len_q, N=seq_len_k, D=module.head_dim, + TILE_M=module.tile_size_q, + TILE_N=module.tile_size_k, TILE_K=module.head_dim, scale=module.scale, IS_CAUSAL=ctx.causal, From 37d0af6f4048c86ba96085066fcf9a34b34b94e6 Mon Sep 17 00:00:00 2001 From: yash solanki Date: Fri, 10 Oct 2025 16:22:09 +0530 Subject: [PATCH 02/10] Fix Triton fused attention mask handling --- .../core/fused_online_attention.py | 26 ++++++++++++------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/stream_attention/core/fused_online_attention.py b/stream_attention/core/fused_online_attention.py index 45e9e67..ec90fdb 100644 --- a/stream_attention/core/fused_online_attention.py +++ b/stream_attention/core/fused_online_attention.py @@ -110,7 +110,6 @@ def fused_online_attention_kernel( N: tl.constexpr, # seq_len_k D: tl.constexpr, # head_dim TILE_M: tl.constexpr, - TILE_K: tl.constexpr, TILE_N: tl.constexpr, scale: tl.constexpr, IS_CAUSAL: tl.constexpr, @@ -207,14 +206,25 @@ def fused_online_attention_kernel( k_pos = (start_n + offs_n)[None, :].to(tl.float32) qk += slope * (k_pos - q_pos) - # Online softmax update + # Online softmax update with fully masked-row safeguards tile_max = tl.max(qk, axis=1) - new_max = tl.maximum(running_max, tile_max) - correction = tl.exp(running_max - new_max) + row_has_finite = tl.any(tl.isfinite(qk), axis=1) + candidate_max = tl.maximum(running_max, tile_max) + new_max = tl.where(row_has_finite, candidate_max, running_max) + + prev_max_safe = tl.where(row_has_finite, running_max, 0.0) + new_max_safe = tl.where(row_has_finite, new_max, prev_max_safe) + correction = tl.exp(prev_max_safe - new_max_safe) + correction = tl.where(row_has_finite, correction, 1.0) + acc_num *= correction[:, None] acc_den *= correction - - exp_qk = tl.exp(qk - new_max[:, None]) + running_max = new_max + + running_max_safe = tl.where(row_has_finite, running_max, 0.0) + qk_shifted = qk - running_max_safe[:, None] + qk_shifted = tl.where(row_has_finite[:, None], qk_shifted, float("-inf")) + exp_qk = tl.exp(qk_shifted) if HAS_DROPOUT: bh = off_b * H + off_h @@ -311,7 +321,6 @@ def fused_online_attention_bwd_kernel( N: tl.constexpr, D: tl.constexpr, TILE_M: tl.constexpr, - TILE_K: tl.constexpr, TILE_N: tl.constexpr, IS_CAUSAL: tl.constexpr, HAS_MASK: tl.constexpr, @@ -786,7 +795,6 @@ def _backward_triton( N=seq_len_k, D=head_dim, TILE_M=self.tile_size_q, - TILE_K=head_dim, TILE_N=self.tile_size_k, IS_CAUSAL=causal, HAS_MASK=mask_tensor is not None, @@ -1252,7 +1260,6 @@ def _forward_triton( M=seq_len_q, N=seq_len_k, D=self.head_dim, - TILE_K=self.head_dim, scale=self.scale, IS_CAUSAL=causal, HAS_MASK=has_mask, @@ -1579,7 +1586,6 @@ def backward(ctx, grad_output): D=module.head_dim, TILE_M=module.tile_size_q, TILE_N=module.tile_size_k, - TILE_K=module.head_dim, scale=module.scale, IS_CAUSAL=ctx.causal, HAS_MASK=has_mask, From 27464854de02e0cf9dd2cfcc6754995f37350930 Mon Sep 17 00:00:00 2001 From: yash solanki Date: Fri, 10 Oct 2025 16:24:19 +0530 Subject: [PATCH 03/10] Ensure masked row logic works on Triton 2.1 --- stream_attention/core/fused_online_attention.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/stream_attention/core/fused_online_attention.py b/stream_attention/core/fused_online_attention.py index ec90fdb..aeaa23f 100644 --- a/stream_attention/core/fused_online_attention.py +++ b/stream_attention/core/fused_online_attention.py @@ -208,14 +208,17 @@ def fused_online_attention_kernel( # Online softmax update with fully masked-row safeguards tile_max = tl.max(qk, axis=1) - row_has_finite = tl.any(tl.isfinite(qk), axis=1) + tile_has_finite = tile_max > float("-inf") + prev_has_finite = running_max > float("-inf") + row_has_finite = tile_has_finite | prev_has_finite + candidate_max = tl.maximum(running_max, tile_max) new_max = tl.where(row_has_finite, candidate_max, running_max) - prev_max_safe = tl.where(row_has_finite, running_max, 0.0) + prev_max_safe = tl.where(prev_has_finite, running_max, 0.0) new_max_safe = tl.where(row_has_finite, new_max, prev_max_safe) correction = tl.exp(prev_max_safe - new_max_safe) - correction = tl.where(row_has_finite, correction, 1.0) + correction = tl.where(prev_has_finite, correction, 1.0) acc_num *= correction[:, None] acc_den *= correction From 4323cb62babb541a98fc945697058146f133dc32 Mon Sep 17 00:00:00 2001 From: yash solanki Date: Fri, 10 Oct 2025 16:26:53 +0530 Subject: [PATCH 04/10] Restore masked parity without tl.isfinite --- stream_attention/core/fused_online_attention.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/stream_attention/core/fused_online_attention.py b/stream_attention/core/fused_online_attention.py index aeaa23f..8f254eb 100644 --- a/stream_attention/core/fused_online_attention.py +++ b/stream_attention/core/fused_online_attention.py @@ -216,18 +216,18 @@ def fused_online_attention_kernel( new_max = tl.where(row_has_finite, candidate_max, running_max) prev_max_safe = tl.where(prev_has_finite, running_max, 0.0) - new_max_safe = tl.where(row_has_finite, new_max, prev_max_safe) - correction = tl.exp(prev_max_safe - new_max_safe) - correction = tl.where(prev_has_finite, correction, 1.0) + new_max_safe = tl.where(row_has_finite, new_max, 0.0) + correction = tl.where( + prev_has_finite, tl.exp(prev_max_safe - new_max_safe), 0.0 + ) acc_num *= correction[:, None] acc_den *= correction - running_max = new_max + running_max = tl.where(row_has_finite, new_max, float("-inf")) - running_max_safe = tl.where(row_has_finite, running_max, 0.0) - qk_shifted = qk - running_max_safe[:, None] + qk_shifted = qk - new_max_safe[:, None] qk_shifted = tl.where(row_has_finite[:, None], qk_shifted, float("-inf")) - exp_qk = tl.exp(qk_shifted) + exp_qk = tl.where(row_has_finite[:, None], tl.exp(qk_shifted), 0.0) if HAS_DROPOUT: bh = off_b * H + off_h @@ -523,7 +523,7 @@ def fused_online_attention_bwd_kernel( lse_ptrs = Lse + off_b * stride_lsb + off_h * stride_lsh + offs_m * stride_lsm lse = tl.load(lse_ptrs, mask=offs_m < M, other=float('-inf')).to(tl.float32) - row_mask = tl.isfinite(lse) + row_mask = lse > float("-inf") lse = tl.where(row_mask, lse, 0.0) go = go * row_mask[:, None] From 56a04d7650874c8de072d5c34e4309b712c13e0f Mon Sep 17 00:00:00 2001 From: yash solanki Date: Fri, 10 Oct 2025 16:56:14 +0530 Subject: [PATCH 05/10] Guard online softmax updates for masked rows --- .../core/fused_online_attention.py | 23 ++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/stream_attention/core/fused_online_attention.py b/stream_attention/core/fused_online_attention.py index 8f254eb..109d4ec 100644 --- a/stream_attention/core/fused_online_attention.py +++ b/stream_attention/core/fused_online_attention.py @@ -208,26 +208,23 @@ def fused_online_attention_kernel( # Online softmax update with fully masked-row safeguards tile_max = tl.max(qk, axis=1) - tile_has_finite = tile_max > float("-inf") - prev_has_finite = running_max > float("-inf") - row_has_finite = tile_has_finite | prev_has_finite + prev_valid = running_max > float("-inf") + tile_valid = tile_max > float("-inf") + row_valid = prev_valid | tile_valid candidate_max = tl.maximum(running_max, tile_max) - new_max = tl.where(row_has_finite, candidate_max, running_max) + new_max = tl.where(row_valid, candidate_max, float("-inf")) - prev_max_safe = tl.where(prev_has_finite, running_max, 0.0) - new_max_safe = tl.where(row_has_finite, new_max, 0.0) - correction = tl.where( - prev_has_finite, tl.exp(prev_max_safe - new_max_safe), 0.0 - ) + safe_prev = tl.where(prev_valid, running_max, 0.0) + safe_new = tl.where(row_valid, new_max, 0.0) + correction = tl.where(prev_valid, tl.exp(safe_prev - safe_new), 1.0) acc_num *= correction[:, None] acc_den *= correction - running_max = tl.where(row_has_finite, new_max, float("-inf")) + running_max = new_max - qk_shifted = qk - new_max_safe[:, None] - qk_shifted = tl.where(row_has_finite[:, None], qk_shifted, float("-inf")) - exp_qk = tl.where(row_has_finite[:, None], tl.exp(qk_shifted), 0.0) + qk_shifted = qk - safe_new[:, None] + exp_qk = tl.where(row_valid[:, None], tl.exp(qk_shifted), 0.0) if HAS_DROPOUT: bh = off_b * H + off_h From 748daac320673070d8911e9bdd4c9245cb7eee4d Mon Sep 17 00:00:00 2001 From: yash solanki Date: Fri, 10 Oct 2025 17:05:36 +0530 Subject: [PATCH 06/10] Stabilize Triton online softmax for masked rows --- .../core/fused_online_attention.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/stream_attention/core/fused_online_attention.py b/stream_attention/core/fused_online_attention.py index 109d4ec..d672b45 100644 --- a/stream_attention/core/fused_online_attention.py +++ b/stream_attention/core/fused_online_attention.py @@ -154,6 +154,7 @@ def fused_online_attention_kernel( running_max = tl.full([TILE_M], value=-float("inf"), dtype=tl.float32) acc_num = tl.zeros([TILE_M, D], dtype=tl.float32) acc_den = tl.zeros([TILE_M], dtype=tl.float32) + has_valid = tl.zeros([TILE_M], dtype=tl.int1) # Iterate over K/V tiles for start_n in range(0, N, TILE_N): @@ -208,23 +209,21 @@ def fused_online_attention_kernel( # Online softmax update with fully masked-row safeguards tile_max = tl.max(qk, axis=1) - prev_valid = running_max > float("-inf") tile_valid = tile_max > float("-inf") - row_valid = prev_valid | tile_valid + new_has_valid = has_valid | tile_valid candidate_max = tl.maximum(running_max, tile_max) - new_max = tl.where(row_valid, candidate_max, float("-inf")) - - safe_prev = tl.where(prev_valid, running_max, 0.0) - safe_new = tl.where(row_valid, new_max, 0.0) - correction = tl.where(prev_valid, tl.exp(safe_prev - safe_new), 1.0) + safe_prev = tl.where(has_valid, running_max, 0.0) + safe_new = tl.where(new_has_valid, candidate_max, 0.0) + correction = tl.where(has_valid, tl.exp(safe_prev - safe_new), 1.0) + running_max = tl.where(new_has_valid, candidate_max, float("-inf")) acc_num *= correction[:, None] acc_den *= correction - running_max = new_max qk_shifted = qk - safe_new[:, None] - exp_qk = tl.where(row_valid[:, None], tl.exp(qk_shifted), 0.0) + exp_qk = tl.where(new_has_valid[:, None], tl.exp(qk_shifted), 0.0) + has_valid = new_has_valid if HAS_DROPOUT: bh = off_b * H + off_h @@ -238,7 +237,6 @@ def fused_online_attention_kernel( acc_num += tl.dot(exp_qk, v) acc_den += tl.sum(exp_qk, axis=1) - running_max = new_max # Final output with safe denominator; handle rows with all keys masked zero_den = acc_den == 0 From 809697fa840a2518db5fcd84af5481a9550f4c26 Mon Sep 17 00:00:00 2001 From: yash solanki Date: Fri, 10 Oct 2025 17:16:49 +0530 Subject: [PATCH 07/10] Use explicit validity mask in online softmax --- stream_attention/core/fused_online_attention.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/stream_attention/core/fused_online_attention.py b/stream_attention/core/fused_online_attention.py index d672b45..84b6810 100644 --- a/stream_attention/core/fused_online_attention.py +++ b/stream_attention/core/fused_online_attention.py @@ -154,7 +154,7 @@ def fused_online_attention_kernel( running_max = tl.full([TILE_M], value=-float("inf"), dtype=tl.float32) acc_num = tl.zeros([TILE_M, D], dtype=tl.float32) acc_den = tl.zeros([TILE_M], dtype=tl.float32) - has_valid = tl.zeros([TILE_M], dtype=tl.int1) + has_valid = tl.zeros([TILE_M], dtype=tl.int32) # Iterate over K/V tiles for start_n in range(0, N, TILE_N): @@ -209,21 +209,22 @@ def fused_online_attention_kernel( # Online softmax update with fully masked-row safeguards tile_max = tl.max(qk, axis=1) + prev_valid = has_valid.to(tl.int1) > 0 tile_valid = tile_max > float("-inf") - new_has_valid = has_valid | tile_valid + new_valid = prev_valid | tile_valid candidate_max = tl.maximum(running_max, tile_max) - safe_prev = tl.where(has_valid, running_max, 0.0) - safe_new = tl.where(new_has_valid, candidate_max, 0.0) - correction = tl.where(has_valid, tl.exp(safe_prev - safe_new), 1.0) + safe_prev = tl.where(prev_valid, running_max, 0.0) + safe_new = tl.where(new_valid, candidate_max, 0.0) + correction = tl.where(prev_valid, tl.exp(safe_prev - safe_new), 1.0) - running_max = tl.where(new_has_valid, candidate_max, float("-inf")) + running_max = tl.where(new_valid, candidate_max, float("-inf")) acc_num *= correction[:, None] acc_den *= correction qk_shifted = qk - safe_new[:, None] - exp_qk = tl.where(new_has_valid[:, None], tl.exp(qk_shifted), 0.0) - has_valid = new_has_valid + exp_qk = tl.where(new_valid[:, None], tl.exp(qk_shifted), 0.0) + has_valid = new_valid.to(tl.int32) if HAS_DROPOUT: bh = off_b * H + off_h From eaa46497c9be4b96572969b3c194d2fc8d4f1fc3 Mon Sep 17 00:00:00 2001 From: yash solanki Date: Fri, 10 Oct 2025 17:39:55 +0530 Subject: [PATCH 08/10] Fix validity tracking in Triton mask path --- stream_attention/core/fused_online_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stream_attention/core/fused_online_attention.py b/stream_attention/core/fused_online_attention.py index 84b6810..0860d5a 100644 --- a/stream_attention/core/fused_online_attention.py +++ b/stream_attention/core/fused_online_attention.py @@ -209,7 +209,7 @@ def fused_online_attention_kernel( # Online softmax update with fully masked-row safeguards tile_max = tl.max(qk, axis=1) - prev_valid = has_valid.to(tl.int1) > 0 + prev_valid = has_valid > 0 tile_valid = tile_max > float("-inf") new_valid = prev_valid | tile_valid From ae154cb87058368abb3e4de542b18920fcc3f1e4 Mon Sep 17 00:00:00 2001 From: yash solanki Date: Fri, 10 Oct 2025 17:43:54 +0530 Subject: [PATCH 09/10] Track mask validity in float space --- stream_attention/core/fused_online_attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stream_attention/core/fused_online_attention.py b/stream_attention/core/fused_online_attention.py index 0860d5a..829f92f 100644 --- a/stream_attention/core/fused_online_attention.py +++ b/stream_attention/core/fused_online_attention.py @@ -154,7 +154,7 @@ def fused_online_attention_kernel( running_max = tl.full([TILE_M], value=-float("inf"), dtype=tl.float32) acc_num = tl.zeros([TILE_M, D], dtype=tl.float32) acc_den = tl.zeros([TILE_M], dtype=tl.float32) - has_valid = tl.zeros([TILE_M], dtype=tl.int32) + has_valid = tl.zeros([TILE_M], dtype=tl.float32) # Iterate over K/V tiles for start_n in range(0, N, TILE_N): @@ -224,7 +224,7 @@ def fused_online_attention_kernel( qk_shifted = qk - safe_new[:, None] exp_qk = tl.where(new_valid[:, None], tl.exp(qk_shifted), 0.0) - has_valid = new_valid.to(tl.int32) + has_valid = new_valid.to(tl.float32) if HAS_DROPOUT: bh = off_b * H + off_h From 90133308196e471939d2fb42da30554f5d6f57b6 Mon Sep 17 00:00:00 2001 From: yash solanki Date: Sat, 11 Oct 2025 00:16:32 +0530 Subject: [PATCH 10/10] Align Triton mask parity with math backend --- .../core/fused_online_attention.py | 9 +++++--- tests/test_attention.py | 23 +++++++++++++++---- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/stream_attention/core/fused_online_attention.py b/stream_attention/core/fused_online_attention.py index 829f92f..8f40f24 100644 --- a/stream_attention/core/fused_online_attention.py +++ b/stream_attention/core/fused_online_attention.py @@ -149,6 +149,7 @@ def fused_online_attention_kernel( q = tl.load(q_ptrs, mask=q_mask, other=0.0) else: q = tl.load(q_ptrs, mask=q_mask, other=0.0) + q = q.to(tl.float32) # Accumulators running_max = tl.full([TILE_M], value=-float("inf"), dtype=tl.float32) @@ -176,10 +177,12 @@ def fused_online_attention_kernel( if USE_CP_ASYNC: # cp.async + double buffering placeholder k = tl.load(k_ptrs, mask=kv_mask, other=0.0) - v = tl.load(v_ptrs, mask=kv_mask, other=0.0).to(tl.float32) + v = tl.load(v_ptrs, mask=kv_mask, other=0.0) else: k = tl.load(k_ptrs, mask=kv_mask, other=0.0) - v = tl.load(v_ptrs, mask=kv_mask, other=0.0).to(tl.float32) + v = tl.load(v_ptrs, mask=kv_mask, other=0.0) + k = k.to(tl.float32) + v = v.to(tl.float32) # QK^T # Hopper uses WGMMA tensor cores; Ampere uses mma.sync @@ -198,7 +201,7 @@ def fused_online_attention_kernel( + (offs_m[:, None] * stride_mm + (start_n + offs_n)[None, :] * stride_mn) ) mask_mask = (offs_m[:, None] < M) & ((start_n + offs_n)[None, :] < N) - mask_vals = tl.load(mask_ptrs, mask=mask_mask, other=0.0) + mask_vals = tl.load(mask_ptrs, mask=mask_mask, other=0.0).to(qk.dtype) qk += mask_vals if HAS_ALIBI: diff --git a/tests/test_attention.py b/tests/test_attention.py index f6058db..12c82ca 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -11,6 +11,13 @@ from typing import Tuple, Optional import time import gc +from contextlib import nullcontext + +try: + from torch.nn.attention import sdpa_kernel as sdpa_kernel_ctx, SDPBackend +except ImportError: # pragma: no cover - older PyTorch + sdpa_kernel_ctx = None + SDPBackend = None from stream_attention import StreamAttention, StreamAttentionConfig from stream_attention.core.flashattention_v3 import FlashAttentionV3 @@ -20,6 +27,16 @@ ) from stream_attention.core.ring_attention import RingAttention from stream_attention.core.star_attention import StarAttention + + +def _math_sdpa_ctx(): + if sdpa_kernel_ctx is not None and SDPBackend is not None: + return sdpa_kernel_ctx(SDPBackend.MATH) + if torch.cuda.is_available(): + return torch.backends.cuda.sdp_kernel( + enable_math=True, enable_flash=False, enable_mem_efficient=False + ) + return nullcontext() from stream_attention.utils.memory import create_kv_compressor, MemoryProfiler @@ -152,9 +169,7 @@ def test_fused_online_attention_mask_parity(self, device): torch.full((1,), float("-inf"), dtype=q_bh.dtype, device=device), torch.zeros(1, dtype=q_bh.dtype, device=device), ) - with torch.backends.cuda.sdp_kernel( - enable_math=True, enable_flash=False, enable_mem_efficient=False - ): + with _math_sdpa_ctx(): ref = torch.nn.functional.scaled_dot_product_attention( q_bh, k_bh, v_bh, attn_mask=add_mask, is_causal=False, dropout_p=0.0 ) @@ -253,7 +268,7 @@ def test_fused_online_attention_backward_matches_sdpa(self, device): alibi_bias = alibi_bias.reshape(batch_size * num_heads, seq_len_q, seq_len_k) combined_bias = mask_bh + alibi_bias - with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False): + with _math_sdpa_ctx(): ref_out = torch.nn.functional.scaled_dot_product_attention( q_bh, k_bh, v_bh, attn_mask=combined_bias, is_causal=True, dropout_p=0.0 )