-
Notifications
You must be signed in to change notification settings - Fork 0
Use new sdpa_kernel API with fallback #25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use new sdpa_kernel API with fallback #25
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 issue found across 1 file
React with 👍 or 👎 to teach cubic. You can also tag @cubic-dev-ai to give feedback, ask questions, or re-run the review.
| except (AttributeError, RuntimeError): | ||
| sdpa_ctx = nullcontext() | ||
| # Fallback for older PyTorch releases | ||
| sdpa_ctx = torch.backends.cuda.sdp_kernel( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Calling torch.backends.cuda.sdp_kernel without its own error handling can raise and crash on environments where the kernel is absent, removing the prior graceful fallback.
Prompt for AI agents
Address the following comment on stream_attention/core/flashattention_v3.py at line 88:
<comment>Calling torch.backends.cuda.sdp_kernel without its own error handling can raise and crash on environments where the kernel is absent, removing the prior graceful fallback.</comment>
<file context>
@@ -80,15 +80,14 @@ def forward(
if _use_flash_sdpa() and q.device.type == "cuda":
try:
# Prefer the newer torch.nn.attention API when available
- sdpa_ctx = torch.nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION)
+ sdpa_ctx = torch.nn.attention.sdpa_kernel(
+ SDPBackend.FLASH_ATTENTION
+ )
except (AttributeError, TypeError):
- try:
</file context>
|
@cursoragent look carefully and vigilantly at this issue, and properly fix it!! |
|
I enhanced the robustness of the FlashAttention module by adding fallback mechanisms for SDPA kernel selection and made PyYAML an optional dependency for configuration loading. I also performed a minor cleanup in a benchmark script.
3 tasks completed
Learn more about Cursor Agents |
Summary
torch.nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION)torch.backends.cuda.sdp_kernelon TypeError for older PyTorchTesting
pytest -q(fails: TabError in accuracy_test.py; ModuleNotFoundError: No module named 'yaml')https://chatgpt.com/codex/tasks/task_e_68ab6a18b2448322a01c023171329cbb
Summary by cubic
Use the new torch.nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION) with a fallback to torch.backends.cuda.sdp_kernel for older PyTorch, and remove the redundant attention call so SDPA runs once. This improves version compatibility and avoids extra compute.