diff --git a/stream_attention/core/fused_online_attention.py b/stream_attention/core/fused_online_attention.py index 3bc6795..8f40f24 100644 --- a/stream_attention/core/fused_online_attention.py +++ b/stream_attention/core/fused_online_attention.py @@ -110,7 +110,6 @@ def fused_online_attention_kernel( N: tl.constexpr, # seq_len_k D: tl.constexpr, # head_dim TILE_M: tl.constexpr, - TILE_K: tl.constexpr, TILE_N: tl.constexpr, scale: tl.constexpr, IS_CAUSAL: tl.constexpr, @@ -150,11 +149,13 @@ def fused_online_attention_kernel( q = tl.load(q_ptrs, mask=q_mask, other=0.0) else: q = tl.load(q_ptrs, mask=q_mask, other=0.0) + q = q.to(tl.float32) # Accumulators running_max = tl.full([TILE_M], value=-float("inf"), dtype=tl.float32) acc_num = tl.zeros([TILE_M, D], dtype=tl.float32) acc_den = tl.zeros([TILE_M], dtype=tl.float32) + has_valid = tl.zeros([TILE_M], dtype=tl.float32) # Iterate over K/V tiles for start_n in range(0, N, TILE_N): @@ -176,10 +177,12 @@ def fused_online_attention_kernel( if USE_CP_ASYNC: # cp.async + double buffering placeholder k = tl.load(k_ptrs, mask=kv_mask, other=0.0) - v = tl.load(v_ptrs, mask=kv_mask, other=0.0).to(tl.float32) + v = tl.load(v_ptrs, mask=kv_mask, other=0.0) else: k = tl.load(k_ptrs, mask=kv_mask, other=0.0) - v = tl.load(v_ptrs, mask=kv_mask, other=0.0).to(tl.float32) + v = tl.load(v_ptrs, mask=kv_mask, other=0.0) + k = k.to(tl.float32) + v = v.to(tl.float32) # QK^T # Hopper uses WGMMA tensor cores; Ampere uses mma.sync @@ -198,7 +201,7 @@ def fused_online_attention_kernel( + (offs_m[:, None] * stride_mm + (start_n + offs_n)[None, :] * stride_mn) ) mask_mask = (offs_m[:, None] < M) & ((start_n + offs_n)[None, :] < N) - mask_vals = tl.load(mask_ptrs, mask=mask_mask, other=0.0) + mask_vals = tl.load(mask_ptrs, mask=mask_mask, other=0.0).to(qk.dtype) qk += mask_vals if HAS_ALIBI: @@ -207,14 +210,24 @@ def fused_online_attention_kernel( k_pos = (start_n + offs_n)[None, :].to(tl.float32) qk += slope * (k_pos - q_pos) - # Online softmax update + # Online softmax update with fully masked-row safeguards tile_max = tl.max(qk, axis=1) - new_max = tl.maximum(running_max, tile_max) - correction = tl.exp(running_max - new_max) + prev_valid = has_valid > 0 + tile_valid = tile_max > float("-inf") + new_valid = prev_valid | tile_valid + + candidate_max = tl.maximum(running_max, tile_max) + safe_prev = tl.where(prev_valid, running_max, 0.0) + safe_new = tl.where(new_valid, candidate_max, 0.0) + correction = tl.where(prev_valid, tl.exp(safe_prev - safe_new), 1.0) + + running_max = tl.where(new_valid, candidate_max, float("-inf")) acc_num *= correction[:, None] acc_den *= correction - - exp_qk = tl.exp(qk - new_max[:, None]) + + qk_shifted = qk - safe_new[:, None] + exp_qk = tl.where(new_valid[:, None], tl.exp(qk_shifted), 0.0) + has_valid = new_valid.to(tl.float32) if HAS_DROPOUT: bh = off_b * H + off_h @@ -228,7 +241,6 @@ def fused_online_attention_kernel( acc_num += tl.dot(exp_qk, v) acc_den += tl.sum(exp_qk, axis=1) - running_max = new_max # Final output with safe denominator; handle rows with all keys masked zero_den = acc_den == 0 @@ -311,7 +323,6 @@ def fused_online_attention_bwd_kernel( N: tl.constexpr, D: tl.constexpr, TILE_M: tl.constexpr, - TILE_K: tl.constexpr, TILE_N: tl.constexpr, IS_CAUSAL: tl.constexpr, HAS_MASK: tl.constexpr, @@ -511,7 +522,7 @@ def fused_online_attention_bwd_kernel( lse_ptrs = Lse + off_b * stride_lsb + off_h * stride_lsh + offs_m * stride_lsm lse = tl.load(lse_ptrs, mask=offs_m < M, other=float('-inf')).to(tl.float32) - row_mask = tl.isfinite(lse) + row_mask = lse > float("-inf") lse = tl.where(row_mask, lse, 0.0) go = go * row_mask[:, None] @@ -786,7 +797,6 @@ def _backward_triton( N=seq_len_k, D=head_dim, TILE_M=self.tile_size_q, - TILE_K=head_dim, TILE_N=self.tile_size_k, IS_CAUSAL=causal, HAS_MASK=mask_tensor is not None, @@ -1050,10 +1060,11 @@ def _prepare_attn_mask( dtype: torch.dtype, ) -> torch.Tensor: mask = attention_mask - if mask.dtype == torch.float16 or mask.dtype == torch.bfloat16: - mask = mask.to(dtype) if mask.dtype == torch.bool: - pass + mask = mask.to(torch.float32) + mask = mask.masked_fill(mask > 0, float('-inf')) + else: + mask = mask.to(dtype) if mask.dim() == 2: mask = mask.view(batch_size, 1, 1, seq_len_k) elif mask.dim() == 3: @@ -1251,7 +1262,6 @@ def _forward_triton( M=seq_len_q, N=seq_len_k, D=self.head_dim, - TILE_K=self.head_dim, scale=self.scale, IS_CAUSAL=causal, HAS_MASK=has_mask, @@ -1576,7 +1586,8 @@ def backward(ctx, grad_output): M=seq_len_q, N=seq_len_k, D=module.head_dim, - TILE_K=module.head_dim, + TILE_M=module.tile_size_q, + TILE_N=module.tile_size_k, scale=module.scale, IS_CAUSAL=ctx.causal, HAS_MASK=has_mask, diff --git a/tests/test_attention.py b/tests/test_attention.py index f6058db..12c82ca 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -11,6 +11,13 @@ from typing import Tuple, Optional import time import gc +from contextlib import nullcontext + +try: + from torch.nn.attention import sdpa_kernel as sdpa_kernel_ctx, SDPBackend +except ImportError: # pragma: no cover - older PyTorch + sdpa_kernel_ctx = None + SDPBackend = None from stream_attention import StreamAttention, StreamAttentionConfig from stream_attention.core.flashattention_v3 import FlashAttentionV3 @@ -20,6 +27,16 @@ ) from stream_attention.core.ring_attention import RingAttention from stream_attention.core.star_attention import StarAttention + + +def _math_sdpa_ctx(): + if sdpa_kernel_ctx is not None and SDPBackend is not None: + return sdpa_kernel_ctx(SDPBackend.MATH) + if torch.cuda.is_available(): + return torch.backends.cuda.sdp_kernel( + enable_math=True, enable_flash=False, enable_mem_efficient=False + ) + return nullcontext() from stream_attention.utils.memory import create_kv_compressor, MemoryProfiler @@ -152,9 +169,7 @@ def test_fused_online_attention_mask_parity(self, device): torch.full((1,), float("-inf"), dtype=q_bh.dtype, device=device), torch.zeros(1, dtype=q_bh.dtype, device=device), ) - with torch.backends.cuda.sdp_kernel( - enable_math=True, enable_flash=False, enable_mem_efficient=False - ): + with _math_sdpa_ctx(): ref = torch.nn.functional.scaled_dot_product_attention( q_bh, k_bh, v_bh, attn_mask=add_mask, is_causal=False, dropout_p=0.0 ) @@ -253,7 +268,7 @@ def test_fused_online_attention_backward_matches_sdpa(self, device): alibi_bias = alibi_bias.reshape(batch_size * num_heads, seq_len_q, seq_len_k) combined_bias = mask_bh + alibi_bias - with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False): + with _math_sdpa_ctx(): ref_out = torch.nn.functional.scaled_dot_product_attention( q_bh, k_bh, v_bh, attn_mask=combined_bias, is_causal=True, dropout_p=0.0 )