Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 28 additions & 26 deletions stream_attention/core/fused_online_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,11 +896,21 @@ def forward(
effective_dropout = dropout_p if self.training else 0.0
deterministic_mode = self.deterministic if deterministic is None else deterministic

mask_supported = (attention_mask is None) or (
attention_mask.dim() == 2
and attention_mask.shape[0] == batch_size
and attention_mask.shape[1] == seq_len_k
)
mask_supported = attention_mask is None
if not mask_supported:
dims = attention_mask.dim()
if dims == 2:
mask_supported = attention_mask.shape[0] == batch_size and attention_mask.shape[1] == seq_len_k
elif dims == 3:
mask_supported = (
attention_mask.shape[0] == batch_size
and attention_mask.shape[1] == seq_len_q
and attention_mask.shape[2] == seq_len_k
)
elif dims == 4:
mask_supported = attention_mask.shape[0] == batch_size and attention_mask.shape[-1] == seq_len_k
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Incomplete Mask Validation Causes Runtime Failures

The 4D mask validation is incomplete. It only checks shape[0] (batch size) and shape[-1] (seq_len_k), missing compatibility checks for shape[1] (num_heads) and shape[2] (seq_len_q). This allows incompatible masks to pass validation, leading to runtime failures during mask expansion instead of falling back to PyTorch SDPA.

Fix in Cursor Fix in Web

Copy link
Contributor

@cubic-dev-ai cubic-dev-ai bot Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4D attention mask validation only checks batch and K lengths, allowing masks with mismatched head or Q dimensions to slip through and fail during mask expansion. Tighten the 4D check to also validate shape[1] (allow 1 or num_heads) and shape[2] (seq_len_q).

Prompt for AI agents
Address the following comment on stream_attention/core/fused_online_attention.py at line 911:

<comment>4D attention mask validation only checks batch and K lengths, allowing masks with mismatched head or Q dimensions to slip through and fail during mask expansion. Tighten the 4D check to also validate shape[1] (allow 1 or num_heads) and shape[2] (seq_len_q).</comment>

<file context>
@@ -896,11 +896,21 @@ def forward(
+                    and attention_mask.shape[2] == seq_len_k
+                )
+            elif dims == 4:
+                mask_supported = attention_mask.shape[0] == batch_size and attention_mask.shape[-1] == seq_len_k
+            else:
+                mask_supported = False
</file context>
Suggested change
mask_supported = attention_mask.shape[0] == batch_size and attention_mask.shape[-1] == seq_len_k
mask_supported = (attention_mask.shape[0] == batch_size and attention_mask.shape[1] in (1, num_heads) and attention_mask.shape[2] == seq_len_q and attention_mask.shape[3] == seq_len_k)
Fix with Cubic

else:
mask_supported = False

use_triton = (
TRITON_AVAILABLE
Expand Down Expand Up @@ -1010,18 +1020,13 @@ def forward(
sdpa_ctx = nullcontext()
if q.is_cuda:
try:
sdpa_ctx = torch.nn.attention.sdpa_kernel(
SDPBackend.FLASH_ATTENTION
sdpa_ctx = torch.backends.cuda.sdp_kernel(
enable_math=True,
enable_flash=False,
enable_mem_efficient=False,
)
except (AttributeError, TypeError):
try:
sdpa_ctx = torch.backends.cuda.sdp_kernel(
enable_math=True,
enable_flash=True,
enable_mem_efficient=False,
)
except Exception: # pragma: no cover - environment dependent
sdpa_ctx = nullcontext()
except Exception: # pragma: no cover - environment dependent
sdpa_ctx = nullcontext()
with sdpa_ctx:
out = F.scaled_dot_product_attention(q, k, v, **sdpa_kwargs)

Expand Down Expand Up @@ -1649,16 +1654,13 @@ def backward(ctx, grad_output):
sdpa_ctx = nullcontext()
if q_ref.is_cuda:
try:
sdpa_ctx = torch.nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION)
except (AttributeError, TypeError):
try:
sdpa_ctx = torch.backends.cuda.sdp_kernel(
enable_math=True,
enable_flash=True,
enable_mem_efficient=False,
)
except Exception: # pragma: no cover - environment dependent
sdpa_ctx = nullcontext()
sdpa_ctx = torch.backends.cuda.sdp_kernel(
enable_math=True,
enable_flash=False,
enable_mem_efficient=False,
)
except Exception: # pragma: no cover - environment dependent
sdpa_ctx = nullcontext()

go = grad_output.permute(0, 2, 1, 3).reshape(bsz * nh, seq_len_q, hd)

Expand Down