Skip to content
47 changes: 29 additions & 18 deletions stream_attention/core/fused_online_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def fused_online_attention_kernel(
N: tl.constexpr, # seq_len_k
D: tl.constexpr, # head_dim
TILE_M: tl.constexpr,
TILE_K: tl.constexpr,
TILE_N: tl.constexpr,
scale: tl.constexpr,
IS_CAUSAL: tl.constexpr,
Expand Down Expand Up @@ -150,11 +149,13 @@ def fused_online_attention_kernel(
q = tl.load(q_ptrs, mask=q_mask, other=0.0)
else:
q = tl.load(q_ptrs, mask=q_mask, other=0.0)
q = q.to(tl.float32)

# Accumulators
running_max = tl.full([TILE_M], value=-float("inf"), dtype=tl.float32)
acc_num = tl.zeros([TILE_M, D], dtype=tl.float32)
acc_den = tl.zeros([TILE_M], dtype=tl.float32)
has_valid = tl.zeros([TILE_M], dtype=tl.float32)

# Iterate over K/V tiles
for start_n in range(0, N, TILE_N):
Expand All @@ -176,10 +177,12 @@ def fused_online_attention_kernel(
if USE_CP_ASYNC:
# cp.async + double buffering placeholder
k = tl.load(k_ptrs, mask=kv_mask, other=0.0)
v = tl.load(v_ptrs, mask=kv_mask, other=0.0).to(tl.float32)
v = tl.load(v_ptrs, mask=kv_mask, other=0.0)
else:
k = tl.load(k_ptrs, mask=kv_mask, other=0.0)
v = tl.load(v_ptrs, mask=kv_mask, other=0.0).to(tl.float32)
v = tl.load(v_ptrs, mask=kv_mask, other=0.0)
k = k.to(tl.float32)
v = v.to(tl.float32)

# QK^T
# Hopper uses WGMMA tensor cores; Ampere uses mma.sync
Expand All @@ -198,7 +201,7 @@ def fused_online_attention_kernel(
+ (offs_m[:, None] * stride_mm + (start_n + offs_n)[None, :] * stride_mn)
)
mask_mask = (offs_m[:, None] < M) & ((start_n + offs_n)[None, :] < N)
mask_vals = tl.load(mask_ptrs, mask=mask_mask, other=0.0)
mask_vals = tl.load(mask_ptrs, mask=mask_mask, other=0.0).to(qk.dtype)
qk += mask_vals

if HAS_ALIBI:
Expand All @@ -207,14 +210,24 @@ def fused_online_attention_kernel(
k_pos = (start_n + offs_n)[None, :].to(tl.float32)
qk += slope * (k_pos - q_pos)

# Online softmax update
# Online softmax update with fully masked-row safeguards
tile_max = tl.max(qk, axis=1)
new_max = tl.maximum(running_max, tile_max)
correction = tl.exp(running_max - new_max)
prev_valid = has_valid > 0
tile_valid = tile_max > float("-inf")
new_valid = prev_valid | tile_valid

candidate_max = tl.maximum(running_max, tile_max)
safe_prev = tl.where(prev_valid, running_max, 0.0)
safe_new = tl.where(new_valid, candidate_max, 0.0)
correction = tl.where(prev_valid, tl.exp(safe_prev - safe_new), 1.0)

running_max = tl.where(new_valid, candidate_max, float("-inf"))
acc_num *= correction[:, None]
acc_den *= correction

exp_qk = tl.exp(qk - new_max[:, None])

qk_shifted = qk - safe_new[:, None]
exp_qk = tl.where(new_valid[:, None], tl.exp(qk_shifted), 0.0)
has_valid = new_valid.to(tl.float32)

if HAS_DROPOUT:
bh = off_b * H + off_h
Expand All @@ -228,7 +241,6 @@ def fused_online_attention_kernel(

acc_num += tl.dot(exp_qk, v)
acc_den += tl.sum(exp_qk, axis=1)
running_max = new_max

# Final output with safe denominator; handle rows with all keys masked
zero_den = acc_den == 0
Expand Down Expand Up @@ -311,7 +323,6 @@ def fused_online_attention_bwd_kernel(
N: tl.constexpr,
D: tl.constexpr,
TILE_M: tl.constexpr,
TILE_K: tl.constexpr,
TILE_N: tl.constexpr,
IS_CAUSAL: tl.constexpr,
HAS_MASK: tl.constexpr,
Expand Down Expand Up @@ -511,7 +522,7 @@ def fused_online_attention_bwd_kernel(

lse_ptrs = Lse + off_b * stride_lsb + off_h * stride_lsh + offs_m * stride_lsm
lse = tl.load(lse_ptrs, mask=offs_m < M, other=float('-inf')).to(tl.float32)
row_mask = tl.isfinite(lse)
row_mask = lse > float("-inf")
lse = tl.where(row_mask, lse, 0.0)
go = go * row_mask[:, None]

Expand Down Expand Up @@ -786,7 +797,6 @@ def _backward_triton(
N=seq_len_k,
D=head_dim,
TILE_M=self.tile_size_q,
TILE_K=head_dim,
TILE_N=self.tile_size_k,
IS_CAUSAL=causal,
HAS_MASK=mask_tensor is not None,
Expand Down Expand Up @@ -1050,10 +1060,11 @@ def _prepare_attn_mask(
dtype: torch.dtype,
) -> torch.Tensor:
mask = attention_mask
if mask.dtype == torch.float16 or mask.dtype == torch.bfloat16:
mask = mask.to(dtype)
if mask.dtype == torch.bool:
pass
mask = mask.to(torch.float32)
mask = mask.masked_fill(mask > 0, float('-inf'))
else:
mask = mask.to(dtype)
if mask.dim() == 2:
mask = mask.view(batch_size, 1, 1, seq_len_k)
elif mask.dim() == 3:
Expand Down Expand Up @@ -1251,7 +1262,6 @@ def _forward_triton(
M=seq_len_q,
N=seq_len_k,
D=self.head_dim,
TILE_K=self.head_dim,
scale=self.scale,
IS_CAUSAL=causal,
HAS_MASK=has_mask,
Expand Down Expand Up @@ -1576,7 +1586,8 @@ def backward(ctx, grad_output):
M=seq_len_q,
N=seq_len_k,
D=module.head_dim,
TILE_K=module.head_dim,
TILE_M=module.tile_size_q,
TILE_N=module.tile_size_k,
scale=module.scale,
IS_CAUSAL=ctx.causal,
HAS_MASK=has_mask,
Expand Down
23 changes: 19 additions & 4 deletions tests/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@
from typing import Tuple, Optional
import time
import gc
from contextlib import nullcontext

try:
from torch.nn.attention import sdpa_kernel as sdpa_kernel_ctx, SDPBackend
except ImportError: # pragma: no cover - older PyTorch
sdpa_kernel_ctx = None
SDPBackend = None

from stream_attention import StreamAttention, StreamAttentionConfig
from stream_attention.core.flashattention_v3 import FlashAttentionV3
Expand All @@ -20,6 +27,16 @@
)
from stream_attention.core.ring_attention import RingAttention
from stream_attention.core.star_attention import StarAttention


def _math_sdpa_ctx():
Copy link
Contributor

@cubic-dev-ai cubic-dev-ai bot Oct 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test suite in tests/test_attention.py is missing a test case for the core bug being fixed: handling fully masked rows. The new logic in fused_online_attention.py is designed to prevent NaNs in this scenario, but without a dedicated test, the fix is not verified and could regress.

Prompt for AI agents
Address the following comment on tests/test_attention.py at line 32:

<comment>The test suite in `tests/test_attention.py` is missing a test case for the core bug being fixed: handling fully masked rows. The new logic in `fused_online_attention.py` is designed to prevent NaNs in this scenario, but without a dedicated test, the fix is not verified and could regress.</comment>

<file context>
@@ -20,6 +27,16 @@
 from stream_attention.core.star_attention import StarAttention
+
+
+def _math_sdpa_ctx():
+    if sdpa_kernel_ctx is not None and SDPBackend is not None:
+        return sdpa_kernel_ctx(SDPBackend.MATH)
</file context>
Fix with Cubic

if sdpa_kernel_ctx is not None and SDPBackend is not None:
return sdpa_kernel_ctx(SDPBackend.MATH)
if torch.cuda.is_available():
return torch.backends.cuda.sdp_kernel(
enable_math=True, enable_flash=False, enable_mem_efficient=False
)
return nullcontext()
from stream_attention.utils.memory import create_kv_compressor, MemoryProfiler


Expand Down Expand Up @@ -152,9 +169,7 @@ def test_fused_online_attention_mask_parity(self, device):
torch.full((1,), float("-inf"), dtype=q_bh.dtype, device=device),
torch.zeros(1, dtype=q_bh.dtype, device=device),
)
with torch.backends.cuda.sdp_kernel(
enable_math=True, enable_flash=False, enable_mem_efficient=False
):
with _math_sdpa_ctx():
ref = torch.nn.functional.scaled_dot_product_attention(
q_bh, k_bh, v_bh, attn_mask=add_mask, is_causal=False, dropout_p=0.0
)
Expand Down Expand Up @@ -253,7 +268,7 @@ def test_fused_online_attention_backward_matches_sdpa(self, device):
alibi_bias = alibi_bias.reshape(batch_size * num_heads, seq_len_q, seq_len_k)

combined_bias = mask_bh + alibi_bias
with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False):
with _math_sdpa_ctx():
ref_out = torch.nn.functional.scaled_dot_product_attention(
q_bh, k_bh, v_bh, attn_mask=combined_bias, is_causal=True, dropout_p=0.0
)
Expand Down