-
Notifications
You must be signed in to change notification settings - Fork 0
Streaming backward #35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 issue found across 2 files
Prompt for AI agents (all 1 issues)
Understand the root cause of the following 1 issues and fix them.
<file name="tests/test_attention.py">
<violation number="1" location="tests/test_attention.py:32">
The test suite in `tests/test_attention.py` is missing a test case for the core bug being fixed: handling fully masked rows. The new logic in `fused_online_attention.py` is designed to prevent NaNs in this scenario, but without a dedicated test, the fix is not verified and could regress.</violation>
</file>
React with 👍 or 👎 to teach cubic. Mention @cubic-dev-ai to give feedback, ask questions, or re-run the review.
| from stream_attention.core.star_attention import StarAttention | ||
|
|
||
|
|
||
| def _math_sdpa_ctx(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test suite in tests/test_attention.py is missing a test case for the core bug being fixed: handling fully masked rows. The new logic in fused_online_attention.py is designed to prevent NaNs in this scenario, but without a dedicated test, the fix is not verified and could regress.
Prompt for AI agents
Address the following comment on tests/test_attention.py at line 32:
<comment>The test suite in `tests/test_attention.py` is missing a test case for the core bug being fixed: handling fully masked rows. The new logic in `fused_online_attention.py` is designed to prevent NaNs in this scenario, but without a dedicated test, the fix is not verified and could regress.</comment>
<file context>
@@ -20,6 +27,16 @@
from stream_attention.core.star_attention import StarAttention
+
+
+def _math_sdpa_ctx():
+ if sdpa_kernel_ctx is not None and SDPBackend is not None:
+ return sdpa_kernel_ctx(SDPBackend.MATH)
</file context>
Summary by cubic
Improved masked-row handling in Triton fused online attention so masks and softmax match PyTorch SDPA and avoid NaNs on fully masked rows. Backward now receives the correct tile sizes; tests use a portable SDPA math context for parity.