Skip to content

Conversation

@MagellaX
Copy link
Owner

@MagellaX MagellaX commented Aug 24, 2025

Summary

  • add StreamMultiheadAttention wrapping FusedOnlineAttention with a nn.MultiheadAttention-like API
  • provide create_stream_attention factory that chooses Triton fused or PyTorch SDPA backend
  • document usage via new snippet in README

Testing

  • pytest

https://chatgpt.com/codex/tasks/task_e_68ab6a1e403c8322b7d0cbace7d428ec


Summary by cubic

Adds a MultiheadAttention-style layer backed by the fused online attention kernel, plus a factory that auto-selects the best backend. This makes it easy to drop in a familiar API while using Triton when available, with a safe SDPA fallback.

  • New Features

    • StreamMultiheadAttention: a drop-in nn.MultiheadAttention wrapper using the fused backend; supports batch_first and key_padding_mask (attn_mask not supported).
    • create_stream_attention: returns the fused implementation when Triton + CUDA are available, otherwise falls back to PyTorch’s nn.MultiheadAttention; README updated with a usage snippet.
  • Migration

    • The old config-based factory is now build_stream_attention(config). The name create_stream_attention now refers to the Multihead-style factory. Update any previous create_stream_attention(config) calls to build_stream_attention(config).

Copy link
Contributor

@cubic-dev-ai cubic-dev-ai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 issues found across 6 files

React with 👍 or 👎 to teach cubic. You can also tag @cubic-dev-ai to give feedback, ask questions, or re-run the review.

key: torch.Tensor,
value: torch.Tensor,
key_padding_mask: Optional[torch.Tensor] = None,
need_weights: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing average_attn_weights parameter despite claiming to mirror MultiheadAttention’s forward signature.

Prompt for AI agents
Address the following comment on stream_attention/core/multihead_attention.py at line 67:

<comment>Missing average_attn_weights parameter despite claiming to mirror MultiheadAttention’s forward signature.</comment>

<file context>
@@ -0,0 +1,147 @@
+&quot;&quot;&quot;Multihead-style wrapper around FusedOnlineAttention.
+
+This module provides :class:`StreamMultiheadAttention`, mirroring the
+constructor and forward signature of :class:`torch.nn.MultiheadAttention`
+while delegating the actual attention computation to
+:class:`~stream_attention.core.fused_online_attention.FusedOnlineAttention`.
+
+A convenience factory :func:`create_stream_attention` is also provided to
+instantiate the optimal backend automatically. If Triton kernels are
</file context>

if not self.batch_first:
attn_output = attn_output.transpose(0, 1)
if need_weights:
return attn_output, None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need_weights=True returns None for weights; either compute weights or raise to avoid silent API mismatch.

Prompt for AI agents
Address the following comment on stream_attention/core/multihead_attention.py at line 109:

<comment>need_weights=True returns None for weights; either compute weights or raise to avoid silent API mismatch.</comment>

<file context>
@@ -0,0 +1,147 @@
+&quot;&quot;&quot;Multihead-style wrapper around FusedOnlineAttention.
+
+This module provides :class:`StreamMultiheadAttention`, mirroring the
+constructor and forward signature of :class:`torch.nn.MultiheadAttention`
+while delegating the actual attention computation to
+:class:`~stream_attention.core.fused_online_attention.FusedOnlineAttention`.
+
+A convenience factory :func:`create_stream_attention` is also provided to
+instantiate the optimal backend automatically. If Triton kernels are
</file context>

attn_output = attn_output.transpose(0, 1)
if need_weights:
return attn_output, None
return attn_output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Return type does not match the annotated Tuple when need_weights is False; return a 2-tuple consistently to avoid API inconsistency.

Prompt for AI agents
Address the following comment on stream_attention/core/multihead_attention.py at line 110:

<comment>Return type does not match the annotated Tuple when need_weights is False; return a 2-tuple consistently to avoid API inconsistency.</comment>

<file context>
@@ -0,0 +1,147 @@
+&quot;&quot;&quot;Multihead-style wrapper around FusedOnlineAttention.
+
+This module provides :class:`StreamMultiheadAttention`, mirroring the
+constructor and forward signature of :class:`torch.nn.MultiheadAttention`
+while delegating the actual attention computation to
+:class:`~stream_attention.core.fused_online_attention.FusedOnlineAttention`.
+
+A convenience factory :func:`create_stream_attention` is also provided to
+instantiate the optimal backend automatically. If Triton kernels are
</file context>

query=q,
key=k,
value=v,
causal=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forces causal masking for all calls, diverging from MultiheadAttention’s default behavior and potentially breaking non-causal use cases.

Prompt for AI agents
Address the following comment on stream_attention/core/multihead_attention.py at line 99:

<comment>Forces causal masking for all calls, diverging from MultiheadAttention’s default behavior and potentially breaking non-causal use cases.</comment>

<file context>
@@ -0,0 +1,147 @@
+&quot;&quot;&quot;Multihead-style wrapper around FusedOnlineAttention.
+
+This module provides :class:`StreamMultiheadAttention`, mirroring the
+constructor and forward signature of :class:`torch.nn.MultiheadAttention`
+while delegating the actual attention computation to
+:class:`~stream_attention.core.fused_online_attention.FusedOnlineAttention`.
+
+A convenience factory :func:`create_stream_attention` is also provided to
+instantiate the optimal backend automatically. If Triton kernels are
</file context>

key=k,
value=v,
causal=True,
attention_mask=attn_mask_inner,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Attention dropout is never applied because dropout_p is not forwarded to FusedOnlineAttention during training.

Prompt for AI agents
Address the following comment on stream_attention/core/multihead_attention.py at line 100:

<comment>Attention dropout is never applied because dropout_p is not forwarded to FusedOnlineAttention during training.</comment>

<file context>
@@ -0,0 +1,147 @@
+&quot;&quot;&quot;Multihead-style wrapper around FusedOnlineAttention.
+
+This module provides :class:`StreamMultiheadAttention`, mirroring the
+constructor and forward signature of :class:`torch.nn.MultiheadAttention`
+while delegating the actual attention computation to
+:class:`~stream_attention.core.fused_online_attention.FusedOnlineAttention`.
+
+A convenience factory :func:`create_stream_attention` is also provided to
+instantiate the optimal backend automatically. If Triton kernels are
</file context>

Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
@MagellaX
Copy link
Owner Author

@cursoragent fix those issues

@cursor
Copy link

cursor bot commented Aug 24, 2025

I updated the StreamMultiheadAttention module to support configurable causality and apply dropout, while also clarifying that attention weights are not returned.

  • I made causality configurable in StreamMultiheadAttention.forward by adding an is_causal parameter to scaled_dot_product_attention.
  • I enabled dropout during attention by passing dropout_p=(self.inner.dropout if self.training else 0.0) to F.scaled_dot_product_attention.
  • I enforced need_weights=False by raising NotImplementedError and simplifying the return to (attn_output, None).

5 tasks completed

  • Add average_attn_weights and is_causal to StreamMultiheadAttention.forward
  • Return (attn_output, None) consistently; error on need_weights=True
  • Forward dropout to FusedOnlineAttention; avoid forcing causal
  • Ensure README example unpacks weights correctly
  • Run test suite and fix any failures
Open in Cursor Open in Web

Learn more about Cursor Agents

@MagellaX MagellaX merged commit 5e8e02c into main Aug 24, 2025
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants