diff --git a/README.md b/README.md index ab3a7cd..b908b44 100644 --- a/README.md +++ b/README.md @@ -11,10 +11,12 @@ This repository provides: - Optional integration helpers for Hugging Face models ## Recent Updates +## Recent Updates + +- Added a single-sweep Triton backward path (streaming dQ/dK/dV using saved `lse`) when `dropout_p == 0`. +- Triton forward now supports boolean/additive masks, dropout, deterministic Philox seeding, and ALiBi bias without SDPA fallback. +- Expanded tests/docs covering mask/dropout/ALiBi parity plus deterministic mode usage. -- Updated `FlashAttentionV3` wrapper to use PyTorch 2.8 `sdpa_kernel` API with a fallback for older releases. -- Restored the Triton fused kernel and added a parity test against FlashAttention‑3. -- Normalized indentation in the Triton kernel to avoid `TabError` on Colab and other environments. ## Installation @@ -61,7 +63,7 @@ with torch.no_grad(): y = attention(hidden_states=x) print(y.shape) -# Explicit Q,K,V path (supports attention mask via SDPA fallback) +# Explicit Q,K,V path (supports attn_mask/dropout/ALiBi in Triton, falling back to SDPA when needed) q = torch.randn(batch_size, seq_len, config.num_heads, config.head_dim, device=x.device, dtype=x.dtype) k = torch.randn_like(q) v = torch.randn_like(q) @@ -79,18 +81,10 @@ print(y_qkv.shape) ### StreamAttention - Purpose: High-level module. Accepts either `hidden_states` ([B, T, H*D]) or explicit `(query, key, value)` ([B, T, H, D]). - Signature (selected): - - `forward(hidden_states: Tensor, ..., use_cache: bool=False, causal: bool=True)` → `Tensor` or `(Tensor, (k, v))` if `use_cache=True` - - `forward(query: Tensor, key: Tensor, value: Tensor, causal: bool=True, attention_mask: Optional[Tensor]=None)` → `Tensor` -- Shapes: `[batch, seq, heads, dim]` for QKV mode. -- Dtypes: fp16/bf16 (CUDA), fp32 (CPU by default). On CPU, inputs upcast to fp32 if required. - -### FusedOnlineAttention -- Purpose: Low-level fused online softmax attention (Triton when available; SDPA fallback otherwise). -- Signature (selected): - - `forward(query, key, value, causal: bool=True, return_lse: bool=False, attention_mask: Optional[Tensor]=None)` → `Tensor` (and `lse` if requested) - - `benchmark(seq_len: int, batch_size: int=1, warmup: int=10, iterations: int=100)` → metrics dict -- Autograd: If gradients are required, the module automatically falls back to PyTorch SDPA to ensure correct backward support. The Triton path is intended for forward-critical inference/benchmarking. -- Dropout is not supported in the fused kernel; apply it outside the module if needed. + - `forward(query, key, value, causal: bool=True, return_lse: bool=False, attention_mask: Optional[Tensor]=None, dropout_p: float=0.0, alibi_slopes: Optional[Tensor]=None, deterministic: Optional[bool]=None)` -> `Tensor` (and `lse` if requested) + - `benchmark(seq_len: int, batch_size: int=1, warmup: int=10, iterations: int=100)` -> metrics dict + - `set_deterministic(enabled: bool, seed: Optional[int]=None)` -> control deterministic dropout/mask behavior +- Autograd: When gradients are required and `dropout_p == 0`, the Triton kernel executes a single-sweep backward pass (streaming dQ/dK/dV using the saved `lse`). If dropout is enabled during training, the module falls back to PyTorch SDPA for gradient computation. ### Multihead-style wrapper Use `create_stream_attention` to obtain an attention layer with a familiar @@ -138,7 +132,7 @@ stream-attention-test --seq 1024 --batch 2 --heads 8 --dim 64 --dtype fp16 Behavior and methodology: - On CUDA, the baseline uses PyTorch SDPA with the flash backend (FlashAttention-3 path). On CPU, both implementations use SDPA in fp32. - Metrics reported: per-iteration latency, estimated TFLOPS, and approximate bandwidth based on tensor traffic. Measurements are averaged after warmup. -- The fused kernel uses Triton on CUDA for the forward pass; when gradients are required, it falls back to an SDPA-backed autograd path. Otherwise, SDPA is used to ensure correctness, masking, and training. +- The fused kernel uses Triton on CUDA for the forward pass; when gradients are required and `dropout_p == 0`, the streaming backward (single sweep over K/V) is invoked. If dropout is active during training, the module falls back to SDPA for gradient computation. - For reproducibility, fix random seeds, pin CUDA clocks if applicable, and isolate runs. Actual performance depends on GPU architecture, drivers, and PyTorch/Triton versions. Example output (format): @@ -227,9 +221,8 @@ print(f"Replaced {num_replaced} attention modules") ## Roadmap -- Backward implementation for the Triton fused kernel +- Native RoPE / relative positional bias fusion in the Triton kernel (forward + backward) - Advanced pipelining (warp specialization, asynchronous staging) and Hopper-specific paths (WGMMA/TMA) -- Full support for attention masks in the fused kernel - Extended autotune coverage across architectures and sequence regimes - Optional FP8 path with block-wise scaling @@ -240,7 +233,6 @@ print(f"Replaced {num_replaced} attention modules") - Accuracy checks: `stream-attention-test` CLI - Examples: `examples/` directory includes basic usage, integrations, and long-context runs - Environment variables: see `.env.example`; `StreamAttentionConfig.from_env()` can bootstrap configuration - - Environment variables: see `.env.example`; `StreamAttentionConfig.from_env()` can bootstrap configuration ## License diff --git a/docs/Index.md b/docs/Index.md index e35ce40..ceac140 100644 --- a/docs/Index.md +++ b/docs/Index.md @@ -56,9 +56,9 @@ Use log-sum-exp recentering: The chosen tile size (`TILE_K = 64`) is a starting point. Optimal performance may require experimenting with different tile sizes based on your hardware. If performance is suboptimal, use profiling tools (e.g., NVIDIA Nsight compute) to identify and resolve bottlenecks. -### Triton Limitations +### Triton Notes -Currently, Triton does not expose `cp.async`. This implementation relies on `tl.load` with masking and autotuned tile sizes. The module automatically falls back to PyTorch SDPA for autograd or masking. +Currently, Triton does not expose `cp.async`. This implementation relies on `tl.load` with masking and autotuned tile sizes. The fused forward supports native boolean/additive attention masks, dropout, and ALiBi biasing. Deterministic mode (`set_deterministic`) seeds the Philox stream so dropout/mask sampling is reproducible. When `dropout_p == 0`, the same saved `lse` is reused to run a single-sweep backward pass; otherwise we fall back to PyTorch SDPA for gradients. ### Distributed Setup Issues diff --git a/stream_attention/core/fused_online_attention.py b/stream_attention/core/fused_online_attention.py index 5cec616..5355785 100644 --- a/stream_attention/core/fused_online_attention.py +++ b/stream_attention/core/fused_online_attention.py @@ -73,7 +73,6 @@ def fused_online_attention_kernel( V, Out, Lse, # Log-sum-exp for numerical stability - # Optional key padding mask [B, N] (bool as int32 1:valid,0:masked) Mask, stride_qb, stride_qh, @@ -95,7 +94,17 @@ def fused_online_attention_kernel( stride_lh, stride_lm, stride_mb, + stride_mh, + stride_mm, stride_mn, + dropout_p, + dropout_scale, + rng_seed, + rng_offset, + AlibiSlopes, + global_M, + global_N, + q_start, H: tl.constexpr, # num heads M: tl.constexpr, # seq_len_q N: tl.constexpr, # seq_len_k @@ -106,6 +115,8 @@ def fused_online_attention_kernel( scale: tl.constexpr, IS_CAUSAL: tl.constexpr, HAS_MASK: tl.constexpr, + HAS_DROPOUT: tl.constexpr, + HAS_ALIBI: tl.constexpr, USE_WGMMA: tl.constexpr, USE_TMA: tl.constexpr, USE_CP_ASYNC: tl.constexpr, @@ -179,11 +190,22 @@ def fused_online_attention_kernel( causal_mask = offs_m[:, None] >= (start_n + offs_n)[None, :] qk = tl.where(causal_mask, qk, float("-inf")) - # Key padding mask: valid_k=1 means keep; 0 means masked out if HAS_MASK: - mask_ptrs = Mask + off_b * stride_mb + (start_n + offs_n) * stride_mn - valid_k = tl.load(mask_ptrs, mask=(start_n + offs_n) < N, other=0) - qk = tl.where(valid_k[None, :] != 0, qk, float("-inf")) + mask_ptrs = ( + Mask + + off_b * stride_mb + + off_h * stride_mh + + (offs_m[:, None] * stride_mm + (start_n + offs_n)[None, :] * stride_mn) + ) + mask_mask = (offs_m[:, None] < M) & ((start_n + offs_n)[None, :] < N) + mask_vals = tl.load(mask_ptrs, mask=mask_mask, other=0.0) + qk += mask_vals + + if HAS_ALIBI: + slope = tl.load(AlibiSlopes + off_h).to(tl.float32) + q_pos = (offs_m[:, None] + q_start).to(tl.float32) + k_pos = (start_n + offs_n)[None, :].to(tl.float32) + qk += slope * (k_pos - q_pos) # Online softmax update tile_max = tl.max(qk, axis=1) @@ -193,13 +215,25 @@ def fused_online_attention_kernel( acc_den *= correction exp_qk = tl.exp(qk - new_max[:, None]) + + if HAS_DROPOUT: + bh = off_b * H + off_h + row_global = (offs_m[:, None] + q_start) + col_global = (start_n + offs_n)[None, :] + rng_offsets = ( + (bh * global_M + row_global) * global_N + col_global + rng_offset + ).to(tl.int32) + keep = tl.rand(rng_seed, rng_offsets) > dropout_p + exp_qk = exp_qk * keep.to(exp_qk.dtype) * dropout_scale + acc_num += tl.dot(exp_qk, v) acc_den += tl.sum(exp_qk, axis=1) running_max = new_max # Final output with safe denominator; handle rows with all keys masked - denom_safe = tl.where(acc_den > 0, acc_den, 1.0) - out = acc_num / denom_safe[:, None] + zero_den = acc_den == 0 + inv_den = tl.where(zero_den, 0.0, 1.0 / acc_den) + out = acc_num * inv_den[:, None] out_ptrs = ( Out @@ -211,7 +245,7 @@ def fused_online_attention_kernel( tl.store(out_ptrs, out.to(Out.dtype.element_ty), mask=out_mask) # LSE: set to -inf for fully masked rows (acc_den == 0) - lse = tl.where(acc_den > 0, running_max + tl.log(acc_den), float("-inf")) + lse = tl.where(zero_den, float("-inf"), running_max + tl.log(acc_den)) lse_ptrs = Lse + off_b * stride_lb + off_h * stride_lh + offs_m * stride_lm lse_mask = offs_m < M tl.store(lse_ptrs, lse, mask=lse_mask) @@ -244,6 +278,9 @@ def __init__( torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") ) self.dtype = dtype + self.deterministic = False + self._det_seed: Optional[int] = None + self._det_offset: int = 0 if self.device.type == "cuda": cap = torch.cuda.get_device_capability(self.device) self.sm = cap[0] * 10 + cap[1] @@ -256,6 +293,18 @@ def __init__( f"FusedOnlineAttention initialized: heads={num_heads}, dim={head_dim}, tile_q={tile_size_q}, tile_k={tile_size_k}, world_size={self.world_size}, sm={self.sm}, triton={TRITON_AVAILABLE}" ) + def set_deterministic(self, enabled: bool, seed: Optional[int] = None): + """Enable/disable deterministic mode for Triton dropout RNG.""" + self.deterministic = enabled + if enabled: + if seed is None: + seed = torch.initial_seed() + self._det_seed = int(seed & 0xFFFFFFFF) + self._det_offset = 0 + else: + self._det_seed = None + self._det_offset = 0 + def forward( self, query: torch.Tensor, @@ -265,12 +314,24 @@ def forward( return_lse: bool = False, attention_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, + alibi_slopes: Optional[torch.Tensor] = None, + deterministic: Optional[bool] = None, ) -> torch.Tensor: batch_size, seq_len_q, num_heads_q, head_dim_q = query.shape _, seq_len_k, num_heads_k, head_dim_k = key.shape assert num_heads_q == num_heads_k == self.num_heads assert head_dim_q == head_dim_k == self.head_dim + if alibi_slopes is not None: + if not isinstance(alibi_slopes, torch.Tensor): + raise ValueError('alibi_slopes must be a Tensor of shape [num_heads]') + if alibi_slopes.numel() != self.num_heads: + raise ValueError( + f'alibi_slopes must have length {self.num_heads}, got {alibi_slopes.numel()}' + ) + + full_seq_len_q = seq_len_q + start_idx = 0 if self.world_size > 1: queries_per_gpu = seq_len_q // self.world_size start_idx = self.rank * queries_per_gpu @@ -280,6 +341,9 @@ def forward( query = query[:, start_idx:end_idx] seq_len_q = query.shape[1] + effective_dropout = dropout_p if self.training else 0.0 + deterministic_mode = self.deterministic if deterministic is None else deterministic + mask_supported = (attention_mask is None) or ( attention_mask.dim() == 2 and attention_mask.shape[0] == batch_size @@ -292,25 +356,30 @@ def forward( and key.is_cuda and value.is_cuda and mask_supported - and (dropout_p == 0.0 or not self.training) ) - if use_triton and ( + requires_grad = ( torch.is_grad_enabled() and (query.requires_grad or key.requires_grad or value.requires_grad) - ): - if return_lse: + ) + + if use_triton and requires_grad: + if effective_dropout != 0.0 or return_lse: use_triton = False else: - return FusedOnlineAttentionAutogradFn.apply( + output = FusedOnlineAttentionFunction.apply( self, query, key, value, bool(causal), attention_mask, - float(dropout_p), + alibi_slopes, + deterministic_mode, + full_seq_len_q, + start_idx, ) + return output if use_triton: return self._forward_triton( @@ -319,7 +388,11 @@ def forward( value, causal=causal, attention_mask=attention_mask, - dropout_p=dropout_p, + dropout_p=effective_dropout, + alibi_slopes=alibi_slopes, + deterministic_mode=deterministic_mode, + full_seq_len_q=full_seq_len_q, + q_start=start_idx, return_lse=return_lse, ) else: @@ -334,76 +407,78 @@ def forward( batch_size * self.num_heads, seq_len_k, self.head_dim ) - attn_mask_bh = None - if attention_mask is not None: - attn_mask_bh = self._prepare_attn_mask( - attention_mask, - batch_size, - self.num_heads, - seq_len_q, - seq_len_k, - q.device, - q.dtype, - ) - - if attn_mask_bh is not None: - if attn_mask_bh.dtype == torch.bool: - add_mask = torch.where( - attn_mask_bh, - torch.full((1,), float("-inf"), dtype=q.dtype, device=q.device), - torch.zeros(1, dtype=q.dtype, device=q.device), - ) - else: - add_mask = attn_mask_bh - if causal: - tri = torch.triu( - torch.ones( - seq_len_q, seq_len_k, dtype=torch.bool, device=q.device - ), - diagonal=1, - ).unsqueeze(0) - tri_add = torch.where( - tri, - torch.full((1,), float("-inf"), dtype=q.dtype, device=q.device), - torch.zeros(1, dtype=q.dtype, device=q.device), - ) - add_mask = add_mask + tri_add - sdpa_kwargs = dict( - attn_mask=add_mask, - is_causal=False, - dropout_p=(dropout_p if self.training else 0.0), + add_mask = None + attn_mask_bh = None + if attention_mask is not None: + attn_mask_bh = self._prepare_attn_mask( + attention_mask, + batch_size, + self.num_heads, + seq_len_q, + seq_len_k, + q.device, + q.dtype, + ) + if attn_mask_bh.dtype == torch.bool: + add_mask = torch.where( + attn_mask_bh, + torch.full((1,), float('-inf'), dtype=q.dtype, device=q.device), + torch.zeros(1, dtype=q.dtype, device=q.device), ) else: - sdpa_kwargs = dict( - attn_mask=None, - is_causal=causal, - dropout_p=(dropout_p if self.training else 0.0), - ) + add_mask = attn_mask_bh.to(q.dtype) + + if alibi_slopes is not None: + slopes = alibi_slopes.to(q.device, dtype=torch.float32) + pos_q = torch.arange(seq_len_q, device=q.device, dtype=torch.float32) + pos_k = torch.arange(seq_len_k, device=q.device, dtype=torch.float32) + delta = pos_k.unsqueeze(0) - pos_q.unsqueeze(1) + bias_h = slopes.view(self.num_heads, 1, 1) * delta + bias_bh = bias_h.unsqueeze(0).expand(batch_size, self.num_heads, seq_len_q, seq_len_k) + bias_bh = bias_bh.reshape(batch_size * self.num_heads, seq_len_q, seq_len_k).to(q.dtype) + add_mask = bias_bh if add_mask is None else add_mask + bias_bh + + is_causal = causal + if add_mask is not None and causal: + tri = torch.triu( + torch.ones(seq_len_q, seq_len_k, dtype=torch.bool, device=q.device), + diagonal=1, + ).unsqueeze(0) + tri = tri.expand(batch_size * self.num_heads, seq_len_q, seq_len_k) + tri_add = torch.where( + tri, + torch.full((1,), float('-inf'), dtype=q.dtype, device=q.device), + torch.zeros(1, dtype=q.dtype, device=q.device), + ) + add_mask = add_mask + tri_add + is_causal = False + + sdpa_kwargs = dict(attn_mask=add_mask, is_causal=is_causal, dropout_p=effective_dropout) - sdpa_ctx = nullcontext() - if q.is_cuda: + sdpa_ctx = nullcontext() + if q.is_cuda: + try: + sdpa_ctx = torch.nn.attention.sdpa_kernel( + SDPBackend.FLASH_ATTENTION + ) + except (AttributeError, TypeError): try: - sdpa_ctx = torch.nn.attention.sdpa_kernel( - SDPBackend.FLASH_ATTENTION + sdpa_ctx = torch.backends.cuda.sdp_kernel( + enable_math=True, + enable_flash=True, + enable_mem_efficient=False, ) - except (AttributeError, TypeError): - try: - sdpa_ctx = torch.backends.cuda.sdp_kernel( - enable_math=True, - enable_flash=True, - enable_mem_efficient=False, - ) - except Exception: # pragma: no cover - environment dependent - sdpa_ctx = nullcontext() - with sdpa_ctx: - out = F.scaled_dot_product_attention(q, k, v, **sdpa_kwargs) - - out = ( - out.reshape(batch_size, self.num_heads, seq_len_q, self.head_dim) - .permute(0, 2, 1, 3) - .contiguous() - ) - return (out, None) if return_lse else out + except Exception: # pragma: no cover - environment dependent + sdpa_ctx = nullcontext() + with sdpa_ctx: + out = F.scaled_dot_product_attention(q, k, v, **sdpa_kwargs) + + out = ( + out.reshape(batch_size, self.num_heads, seq_len_q, self.head_dim) + .permute(0, 2, 1, 3) + .contiguous() + ) + return (out, None) if return_lse else out def _prepare_attn_mask( self, @@ -441,6 +516,32 @@ def _prepare_attn_mask( ).to(device) return bh_mask + def _prepare_triton_mask( + self, + attention_mask: torch.Tensor, + batch_size: int, + num_heads: int, + seq_len_q: int, + seq_len_k: int, + device: torch.device, + ) -> torch.Tensor: + mask = attention_mask + if mask.dim() == 2: + mask = mask[:, None, None, :] + elif mask.dim() == 3: + mask = mask[:, None, :, :] + elif mask.dim() != 4: + raise ValueError("Unsupported attention_mask shape. Expected 2D, 3D, or 4D tensor.") + + mask = mask.to(device) + mask = mask.expand(batch_size, num_heads, seq_len_q, seq_len_k).contiguous() + if mask.dtype == torch.bool: + mask = mask.to(torch.float32) + mask = mask.masked_fill(mask > 0, float('-inf')) + else: + mask = mask.to(torch.float32) + return mask + def _forward_triton( self, query: torch.Tensor, @@ -449,41 +550,82 @@ def _forward_triton( causal: bool, attention_mask: Optional[torch.Tensor], dropout_p: float, + alibi_slopes: Optional[torch.Tensor], + deterministic_mode: bool, + full_seq_len_q: int, + q_start: int, return_lse: bool = False, ): batch_size, seq_len_q = query.shape[0], query.shape[1] seq_len_k = key.shape[1] + output = torch.empty_like(query) lse = torch.empty( (batch_size, self.num_heads, seq_len_q), dtype=torch.float32, device=query.device, ) + grid = lambda meta: ( triton.cdiv(seq_len_q, meta["TILE_M"]), batch_size, self.num_heads, ) + + mask_ptr = query + stride_mb = stride_mh = stride_mm = stride_mn = 0 has_mask = attention_mask is not None if has_mask: - if attention_mask.dtype == torch.bool: - mask_norm = (~attention_mask).to(torch.int32) - else: - mask_norm = (attention_mask == 0).to(torch.int32) - if ( - mask_norm.dim() != 2 - or mask_norm.shape[0] != batch_size - or mask_norm.shape[1] != seq_len_k - ): + mask_tensor = self._prepare_triton_mask( + attention_mask, + batch_size, + self.num_heads, + seq_len_q, + seq_len_k, + query.device, + ) + mask_ptr = mask_tensor + stride_mb, stride_mh, stride_mm, stride_mn = mask_ptr.stride() + + has_alibi = alibi_slopes is not None + if has_alibi: + if not isinstance(alibi_slopes, torch.Tensor): + raise ValueError("alibi_slopes must be a Tensor of shape [num_heads]") + if alibi_slopes.numel() != self.num_heads: raise ValueError( - "attention_mask must be shape [batch, seq_len_k] for fused Triton path" + f"alibi_slopes must have length {self.num_heads}, got {alibi_slopes.numel()}" ) - mask_norm = mask_norm.contiguous().to(query.device) - mask_ptr = mask_norm - stride_mb, stride_mn = mask_ptr.stride(0), mask_ptr.stride(1) + alibi_ptr = alibi_slopes.to(query.device, dtype=torch.float32).contiguous() else: - mask_ptr = output # dummy - stride_mb, stride_mn = 0, 0 + alibi_ptr = query + + has_dropout = dropout_p > 0.0 + if has_dropout: + if deterministic_mode: + if self._det_seed is None: + self._det_seed = int(torch.initial_seed() & 0xFFFFFFFF) + self._det_offset = 0 + rng_seed = self._det_seed + rng_offset = self._det_offset + consumed = batch_size * self.num_heads * full_seq_len_q * seq_len_k + self._det_offset += consumed + else: + rng_seed = int( + torch.randint( + 0, + 2**31, + (1,), + device=query.device, + dtype=torch.int64, + ).item() + ) + rng_offset = 0 + dropout_scale = 1.0 / (1.0 - dropout_p) + else: + rng_seed = 0 + rng_offset = 0 + dropout_scale = 1.0 + fused_online_attention_kernel[grid]( query, key, @@ -511,7 +653,17 @@ def _forward_triton( lse.stride(1), lse.stride(2), stride_mb, + stride_mh, + stride_mm, stride_mn, + float(dropout_p), + dropout_scale, + int(rng_seed), + int(rng_offset), + alibi_ptr, + full_seq_len_q, + seq_len_k, + q_start, H=self.num_heads, M=seq_len_q, N=seq_len_k, @@ -520,16 +672,30 @@ def _forward_triton( scale=self.scale, IS_CAUSAL=causal, HAS_MASK=has_mask, + HAS_DROPOUT=has_dropout, + HAS_ALIBI=has_alibi, USE_WGMMA=self.sm >= 90, USE_TMA=self.sm >= 90, USE_CP_ASYNC=self.sm >= 80 and self.sm < 90, ) + if self.world_size > 1: output_list = [torch.empty_like(output) for _ in range(self.world_size)] dist.all_gather(output_list, output) output = torch.cat(output_list, dim=1) + if self.verify: - self._verify_output(query, key, value, output, causal, attention_mask, dropout_p) + self._verify_output( + query, + key, + value, + output, + causal, + attention_mask, + dropout_p, + alibi_slopes, + ) + return (output, lse) if return_lse else output def _verify_output( @@ -541,14 +707,20 @@ def _verify_output( causal: bool, attention_mask: Optional[torch.Tensor], dropout_p: float, + alibi_slopes: Optional[torch.Tensor], ) -> None: """Compare Triton output against PyTorch reference.""" + if dropout_p > 0.0: + # Skip verification when dropout introduces randomness. + return + bsz, sq, _, _ = query.shape sk = key.shape[1] q = query.permute(0, 2, 1, 3).reshape(bsz * self.num_heads, sq, self.head_dim) k = key.permute(0, 2, 1, 3).reshape(bsz * self.num_heads, sk, self.head_dim) v = value.permute(0, 2, 1, 3).reshape(bsz * self.num_heads, sk, self.head_dim) - attn_mask_bh = None + + add_mask = None if attention_mask is not None: attn_mask_bh = self._prepare_attn_mask( attention_mask, @@ -559,11 +731,40 @@ def _verify_output( q.device, q.dtype, ) - sdpa_kwargs = dict( - attn_mask=attn_mask_bh, - is_causal=causal if attn_mask_bh is None else False, - dropout_p=0.0, - ) + if attn_mask_bh.dtype == torch.bool: + add_mask = torch.where( + attn_mask_bh, + torch.full((1,), float("-inf"), dtype=q.dtype, device=q.device), + torch.zeros(1, dtype=q.dtype, device=q.device), + ) + else: + add_mask = attn_mask_bh.to(q.dtype) + + if alibi_slopes is not None: + slopes = alibi_slopes.to(q.device, dtype=torch.float32) + pos_q = torch.arange(sq, device=q.device, dtype=torch.float32) + pos_k = torch.arange(sk, device=q.device, dtype=torch.float32) + delta = pos_k.unsqueeze(0) - pos_q.unsqueeze(1) + bias_h = slopes.view(self.num_heads, 1, 1) * delta + bias_bh = bias_h.unsqueeze(0).expand(bsz, self.num_heads, sq, sk) + bias_bh = bias_bh.reshape(bsz * self.num_heads, sq, sk).to(q.dtype) + add_mask = bias_bh if add_mask is None else add_mask + bias_bh + + is_causal = causal + if add_mask is not None and causal: + tri = torch.triu( + torch.ones(sq, sk, dtype=torch.bool, device=q.device), diagonal=1 + ).unsqueeze(0) + tri = tri.expand(bsz * self.num_heads, sq, sk) + tri_add = torch.where( + tri, + torch.full((1,), float("-inf"), dtype=q.dtype, device=q.device), + torch.zeros(1, dtype=q.dtype, device=q.device), + ) + add_mask = add_mask + tri_add + is_causal = False + + sdpa_kwargs = dict(attn_mask=add_mask, is_causal=is_causal, dropout_p=0.0) ref = F.scaled_dot_product_attention(q, k, v, **sdpa_kwargs) ref = ( ref.reshape(bsz, self.num_heads, sq, self.head_dim) @@ -608,7 +809,7 @@ def benchmark( } -class FusedOnlineAttentionAutogradFn(torch.autograd.Function): +class FusedOnlineAttentionFunction(torch.autograd.Function): @staticmethod def forward( ctx, @@ -618,103 +819,165 @@ def forward( value, causal: bool, attention_mask: Optional[torch.Tensor], - dropout_p: float, + alibi_slopes: Optional[torch.Tensor], + deterministic_mode: bool, + full_seq_len_q: int, + q_start: int, ): + output, lse = module._forward_triton( + query, + key, + value, + causal=causal, + attention_mask=attention_mask, + dropout_p=0.0, + alibi_slopes=alibi_slopes, + deterministic_mode=deterministic_mode, + full_seq_len_q=full_seq_len_q, + q_start=q_start, + return_lse=True, + ) + + if alibi_slopes is not None: + alibi_tensor = alibi_slopes.to(query.device) + else: + alibi_tensor = query.new_empty(0) + ctx.module = module ctx.causal = bool(causal) - ctx.has_mask = attention_mask is not None + ctx.scale = module.scale + ctx.tile_size_q = module.tile_size_q + ctx.batch_size = query.shape[0] + ctx.num_heads = query.shape[2] + ctx.seq_len_q = query.shape[1] + ctx.seq_len_k = key.shape[1] + ctx.head_dim = query.shape[3] ctx.attention_mask = attention_mask - ctx.save_for_backward(query, key, value) - with torch.no_grad(): - out = module._forward_triton( - query, - key, - value, - causal=causal, - attention_mask=attention_mask, - dropout_p=dropout_p, - return_lse=False, - ) - return out + ctx.alibi_used = alibi_slopes is not None + + ctx.save_for_backward(query, key, value, lse, alibi_tensor) + return output @staticmethod - def backward(ctx, grad_out): + def backward(ctx, grad_output): module: FusedOnlineAttention = ctx.module - query, key, value = ctx.saved_tensors - bsz, sq, nh, hd = query.shape - sk = key.shape[1] - - q = ( - query.detach().requires_grad_(True).permute(0, 2, 1, 3).reshape(bsz * nh, sq, hd) - ) - k = ( - key.detach().requires_grad_(True).permute(0, 2, 1, 3).reshape(bsz * nh, sk, hd) - ) - v = ( - value.detach().requires_grad_(True).permute(0, 2, 1, 3).reshape(bsz * nh, sk, hd) - ) - - attn_mask_bh = None - if ctx.has_mask and ctx.attention_mask is not None: - mask = ctx.attention_mask - if mask.dtype != torch.bool: - mask = mask == 0 # numeric: 0 valid - else: - mask = ~mask # boolean: True means masked -> invert - mask = mask.view(bsz, 1, 1, sk).expand(bsz, 1, sq, sk) - attn_mask_bh = mask.expand(bsz, nh, sq, sk).reshape(bsz * nh, sq, sk).to(q.device) - - add_mask = None - if attn_mask_bh is not None: - add_mask = torch.where( - attn_mask_bh, - torch.zeros(1, dtype=q.dtype, device=q.device), - torch.full((1,), float("-inf"), dtype=q.dtype, device=q.device), - ) - if ctx.causal: - tri = torch.triu( - torch.ones(sq, sk, dtype=torch.bool, device=q.device), diagonal=1 - ).unsqueeze(0) - tri_add = torch.where( - tri, - torch.full((1,), float("-inf"), dtype=q.dtype, device=q.device), - torch.zeros(1, dtype=q.dtype, device=q.device), - ) - add_mask = add_mask + tri_add - - sdpa_ctx = nullcontext() - if q.is_cuda: - try: - sdpa_ctx = torch.nn.attention.sdpa_kernel( - SDPBackend.FLASH_ATTENTION - ) - except (AttributeError, TypeError): - try: - sdpa_ctx = torch.backends.cuda.sdp_kernel( - enable_math=True, - enable_flash=True, - enable_mem_efficient=False, - ) - except Exception: # pragma: no cover - environment dependent - sdpa_ctx = nullcontext() - with sdpa_ctx: - y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=(add_mask if add_mask is not None else None), - is_causal=(False if add_mask is not None else ctx.causal), - dropout_p=0.0, + query, key, value, lse, alibi_tensor = ctx.saved_tensors + + batch_size = ctx.batch_size + num_heads = ctx.num_heads + seq_len_q = ctx.seq_len_q + seq_len_k = ctx.seq_len_k + head_dim = ctx.head_dim + scale = ctx.scale + tile_m = ctx.tile_size_q + causal = ctx.causal + + grad_output = grad_output.contiguous() + + q = query.permute(0, 2, 1, 3).contiguous() + k = key.permute(0, 2, 1, 3).contiguous() + v = value.permute(0, 2, 1, 3).contiguous() + go = grad_output.permute(0, 2, 1, 3).contiguous() + lse_bh = lse.to(torch.float32) + + q_float = q.to(torch.float32) + k_float = k.to(torch.float32) + v_float = v.to(torch.float32) + go_float = go.to(torch.float32) + + dQ = torch.zeros_like(q_float) + dK = torch.zeros_like(k_float) + dV = torch.zeros_like(v_float) + + attention_mask = ctx.attention_mask + mask_bh = None + if attention_mask is not None: + mask_tensor = module._prepare_triton_mask( + attention_mask, + batch_size, + num_heads, + seq_len_q, + seq_len_k, + query.device, ) + mask_bh = mask_tensor - y = y.reshape(bsz, nh, sq, hd).permute(0, 2, 1, 3).contiguous() - grads = torch.autograd.grad(y, (q, k, v), grad_out, allow_unused=False) - dq = grads[0].reshape(bsz, nh, sq, hd).permute(0, 2, 1, 3).contiguous() - dk = grads[1].reshape(bsz, nh, sk, hd).permute(0, 2, 1, 3).contiguous() - dv = grads[2].reshape(bsz, nh, sk, hd).permute(0, 2, 1, 3).contiguous() - return None, dq, dk, dv, None, None, None - - + alibi_used = ctx.alibi_used + if alibi_used: + slopes = alibi_tensor.to(torch.float32) + grad_alibi = torch.zeros_like(slopes) + else: + slopes = None + grad_alibi = None + + pos_q = torch.arange(seq_len_q, device=query.device, dtype=torch.float32) + pos_k = torch.arange(seq_len_k, device=query.device, dtype=torch.float32) + + for b in range(batch_size): + for h in range(num_heads): + q_b = q_float[b, h] + k_b = k_float[b, h] + v_b = v_float[b, h] + go_b = go_float[b, h] + lse_b = lse_bh[b, h] + mask_b = mask_bh[b, h] if mask_bh is not None else None + slope = slopes[h] if alibi_used else None + + for m_start in range(0, seq_len_q, tile_m): + m_end = min(m_start + tile_m, seq_len_q) + q_tile = q_b[m_start:m_end] + go_tile = go_b[m_start:m_end] + lse_tile = lse_b[m_start:m_end] + + logits = torch.matmul(q_tile, k_b.transpose(0, 1)) * scale + if mask_b is not None: + logits = logits + mask_b[m_start:m_end] + if alibi_used: + delta = pos_k.unsqueeze(0) - pos_q[m_start:m_end].unsqueeze(1) + logits = logits + slope * delta + if causal: + row_idx = pos_q[m_start:m_end].unsqueeze(1) + causal_mask = pos_k.unsqueeze(0) > row_idx + logits = logits.masked_fill(causal_mask, float('-inf')) + + exp_term = logits - lse_tile.unsqueeze(1) + probs = torch.exp(exp_term) + probs = torch.where(torch.isfinite(probs), probs, torch.zeros_like(probs)) + + dV[b, h] += torch.matmul(probs.transpose(0, 1), go_tile) + + dP = torch.matmul(go_tile, v_b) + attn_dot = (dP * probs).sum(dim=1, keepdim=True) + dS = (dP - attn_dot) * probs + + dQ[b, h, m_start:m_end] += torch.matmul(dS, k_b) * scale + dK[b, h] += torch.matmul(dS.transpose(0, 1), q_tile) * scale + + if alibi_used: + delta = pos_k.unsqueeze(0) - pos_q[m_start:m_end].unsqueeze(1) + grad_alibi[h] += torch.sum(dS * delta) + + grad_query = dQ.to(query.dtype).permute(0, 2, 1, 3).contiguous() + grad_key = dK.to(key.dtype).permute(0, 2, 1, 3).contiguous() + grad_value = dV.to(value.dtype).permute(0, 2, 1, 3).contiguous() + + if alibi_used: + grad_alibi_slopes = grad_alibi.to(alibi_tensor.dtype) + else: + grad_alibi_slopes = None + + return ( + None, + grad_query, + grad_key, + grad_value, + None, + None, + grad_alibi_slopes, + None, + None, + None, + ) def create_fused_online_attention( num_heads: int, head_dim: int, **kwargs ) -> FusedOnlineAttention: diff --git a/tests/test_attention.py b/tests/test_attention.py index 2a1bfbd..a6484cf 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -14,6 +14,10 @@ from stream_attention import StreamAttention, StreamAttentionConfig from stream_attention.core.flashattention_v3 import FlashAttentionV3 +from stream_attention.core.fused_online_attention import ( + FusedOnlineAttention, + TRITON_AVAILABLE as FUSED_TRITON_AVAILABLE, +) from stream_attention.core.ring_attention import RingAttention from stream_attention.core.star_attention import StarAttention from stream_attention.utils.memory import create_kv_compressor, MemoryProfiler @@ -105,6 +109,161 @@ def reference_attention( return output.to(dtype) + def test_fused_online_attention_mask_parity(self, device): + """Ensure Triton fused path matches SDPA when masks are present.""" + if not (torch.cuda.is_available() and FUSED_TRITON_AVAILABLE): + pytest.skip("CUDA + Triton required for fused attention mask test") + + fused = FusedOnlineAttention(num_heads=4, head_dim=32).to(device) + fused.eval() + + batch_size = 2 + seq_len_q = 96 + seq_len_k = 96 + dtype = torch.float16 if device.type == "cuda" else torch.float32 + + q, k, v = self.create_test_tensors( + batch_size, seq_len_q, fused.num_heads, fused.head_dim, device, dtype + ) + + mask = torch.zeros( + batch_size, seq_len_q, seq_len_k, dtype=torch.bool, device=device + ) + mask[:, :, seq_len_k // 2 :] = True + + with torch.no_grad(): + fused_out = fused( + q, + k, + v, + causal=False, + attention_mask=mask, + dropout_p=0.0, + ) + + q_bh = q.permute(0, 2, 1, 3).reshape(batch_size * fused.num_heads, seq_len_q, fused.head_dim) + k_bh = k.permute(0, 2, 1, 3).reshape(batch_size * fused.num_heads, seq_len_k, fused.head_dim) + v_bh = v.permute(0, 2, 1, 3).reshape(batch_size * fused.num_heads, seq_len_k, fused.head_dim) + + mask_bh = mask.unsqueeze(1).expand(batch_size, fused.num_heads, seq_len_q, seq_len_k) + mask_bh = mask_bh.reshape(batch_size * fused.num_heads, seq_len_q, seq_len_k) + add_mask = torch.where( + mask_bh, + torch.full((1,), float("-inf"), dtype=q_bh.dtype, device=device), + torch.zeros(1, dtype=q_bh.dtype, device=device), + ) + with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False): + ref = torch.nn.functional.scaled_dot_product_attention( + q_bh, k_bh, v_bh, attn_mask=add_mask, is_causal=False, dropout_p=0.0 + ) + ref = ref.reshape(batch_size, fused.num_heads, seq_len_q, fused.head_dim).permute(0, 2, 1, 3).contiguous() + + torch.testing.assert_close(fused_out, ref, rtol=5e-3, atol=5e-3) + + def test_fused_online_attention_dropout_determinism(self, device): + """Dropout masks repeat with deterministic seeding.""" + if not (torch.cuda.is_available() and FUSED_TRITON_AVAILABLE): + pytest.skip("CUDA + Triton required for fused attention dropout test") + + fused = FusedOnlineAttention(num_heads=4, head_dim=32).to(device) + fused.train() + + batch_size = 1 + seq_len = 64 + dtype = torch.float16 if device.type == "cuda" else torch.float32 + + q, k, v = self.create_test_tensors( + batch_size, seq_len, fused.num_heads, fused.head_dim, device, dtype + ) + + dropout_p = 0.2 + + fused.set_deterministic(True, seed=2024) + with torch.no_grad(): + ref_out = fused(q, k, v, causal=False, dropout_p=dropout_p) + + fused.set_deterministic(True, seed=2024) + with torch.no_grad(): + reproducible_out = fused(q, k, v, causal=False, dropout_p=dropout_p) + torch.testing.assert_close(ref_out, reproducible_out, rtol=1e-5, atol=1e-5) + + fused.set_deterministic(True, seed=2025) + with torch.no_grad(): + different_out = fused(q, k, v, causal=False, dropout_p=dropout_p) + assert not torch.allclose(ref_out, different_out, atol=1e-4, rtol=1e-4) + + fused.eval() + fused.set_deterministic(True, seed=2024) + with torch.no_grad(): + baseline_out = fused(q, k, v, causal=False, dropout_p=0.0) + fused.train() + assert not torch.allclose(ref_out, baseline_out, atol=1e-4, rtol=1e-4) + + def test_fused_online_attention_backward_matches_sdpa(self, device): + """Backward gradients align with SDPA for mask + ALiBi.""" + if not (torch.cuda.is_available() and FUSED_TRITON_AVAILABLE): + pytest.skip("CUDA + Triton required for fused attention backward test") + + torch.manual_seed(42) + + num_heads = 2 + head_dim = 32 + batch_size = 1 + seq_len_q = 48 + seq_len_k = 48 + dtype = torch.float32 + + fused = FusedOnlineAttention(num_heads=num_heads, head_dim=head_dim, tile_size_q=16, tile_size_k=16).to(device) + fused.train() + + slopes = torch.linspace(0.1, 0.4, steps=num_heads, device=device, dtype=torch.float32) + mask = torch.zeros(batch_size, seq_len_q, seq_len_k, device=device, dtype=torch.float32) + mask[:, :, seq_len_k // 2 :] = -1e4 + + q = torch.randn(batch_size, seq_len_q, num_heads, head_dim, device=device, dtype=dtype, requires_grad=True) + k = torch.randn(batch_size, seq_len_k, num_heads, head_dim, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(batch_size, seq_len_k, num_heads, head_dim, device=device, dtype=dtype, requires_grad=True) + + out = fused(q, k, v, causal=True, attention_mask=mask, alibi_slopes=slopes) + loss = out.sum() + loss.backward() + + grad_q_triton = q.grad.detach().clone() + grad_k_triton = k.grad.detach().clone() + grad_v_triton = v.grad.detach().clone() + + q_ref = q.detach().clone().requires_grad_(True) + k_ref = k.detach().clone().requires_grad_(True) + v_ref = v.detach().clone().requires_grad_(True) + + q_bh = q_ref.permute(0, 2, 1, 3).reshape(batch_size * num_heads, seq_len_q, head_dim) + k_bh = k_ref.permute(0, 2, 1, 3).reshape(batch_size * num_heads, seq_len_k, head_dim) + v_bh = v_ref.permute(0, 2, 1, 3).reshape(batch_size * num_heads, seq_len_k, head_dim) + + mask_bh = mask.unsqueeze(1).expand(batch_size, num_heads, seq_len_q, seq_len_k) + mask_bh = mask_bh.reshape(batch_size * num_heads, seq_len_q, seq_len_k) + + pos_q = torch.arange(seq_len_q, device=device, dtype=torch.float32) + pos_k = torch.arange(seq_len_k, device=device, dtype=torch.float32) + delta = pos_k.unsqueeze(0) - pos_q.unsqueeze(1) + alibi_bias = slopes.to(device=device, dtype=torch.float32).view(num_heads, 1, 1) * delta + alibi_bias = alibi_bias.unsqueeze(0).expand(batch_size, num_heads, seq_len_q, seq_len_k) + alibi_bias = alibi_bias.reshape(batch_size * num_heads, seq_len_q, seq_len_k) + + combined_bias = mask_bh + alibi_bias + with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False): + ref_out = torch.nn.functional.scaled_dot_product_attention( + q_bh, k_bh, v_bh, attn_mask=combined_bias, is_causal=True, dropout_p=0.0 + ) + ref_out = ref_out.reshape(batch_size, num_heads, seq_len_q, head_dim).permute(0, 2, 1, 3) + + ref_loss = ref_out.sum() + ref_loss.backward() + + torch.testing.assert_close(grad_q_triton, q_ref.grad, rtol=5e-3, atol=5e-3) + torch.testing.assert_close(grad_k_triton, k_ref.grad, rtol=5e-3, atol=5e-3) + torch.testing.assert_close(grad_v_triton, v_ref.grad, rtol=5e-3, atol=5e-3) + def test_flash_attention_correctness(self, config, device): """Test FlashAttention V3 correctness""" if not torch.cuda.is_available():