diff --git a/README.md b/README.md index c628ca4..a96b009 100644 --- a/README.md +++ b/README.md @@ -92,6 +92,20 @@ print(y_qkv.shape) - Autograd: If gradients are required, the module automatically falls back to PyTorch SDPA to ensure correct backward support. The Triton path is intended for forward-critical inference/benchmarking. - Dropout is not supported in the fused kernel; apply it outside the module if needed. +### Multihead-style wrapper +Use `create_stream_attention` to obtain an attention layer with a familiar +`nn.MultiheadAttention` interface. Triton kernels are used automatically when +available, otherwise PyTorch's SDPA backend is selected: + +```python +import torch +from stream_attention import create_stream_attention + +mha = create_stream_attention(embed_dim=512, num_heads=8, batch_first=True) +x = torch.randn(2, 16, 512) +out, _ = mha(x, x, x) +``` + ### FlashAttentionV3 - Purpose: Baseline using PyTorch SDPA with the flash backend on CUDA, falling back gracefully on CPU. - Signature (selected): diff --git a/stream_attention/__init__.py b/stream_attention/__init__.py index 875162d..7e293bc 100644 --- a/stream_attention/__init__.py +++ b/stream_attention/__init__.py @@ -22,7 +22,8 @@ FusedOnlineAttention, create_fused_online_attention, ) -from .core.attention import StreamAttention, create_stream_attention +from .core.attention import StreamAttention +from .core.multihead_attention import StreamMultiheadAttention, create_stream_attention from .core.flashattention_v3 import FlashAttentionV3 # Utilities @@ -36,6 +37,7 @@ __all__ = [ # Main modules "StreamAttention", + "StreamMultiheadAttention", "FusedOnlineAttention", "StreamAttentionConfig", "FlashAttentionV3", diff --git a/stream_attention/benchmarks/accuracy_test.py b/stream_attention/benchmarks/accuracy_test.py index 5eb3522..92b774b 100644 --- a/stream_attention/benchmarks/accuracy_test.py +++ b/stream_attention/benchmarks/accuracy_test.py @@ -59,3 +59,4 @@ def main(): if __name__ == "__main__": main() + diff --git a/stream_attention/core/__init__.py b/stream_attention/core/__init__.py index 7512f4d..2e1ed13 100644 --- a/stream_attention/core/__init__.py +++ b/stream_attention/core/__init__.py @@ -1,6 +1,7 @@ """Core StreamAttention modules""" -from .attention import StreamAttention, create_stream_attention +from .attention import StreamAttention, build_stream_attention +from .multihead_attention import StreamMultiheadAttention, create_stream_attention from .flashattention_v3 import FlashAttentionV3 from .fused_online_attention import FusedOnlineAttention, create_fused_online_attention from .ring_attention import RingAttention @@ -9,6 +10,7 @@ __all__ = [ "StreamAttention", + "StreamMultiheadAttention", "FlashAttentionV3", "FusedOnlineAttention", "RingAttention", diff --git a/stream_attention/core/attention.py b/stream_attention/core/attention.py index 0b24bc4..c704c57 100644 --- a/stream_attention/core/attention.py +++ b/stream_attention/core/attention.py @@ -198,7 +198,7 @@ def replace_attention_in_model( for part in parent_name.split("."): parent = getattr(parent, part) # Create a fresh instance per replacement - new_attn = create_stream_attention(self.config) + new_attn = StreamAttention(self.config) setattr(parent, attr_name, new_attn) replaced_count += 1 logger.info(f"Replaced {name} with StreamAttention") @@ -244,6 +244,6 @@ def benchmark_speedup( return results -def create_stream_attention(config: StreamAttentionConfig) -> StreamAttention: - """Factory function to create StreamAttention instance""" +def build_stream_attention(config: StreamAttentionConfig) -> StreamAttention: + """Legacy factory for creating :class:`StreamAttention` from a config.""" return StreamAttention(config) diff --git a/stream_attention/core/multihead_attention.py b/stream_attention/core/multihead_attention.py new file mode 100644 index 0000000..ef197ff --- /dev/null +++ b/stream_attention/core/multihead_attention.py @@ -0,0 +1,153 @@ +"""Multihead-style wrapper around FusedOnlineAttention. + +This module provides :class:`StreamMultiheadAttention`, mirroring the +constructor and forward signature of :class:`torch.nn.MultiheadAttention` +while delegating the actual attention computation to +:class:`~stream_attention.core.fused_online_attention.FusedOnlineAttention`. + +A convenience factory :func:`create_stream_attention` is also provided to +instantiate the optimal backend automatically. If Triton kernels are +available, the fused implementation is used; otherwise it falls back to +PyTorch's native ``nn.MultiheadAttention`` which relies on SDPA. +""" + +from __future__ import annotations + +from typing import Optional, Tuple + +import torch +import torch.nn as nn + +from .fused_online_attention import FusedOnlineAttention, TRITON_AVAILABLE + + +class StreamMultiheadAttention(nn.Module): + """Drop-in replacement for :class:`torch.nn.MultiheadAttention`. + + Parameters mirror those of ``nn.MultiheadAttention`` but the module uses + :class:`FusedOnlineAttention` internally when available. + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + bias: bool = True, + batch_first: bool = False, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> None: + super().__init__() + assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.batch_first = batch_first + + factory_kwargs = {"device": device, "dtype": dtype} + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) + + self.inner = FusedOnlineAttention( + num_heads=num_heads, + head_dim=self.head_dim, + dropout=dropout, + dtype=dtype or torch.float32, + ) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_padding_mask: Optional[torch.Tensor] = None, + need_weights: bool = False, + attn_mask: Optional[torch.Tensor] = None, + average_attn_weights: bool = True, + is_causal: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Compute attention. + + Arguments are equivalent to ``nn.MultiheadAttention``. Masks other than + ``key_padding_mask`` are currently unsupported and will raise. + """ + + if attn_mask is not None: + raise NotImplementedError( + "attn_mask is not supported in StreamMultiheadAttention" + ) + + if need_weights: + raise NotImplementedError( + "need_weights=True is not supported; set need_weights=False (default)" + ) + + if not self.batch_first: + # (L, N, E) -> (N, L, E) + query = query.transpose(0, 1) + key = key.transpose(0, 1) + value = value.transpose(0, 1) + + bsz, tgt_len, _ = query.shape + src_len = key.shape[1] + + q = self.q_proj(query).view(bsz, tgt_len, self.num_heads, self.head_dim) + k = self.k_proj(key).view(bsz, src_len, self.num_heads, self.head_dim) + v = self.v_proj(value).view(bsz, src_len, self.num_heads, self.head_dim) + + attn_mask_inner = key_padding_mask if key_padding_mask is not None else None + attn_output = self.inner( + query=q, + key=k, + value=v, + causal=bool(is_causal), + attention_mask=attn_mask_inner, + dropout_p=(self.inner.dropout if self.training else 0.0), + ) + + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output = self.out_proj(attn_output) + + if not self.batch_first: + attn_output = attn_output.transpose(0, 1) + return attn_output, None + + +def create_stream_attention( + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + bias: bool = True, + batch_first: bool = False, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, +) -> nn.Module: + """Factory that returns an attention module using the best available backend. + + If Triton fused kernels are available, returns :class:`StreamMultiheadAttention`. + Otherwise, falls back to PyTorch's ``nn.MultiheadAttention`` which utilises + SDPA under the hood. + """ + + if TRITON_AVAILABLE and torch.cuda.is_available(): + return StreamMultiheadAttention( + embed_dim, + num_heads, + dropout=dropout, + bias=bias, + batch_first=batch_first, + device=device, + dtype=dtype, + ) + return nn.MultiheadAttention( + embed_dim, + num_heads, + dropout=dropout, + bias=bias, + batch_first=batch_first, + device=device, + dtype=dtype, + )