Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 1 addition & 17 deletions stream_attention/core/flashattention_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,23 +100,7 @@ def forward(
)
sdpa_ctx = nullcontext()

try:
with sdpa_ctx:
out = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=causal,
)
except RuntimeError as e: # pragma: no cover - device/kernel dependent
# If a forced-flash configuration leads to "no available kernel",
# retry without any forced backend so PyTorch can choose a valid one.
logger.debug(
"FlashAttention SDPA failed under forced settings, retrying default: %s",
e,
)
with sdpa_ctx:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fallback logic for the SDPA kernel is inconsistent between files and has been made more fragile. This file's fallback for older PyTorch versions disables the math kernel (enable_math=False), while fused_online_attention.py enables it. Removing the try/except block here means that if the flash kernel is unavailable for certain inputs, the operation will now raise an unhandled RuntimeError instead of gracefully falling back.

Prompt for AI agents
Address the following comment on stream_attention/core/flashattention_v3.py at line 103:

<comment>The fallback logic for the SDPA kernel is inconsistent between files and has been made more fragile. This file&#39;s fallback for older PyTorch versions disables the `math` kernel (`enable_math=False`), while `fused_online_attention.py` enables it. Removing the `try/except` block here means that if the `flash` kernel is unavailable for certain inputs, the operation will now raise an unhandled `RuntimeError` instead of gracefully falling back.</comment>

<file context>
@@ -100,23 +100,7 @@ def forward(
                     )
                     sdpa_ctx = nullcontext()
 
-        try:
-            with sdpa_ctx:
-                out = F.scaled_dot_product_attention(
-                    q,
-                    k,
-                    v,
</file context>

out = F.scaled_dot_product_attention(
q,
k,
Expand Down
51 changes: 35 additions & 16 deletions stream_attention/core/fused_online_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
import torch.distributed as dist
from typing import Optional, Tuple, Dict
import logging
from contextlib import nullcontext

try:
from torch.nn.attention import SDPBackend
except ImportError: # pragma: no cover - older PyTorch
SDPBackend = None

try:
import triton
Expand Down Expand Up @@ -374,12 +380,22 @@ def forward(
dropout_p=(dropout_p if self.training else 0.0),
)

sdpa_ctx = nullcontext()
if q.is_cuda:
with torch.backends.cuda.sdp_kernel(
enable_math=True, enable_flash=True, enable_mem_efficient=False
):
out = F.scaled_dot_product_attention(q, k, v, **sdpa_kwargs)
else:
try:
sdpa_ctx = torch.nn.attention.sdpa_kernel(
SDPBackend.FLASH_ATTENTION
)
except (AttributeError, TypeError):
try:
sdpa_ctx = torch.backends.cuda.sdp_kernel(
enable_math=True,
enable_flash=True,
enable_mem_efficient=False,
)
except Exception: # pragma: no cover - environment dependent
sdpa_ctx = nullcontext()
with sdpa_ctx:
out = F.scaled_dot_product_attention(q, k, v, **sdpa_kwargs)

out = (
Expand Down Expand Up @@ -666,19 +682,22 @@ def backward(ctx, grad_out):
)
add_mask = add_mask + tri_add

sdpa_ctx = nullcontext()
if q.is_cuda:
with torch.backends.cuda.sdp_kernel(
enable_math=True, enable_flash=True, enable_mem_efficient=False
):
y = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=(add_mask if add_mask is not None else None),
is_causal=(False if add_mask is not None else ctx.causal),
dropout_p=0.0,
try:
sdpa_ctx = torch.nn.attention.sdpa_kernel(
SDPBackend.FLASH_ATTENTION
)
else:
except (AttributeError, TypeError):
try:
sdpa_ctx = torch.backends.cuda.sdp_kernel(
enable_math=True,
enable_flash=True,
enable_mem_efficient=False,
)
except Exception: # pragma: no cover - environment dependent
sdpa_ctx = nullcontext()
with sdpa_ctx:
y = F.scaled_dot_product_attention(
q,
k,
Expand Down