diff --git a/stream_attention/core/fused_online_attention.py b/stream_attention/core/fused_online_attention.py index e92ceb5..e7ac50d 100644 --- a/stream_attention/core/fused_online_attention.py +++ b/stream_attention/core/fused_online_attention.py @@ -896,11 +896,21 @@ def forward( effective_dropout = dropout_p if self.training else 0.0 deterministic_mode = self.deterministic if deterministic is None else deterministic - mask_supported = (attention_mask is None) or ( - attention_mask.dim() == 2 - and attention_mask.shape[0] == batch_size - and attention_mask.shape[1] == seq_len_k - ) + mask_supported = attention_mask is None + if not mask_supported: + dims = attention_mask.dim() + if dims == 2: + mask_supported = attention_mask.shape[0] == batch_size and attention_mask.shape[1] == seq_len_k + elif dims == 3: + mask_supported = ( + attention_mask.shape[0] == batch_size + and attention_mask.shape[1] == seq_len_q + and attention_mask.shape[2] == seq_len_k + ) + elif dims == 4: + mask_supported = attention_mask.shape[0] == batch_size and attention_mask.shape[-1] == seq_len_k + else: + mask_supported = False use_triton = ( TRITON_AVAILABLE @@ -1010,18 +1020,13 @@ def forward( sdpa_ctx = nullcontext() if q.is_cuda: try: - sdpa_ctx = torch.nn.attention.sdpa_kernel( - SDPBackend.FLASH_ATTENTION + sdpa_ctx = torch.backends.cuda.sdp_kernel( + enable_math=True, + enable_flash=False, + enable_mem_efficient=False, ) - except (AttributeError, TypeError): - try: - sdpa_ctx = torch.backends.cuda.sdp_kernel( - enable_math=True, - enable_flash=True, - enable_mem_efficient=False, - ) - except Exception: # pragma: no cover - environment dependent - sdpa_ctx = nullcontext() + except Exception: # pragma: no cover - environment dependent + sdpa_ctx = nullcontext() with sdpa_ctx: out = F.scaled_dot_product_attention(q, k, v, **sdpa_kwargs) @@ -1649,16 +1654,13 @@ def backward(ctx, grad_output): sdpa_ctx = nullcontext() if q_ref.is_cuda: try: - sdpa_ctx = torch.nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION) - except (AttributeError, TypeError): - try: - sdpa_ctx = torch.backends.cuda.sdp_kernel( - enable_math=True, - enable_flash=True, - enable_mem_efficient=False, - ) - except Exception: # pragma: no cover - environment dependent - sdpa_ctx = nullcontext() + sdpa_ctx = torch.backends.cuda.sdp_kernel( + enable_math=True, + enable_flash=False, + enable_mem_efficient=False, + ) + except Exception: # pragma: no cover - environment dependent + sdpa_ctx = nullcontext() go = grad_output.permute(0, 2, 1, 3).reshape(bsz * nh, seq_len_q, hd)