-
Notifications
You must be signed in to change notification settings - Fork 0
Add Multihead-style wrapper and backend factory #28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
6 issues found across 6 files
React with 👍 or 👎 to teach cubic. You can also tag @cubic-dev-ai to give feedback, ask questions, or re-run the review.
| key: torch.Tensor, | ||
| value: torch.Tensor, | ||
| key_padding_mask: Optional[torch.Tensor] = None, | ||
| need_weights: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing average_attn_weights parameter despite claiming to mirror MultiheadAttention’s forward signature.
Prompt for AI agents
Address the following comment on stream_attention/core/multihead_attention.py at line 67:
<comment>Missing average_attn_weights parameter despite claiming to mirror MultiheadAttention’s forward signature.</comment>
<file context>
@@ -0,0 +1,147 @@
+"""Multihead-style wrapper around FusedOnlineAttention.
+
+This module provides :class:`StreamMultiheadAttention`, mirroring the
+constructor and forward signature of :class:`torch.nn.MultiheadAttention`
+while delegating the actual attention computation to
+:class:`~stream_attention.core.fused_online_attention.FusedOnlineAttention`.
+
+A convenience factory :func:`create_stream_attention` is also provided to
+instantiate the optimal backend automatically. If Triton kernels are
</file context>
| if not self.batch_first: | ||
| attn_output = attn_output.transpose(0, 1) | ||
| if need_weights: | ||
| return attn_output, None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need_weights=True returns None for weights; either compute weights or raise to avoid silent API mismatch.
Prompt for AI agents
Address the following comment on stream_attention/core/multihead_attention.py at line 109:
<comment>need_weights=True returns None for weights; either compute weights or raise to avoid silent API mismatch.</comment>
<file context>
@@ -0,0 +1,147 @@
+"""Multihead-style wrapper around FusedOnlineAttention.
+
+This module provides :class:`StreamMultiheadAttention`, mirroring the
+constructor and forward signature of :class:`torch.nn.MultiheadAttention`
+while delegating the actual attention computation to
+:class:`~stream_attention.core.fused_online_attention.FusedOnlineAttention`.
+
+A convenience factory :func:`create_stream_attention` is also provided to
+instantiate the optimal backend automatically. If Triton kernels are
</file context>
| attn_output = attn_output.transpose(0, 1) | ||
| if need_weights: | ||
| return attn_output, None | ||
| return attn_output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Return type does not match the annotated Tuple when need_weights is False; return a 2-tuple consistently to avoid API inconsistency.
Prompt for AI agents
Address the following comment on stream_attention/core/multihead_attention.py at line 110:
<comment>Return type does not match the annotated Tuple when need_weights is False; return a 2-tuple consistently to avoid API inconsistency.</comment>
<file context>
@@ -0,0 +1,147 @@
+"""Multihead-style wrapper around FusedOnlineAttention.
+
+This module provides :class:`StreamMultiheadAttention`, mirroring the
+constructor and forward signature of :class:`torch.nn.MultiheadAttention`
+while delegating the actual attention computation to
+:class:`~stream_attention.core.fused_online_attention.FusedOnlineAttention`.
+
+A convenience factory :func:`create_stream_attention` is also provided to
+instantiate the optimal backend automatically. If Triton kernels are
</file context>
| query=q, | ||
| key=k, | ||
| value=v, | ||
| causal=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Forces causal masking for all calls, diverging from MultiheadAttention’s default behavior and potentially breaking non-causal use cases.
Prompt for AI agents
Address the following comment on stream_attention/core/multihead_attention.py at line 99:
<comment>Forces causal masking for all calls, diverging from MultiheadAttention’s default behavior and potentially breaking non-causal use cases.</comment>
<file context>
@@ -0,0 +1,147 @@
+"""Multihead-style wrapper around FusedOnlineAttention.
+
+This module provides :class:`StreamMultiheadAttention`, mirroring the
+constructor and forward signature of :class:`torch.nn.MultiheadAttention`
+while delegating the actual attention computation to
+:class:`~stream_attention.core.fused_online_attention.FusedOnlineAttention`.
+
+A convenience factory :func:`create_stream_attention` is also provided to
+instantiate the optimal backend automatically. If Triton kernels are
</file context>
| key=k, | ||
| value=v, | ||
| causal=True, | ||
| attention_mask=attn_mask_inner, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Attention dropout is never applied because dropout_p is not forwarded to FusedOnlineAttention during training.
Prompt for AI agents
Address the following comment on stream_attention/core/multihead_attention.py at line 100:
<comment>Attention dropout is never applied because dropout_p is not forwarded to FusedOnlineAttention during training.</comment>
<file context>
@@ -0,0 +1,147 @@
+"""Multihead-style wrapper around FusedOnlineAttention.
+
+This module provides :class:`StreamMultiheadAttention`, mirroring the
+constructor and forward signature of :class:`torch.nn.MultiheadAttention`
+while delegating the actual attention computation to
+:class:`~stream_attention.core.fused_online_attention.FusedOnlineAttention`.
+
+A convenience factory :func:`create_stream_attention` is also provided to
+instantiate the optimal backend automatically. If Triton kernels are
</file context>
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
|
@cursoragent fix those issues |
|
I updated the
5 tasks completed
Learn more about Cursor Agents |
Co-authored-by: alphacr792 <[email protected]>
Summary
StreamMultiheadAttentionwrappingFusedOnlineAttentionwith ann.MultiheadAttention-like APIcreate_stream_attentionfactory that chooses Triton fused or PyTorch SDPA backendTesting
pytesthttps://chatgpt.com/codex/tasks/task_e_68ab6a1e403c8322b7d0cbace7d428ec
Summary by cubic
Adds a MultiheadAttention-style layer backed by the fused online attention kernel, plus a factory that auto-selects the best backend. This makes it easy to drop in a familiar API while using Triton when available, with a safe SDPA fallback.
New Features
Migration