-
Notifications
You must be signed in to change notification settings - Fork 0
Use torch.nn.attention sdpa_kernel API #31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 issue found across 2 files
React with 👍 or 👎 to teach cubic. You can also tag @cubic-dev-ai to give feedback, ask questions, or re-run the review.
| "FlashAttention SDPA failed under forced settings, retrying default: %s", | ||
| e, | ||
| ) | ||
| with sdpa_ctx: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fallback logic for the SDPA kernel is inconsistent between files and has been made more fragile. This file's fallback for older PyTorch versions disables the math kernel (enable_math=False), while fused_online_attention.py enables it. Removing the try/except block here means that if the flash kernel is unavailable for certain inputs, the operation will now raise an unhandled RuntimeError instead of gracefully falling back.
Prompt for AI agents
Address the following comment on stream_attention/core/flashattention_v3.py at line 103:
<comment>The fallback logic for the SDPA kernel is inconsistent between files and has been made more fragile. This file's fallback for older PyTorch versions disables the `math` kernel (`enable_math=False`), while `fused_online_attention.py` enables it. Removing the `try/except` block here means that if the `flash` kernel is unavailable for certain inputs, the operation will now raise an unhandled `RuntimeError` instead of gracefully falling back.</comment>
<file context>
@@ -100,23 +100,7 @@ def forward(
)
sdpa_ctx = nullcontext()
- try:
- with sdpa_ctx:
- out = F.scaled_dot_product_attention(
- q,
- k,
- v,
</file context>
|
@cursoragent fix this |
|
I improved the robustness of the
1 task completed
Learn more about Cursor Agents |
Co-authored-by: alphacr792 <[email protected]>
Summary
torch.nn.attention.sdpa_kernelwith flash backend and CUDA fallbackTesting
pytest -qhttps://chatgpt.com/codex/tasks/task_e_68ab745edc4483229d99bc35f2dcc2af
Summary by cubic
Adopted torch.nn.attention.sdpa_kernel with FLASH backend and robust fallbacks, simplifying our attention code and improving PyTorch compatibility on CUDA.