Skip to content

Conversation

@MagellaX
Copy link
Owner

@MagellaX MagellaX commented Aug 24, 2025

Summary

  • switch FlashAttentionV3 to torch.nn.attention.sdpa_kernel with flash backend and CUDA fallback
  • simplify wrapper to compute SDPA once
  • migrate fused online attention to the new SDPA context manager with legacy fallback

Testing

  • pytest -q

https://chatgpt.com/codex/tasks/task_e_68ab745edc4483229d99bc35f2dcc2af


Summary by cubic

Adopted torch.nn.attention.sdpa_kernel with FLASH backend and robust fallbacks, simplifying our attention code and improving PyTorch compatibility on CUDA.

  • Refactors
    • Use torch.nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION) when available.
    • Fallback to torch.backends.cuda.sdp_kernel, else no-op context for CPU/older PyTorch.
    • Simplified FlashAttentionV3 wrapper and unified SDPA calls in forward/backward paths for fused online attention.

Copy link
Contributor

@cubic-dev-ai cubic-dev-ai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 issue found across 2 files

React with 👍 or 👎 to teach cubic. You can also tag @cubic-dev-ai to give feedback, ask questions, or re-run the review.

"FlashAttention SDPA failed under forced settings, retrying default: %s",
e,
)
with sdpa_ctx:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fallback logic for the SDPA kernel is inconsistent between files and has been made more fragile. This file's fallback for older PyTorch versions disables the math kernel (enable_math=False), while fused_online_attention.py enables it. Removing the try/except block here means that if the flash kernel is unavailable for certain inputs, the operation will now raise an unhandled RuntimeError instead of gracefully falling back.

Prompt for AI agents
Address the following comment on stream_attention/core/flashattention_v3.py at line 103:

<comment>The fallback logic for the SDPA kernel is inconsistent between files and has been made more fragile. This file&#39;s fallback for older PyTorch versions disables the `math` kernel (`enable_math=False`), while `fused_online_attention.py` enables it. Removing the `try/except` block here means that if the `flash` kernel is unavailable for certain inputs, the operation will now raise an unhandled `RuntimeError` instead of gracefully falling back.</comment>

<file context>
@@ -100,23 +100,7 @@ def forward(
                     )
                     sdpa_ctx = nullcontext()
 
-        try:
-            with sdpa_ctx:
-                out = F.scaled_dot_product_attention(
-                    q,
-                    k,
-                    v,
</file context>

@MagellaX
Copy link
Owner Author

@cursoragent fix this

@cursor
Copy link

cursor bot commented Aug 24, 2025

I improved the robustness of the FlashAttentionV3 module by adjusting SDPA kernel settings and implementing a retry mechanism for scaled_dot_product_attention failures.

  • I modified stream_attention/core/flashattention_v3.py to set enable_math=True in torch.backends.cuda.sdp_kernel within FlashAttentionV3.forward, potentially broadening kernel availability.
  • I added a try...except RuntimeError block around F.scaled_dot_product_attention. If the operation fails under forced SDPA settings, I retry it without sdpa_ctx, letting PyTorch select a compatible kernel.

1 task completed

  • Align SDPA fallback in flashattention_v3.py (enable math, add try/except)
  • Run tests to verify no regressions after fallback fix
Open in Cursor Open in Web

Learn more about Cursor Agents

@MagellaX MagellaX merged commit 4a98071 into main Aug 24, 2025
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants