diff --git a/stream_attention/core/flashattention_v3.py b/stream_attention/core/flashattention_v3.py index 1119572..8ad2dc2 100644 --- a/stream_attention/core/flashattention_v3.py +++ b/stream_attention/core/flashattention_v3.py @@ -87,7 +87,7 @@ def forward( # Fallback for older PyTorch releases try: sdpa_ctx = torch.backends.cuda.sdp_kernel( - enable_math=False, + enable_math=True, enable_flash=True, enable_mem_efficient=False, ) diff --git a/stream_attention/core/fused_online_attention.py b/stream_attention/core/fused_online_attention.py index c9d4c1d..5cec616 100644 --- a/stream_attention/core/fused_online_attention.py +++ b/stream_attention/core/fused_online_attention.py @@ -22,6 +22,12 @@ import torch.distributed as dist from typing import Optional, Tuple, Dict import logging +from contextlib import nullcontext + +try: + from torch.nn.attention import SDPBackend +except ImportError: # pragma: no cover - older PyTorch + SDPBackend = None try: import triton @@ -374,12 +380,22 @@ def forward( dropout_p=(dropout_p if self.training else 0.0), ) + sdpa_ctx = nullcontext() if q.is_cuda: - with torch.backends.cuda.sdp_kernel( - enable_math=True, enable_flash=True, enable_mem_efficient=False - ): - out = F.scaled_dot_product_attention(q, k, v, **sdpa_kwargs) - else: + try: + sdpa_ctx = torch.nn.attention.sdpa_kernel( + SDPBackend.FLASH_ATTENTION + ) + except (AttributeError, TypeError): + try: + sdpa_ctx = torch.backends.cuda.sdp_kernel( + enable_math=True, + enable_flash=True, + enable_mem_efficient=False, + ) + except Exception: # pragma: no cover - environment dependent + sdpa_ctx = nullcontext() + with sdpa_ctx: out = F.scaled_dot_product_attention(q, k, v, **sdpa_kwargs) out = ( @@ -666,19 +682,22 @@ def backward(ctx, grad_out): ) add_mask = add_mask + tri_add + sdpa_ctx = nullcontext() if q.is_cuda: - with torch.backends.cuda.sdp_kernel( - enable_math=True, enable_flash=True, enable_mem_efficient=False - ): - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=(add_mask if add_mask is not None else None), - is_causal=(False if add_mask is not None else ctx.causal), - dropout_p=0.0, + try: + sdpa_ctx = torch.nn.attention.sdpa_kernel( + SDPBackend.FLASH_ATTENTION ) - else: + except (AttributeError, TypeError): + try: + sdpa_ctx = torch.backends.cuda.sdp_kernel( + enable_math=True, + enable_flash=True, + enable_mem_efficient=False, + ) + except Exception: # pragma: no cover - environment dependent + sdpa_ctx = nullcontext() + with sdpa_ctx: y = F.scaled_dot_product_attention( q, k,