Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion stream_attention/core/flashattention_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def forward(
# Fallback for older PyTorch releases
try:
sdpa_ctx = torch.backends.cuda.sdp_kernel(
enable_math=False,
enable_math=True,
enable_flash=True,
enable_mem_efficient=False,
)
Expand Down
51 changes: 35 additions & 16 deletions stream_attention/core/fused_online_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
import torch.distributed as dist
from typing import Optional, Tuple, Dict
import logging
from contextlib import nullcontext

try:
from torch.nn.attention import SDPBackend
except ImportError: # pragma: no cover - older PyTorch
SDPBackend = None

try:
import triton
Expand Down Expand Up @@ -374,12 +380,22 @@ def forward(
dropout_p=(dropout_p if self.training else 0.0),
)

sdpa_ctx = nullcontext()
if q.is_cuda:
with torch.backends.cuda.sdp_kernel(
enable_math=True, enable_flash=True, enable_mem_efficient=False
):
out = F.scaled_dot_product_attention(q, k, v, **sdpa_kwargs)
else:
try:
sdpa_ctx = torch.nn.attention.sdpa_kernel(
SDPBackend.FLASH_ATTENTION
)
except (AttributeError, TypeError):
try:
sdpa_ctx = torch.backends.cuda.sdp_kernel(
enable_math=True,
enable_flash=True,
enable_mem_efficient=False,
)
except Exception: # pragma: no cover - environment dependent
sdpa_ctx = nullcontext()
with sdpa_ctx:
out = F.scaled_dot_product_attention(q, k, v, **sdpa_kwargs)

out = (
Expand Down Expand Up @@ -666,19 +682,22 @@ def backward(ctx, grad_out):
)
add_mask = add_mask + tri_add

sdpa_ctx = nullcontext()
if q.is_cuda:
with torch.backends.cuda.sdp_kernel(
enable_math=True, enable_flash=True, enable_mem_efficient=False
):
y = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=(add_mask if add_mask is not None else None),
is_causal=(False if add_mask is not None else ctx.causal),
dropout_p=0.0,
try:
sdpa_ctx = torch.nn.attention.sdpa_kernel(
SDPBackend.FLASH_ATTENTION
)
else:
except (AttributeError, TypeError):
try:
sdpa_ctx = torch.backends.cuda.sdp_kernel(
enable_math=True,
enable_flash=True,
enable_mem_efficient=False,
)
except Exception: # pragma: no cover - environment dependent
sdpa_ctx = nullcontext()
with sdpa_ctx:
y = F.scaled_dot_product_attention(
q,
k,
Expand Down