Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,20 @@ print(y_qkv.shape)
- Autograd: If gradients are required, the module automatically falls back to PyTorch SDPA to ensure correct backward support. The Triton path is intended for forward-critical inference/benchmarking.
- Dropout is not supported in the fused kernel; apply it outside the module if needed.

### Multihead-style wrapper
Use `create_stream_attention` to obtain an attention layer with a familiar
`nn.MultiheadAttention` interface. Triton kernels are used automatically when
available, otherwise PyTorch's SDPA backend is selected:

```python
import torch
from stream_attention import create_stream_attention

mha = create_stream_attention(embed_dim=512, num_heads=8, batch_first=True)
x = torch.randn(2, 16, 512)
out = mha(x, x, x)
```

### FlashAttentionV3
- Purpose: Baseline using PyTorch SDPA with the flash backend on CUDA, falling back gracefully on CPU.
- Signature (selected):
Expand Down
4 changes: 3 additions & 1 deletion stream_attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
FusedOnlineAttention,
create_fused_online_attention,
)
from .core.attention import StreamAttention, create_stream_attention
from .core.attention import StreamAttention
from .core.multihead_attention import StreamMultiheadAttention, create_stream_attention
from .core.flashattention_v3 import FlashAttentionV3

# Utilities
Expand All @@ -36,6 +37,7 @@
__all__ = [
# Main modules
"StreamAttention",
"StreamMultiheadAttention",
"FusedOnlineAttention",
"StreamAttentionConfig",
"FlashAttentionV3",
Expand Down
4 changes: 0 additions & 4 deletions stream_attention/benchmarks/accuracy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,4 @@ def main():


if __name__ == "__main__":

main()

main()

4 changes: 3 additions & 1 deletion stream_attention/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Core StreamAttention modules"""

from .attention import StreamAttention, create_stream_attention
from .attention import StreamAttention, build_stream_attention
from .multihead_attention import StreamMultiheadAttention, create_stream_attention
from .flashattention_v3 import FlashAttentionV3
from .fused_online_attention import FusedOnlineAttention, create_fused_online_attention
from .ring_attention import RingAttention
Expand All @@ -9,6 +10,7 @@

__all__ = [
"StreamAttention",
"StreamMultiheadAttention",
"FlashAttentionV3",
"FusedOnlineAttention",
"RingAttention",
Expand Down
6 changes: 3 additions & 3 deletions stream_attention/core/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def replace_attention_in_model(
for part in parent_name.split("."):
parent = getattr(parent, part)
# Create a fresh instance per replacement
new_attn = create_stream_attention(self.config)
new_attn = StreamAttention(self.config)
setattr(parent, attr_name, new_attn)
replaced_count += 1
logger.info(f"Replaced {name} with StreamAttention")
Expand Down Expand Up @@ -244,6 +244,6 @@ def benchmark_speedup(
return results


def create_stream_attention(config: StreamAttentionConfig) -> StreamAttention:
"""Factory function to create StreamAttention instance"""
def build_stream_attention(config: StreamAttentionConfig) -> StreamAttention:
"""Legacy factory for creating :class:`StreamAttention` from a config."""
return StreamAttention(config)
147 changes: 147 additions & 0 deletions stream_attention/core/multihead_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
"""Multihead-style wrapper around FusedOnlineAttention.
This module provides :class:`StreamMultiheadAttention`, mirroring the
constructor and forward signature of :class:`torch.nn.MultiheadAttention`
while delegating the actual attention computation to
:class:`~stream_attention.core.fused_online_attention.FusedOnlineAttention`.
A convenience factory :func:`create_stream_attention` is also provided to
instantiate the optimal backend automatically. If Triton kernels are
available, the fused implementation is used; otherwise it falls back to
PyTorch's native ``nn.MultiheadAttention`` which relies on SDPA.
"""

from __future__ import annotations

from typing import Optional, Tuple

import torch
import torch.nn as nn

from .fused_online_attention import FusedOnlineAttention, TRITON_AVAILABLE


class StreamMultiheadAttention(nn.Module):
"""Drop-in replacement for :class:`torch.nn.MultiheadAttention`.
Parameters mirror those of ``nn.MultiheadAttention`` but the module uses
:class:`FusedOnlineAttention` internally when available.
"""

def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
bias: bool = True,
batch_first: bool = False,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
super().__init__()
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.batch_first = batch_first

factory_kwargs = {"device": device, "dtype": dtype}
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)

self.inner = FusedOnlineAttention(
num_heads=num_heads,
head_dim=self.head_dim,
dropout=dropout,
dtype=dtype or torch.float32,
)

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_padding_mask: Optional[torch.Tensor] = None,
need_weights: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing average_attn_weights parameter despite claiming to mirror MultiheadAttention’s forward signature.

Prompt for AI agents
Address the following comment on stream_attention/core/multihead_attention.py at line 67:

<comment>Missing average_attn_weights parameter despite claiming to mirror MultiheadAttention’s forward signature.</comment>

<file context>
@@ -0,0 +1,147 @@
+&quot;&quot;&quot;Multihead-style wrapper around FusedOnlineAttention.
+
+This module provides :class:`StreamMultiheadAttention`, mirroring the
+constructor and forward signature of :class:`torch.nn.MultiheadAttention`
+while delegating the actual attention computation to
+:class:`~stream_attention.core.fused_online_attention.FusedOnlineAttention`.
+
+A convenience factory :func:`create_stream_attention` is also provided to
+instantiate the optimal backend automatically. If Triton kernels are
</file context>

attn_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Compute attention.
Arguments are equivalent to ``nn.MultiheadAttention``. Masks other than
``key_padding_mask`` are currently unsupported and will raise.
"""

if attn_mask is not None:
raise NotImplementedError(
"attn_mask is not supported in StreamMultiheadAttention"
)

if not self.batch_first:
# (L, N, E) -> (N, L, E)
query = query.transpose(0, 1)
key = key.transpose(0, 1)
value = value.transpose(0, 1)

bsz, tgt_len, _ = query.shape
src_len = key.shape[1]

q = self.q_proj(query).view(bsz, tgt_len, self.num_heads, self.head_dim)
k = self.k_proj(key).view(bsz, src_len, self.num_heads, self.head_dim)
v = self.v_proj(value).view(bsz, src_len, self.num_heads, self.head_dim)

attn_mask_inner = key_padding_mask if key_padding_mask is not None else None
attn_output = self.inner(
query=q,
key=k,
value=v,
causal=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forces causal masking for all calls, diverging from MultiheadAttention’s default behavior and potentially breaking non-causal use cases.

Prompt for AI agents
Address the following comment on stream_attention/core/multihead_attention.py at line 99:

<comment>Forces causal masking for all calls, diverging from MultiheadAttention’s default behavior and potentially breaking non-causal use cases.</comment>

<file context>
@@ -0,0 +1,147 @@
+&quot;&quot;&quot;Multihead-style wrapper around FusedOnlineAttention.
+
+This module provides :class:`StreamMultiheadAttention`, mirroring the
+constructor and forward signature of :class:`torch.nn.MultiheadAttention`
+while delegating the actual attention computation to
+:class:`~stream_attention.core.fused_online_attention.FusedOnlineAttention`.
+
+A convenience factory :func:`create_stream_attention` is also provided to
+instantiate the optimal backend automatically. If Triton kernels are
</file context>

attention_mask=attn_mask_inner,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Attention dropout is never applied because dropout_p is not forwarded to FusedOnlineAttention during training.

Prompt for AI agents
Address the following comment on stream_attention/core/multihead_attention.py at line 100:

<comment>Attention dropout is never applied because dropout_p is not forwarded to FusedOnlineAttention during training.</comment>

<file context>
@@ -0,0 +1,147 @@
+&quot;&quot;&quot;Multihead-style wrapper around FusedOnlineAttention.
+
+This module provides :class:`StreamMultiheadAttention`, mirroring the
+constructor and forward signature of :class:`torch.nn.MultiheadAttention`
+while delegating the actual attention computation to
+:class:`~stream_attention.core.fused_online_attention.FusedOnlineAttention`.
+
+A convenience factory :func:`create_stream_attention` is also provided to
+instantiate the optimal backend automatically. If Triton kernels are
</file context>

)

attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)

if not self.batch_first:
attn_output = attn_output.transpose(0, 1)
if need_weights:
return attn_output, None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need_weights=True returns None for weights; either compute weights or raise to avoid silent API mismatch.

Prompt for AI agents
Address the following comment on stream_attention/core/multihead_attention.py at line 109:

<comment>need_weights=True returns None for weights; either compute weights or raise to avoid silent API mismatch.</comment>

<file context>
@@ -0,0 +1,147 @@
+&quot;&quot;&quot;Multihead-style wrapper around FusedOnlineAttention.
+
+This module provides :class:`StreamMultiheadAttention`, mirroring the
+constructor and forward signature of :class:`torch.nn.MultiheadAttention`
+while delegating the actual attention computation to
+:class:`~stream_attention.core.fused_online_attention.FusedOnlineAttention`.
+
+A convenience factory :func:`create_stream_attention` is also provided to
+instantiate the optimal backend automatically. If Triton kernels are
</file context>

return attn_output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Return type does not match the annotated Tuple when need_weights is False; return a 2-tuple consistently to avoid API inconsistency.

Prompt for AI agents
Address the following comment on stream_attention/core/multihead_attention.py at line 110:

<comment>Return type does not match the annotated Tuple when need_weights is False; return a 2-tuple consistently to avoid API inconsistency.</comment>

<file context>
@@ -0,0 +1,147 @@
+&quot;&quot;&quot;Multihead-style wrapper around FusedOnlineAttention.
+
+This module provides :class:`StreamMultiheadAttention`, mirroring the
+constructor and forward signature of :class:`torch.nn.MultiheadAttention`
+while delegating the actual attention computation to
+:class:`~stream_attention.core.fused_online_attention.FusedOnlineAttention`.
+
+A convenience factory :func:`create_stream_attention` is also provided to
+instantiate the optimal backend automatically. If Triton kernels are
</file context>



def create_stream_attention(
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
bias: bool = True,
batch_first: bool = False,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> nn.Module:
"""Factory that returns an attention module using the best available backend.
If Triton fused kernels are available, returns :class:`StreamMultiheadAttention`.
Otherwise, falls back to PyTorch's ``nn.MultiheadAttention`` which utilises
SDPA under the hood.
"""

if TRITON_AVAILABLE and torch.cuda.is_available():
return StreamMultiheadAttention(
embed_dim,
num_heads,
dropout=dropout,
bias=bias,
batch_first=batch_first,
device=device,
dtype=dtype,
)
return nn.MultiheadAttention(
embed_dim,
num_heads,
dropout=dropout,
bias=bias,
batch_first=batch_first,
device=device,
dtype=dtype,
)