Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 12 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ This repository provides:
- Optional integration helpers for Hugging Face models

## Recent Updates
## Recent Updates

- Added a single-sweep Triton backward path (streaming dQ/dK/dV using saved `lse`) when `dropout_p == 0`.
- Triton forward now supports boolean/additive masks, dropout, deterministic Philox seeding, and ALiBi bias without SDPA fallback.
- Expanded tests/docs covering mask/dropout/ALiBi parity plus deterministic mode usage.

- Updated `FlashAttentionV3` wrapper to use PyTorch 2.8 `sdpa_kernel` API with a fallback for older releases.
- Restored the Triton fused kernel and added a parity test against FlashAttention‑3.
- Normalized indentation in the Triton kernel to avoid `TabError` on Colab and other environments.


## Installation
Expand Down Expand Up @@ -61,7 +63,7 @@ with torch.no_grad():
y = attention(hidden_states=x)
print(y.shape)

# Explicit Q,K,V path (supports attention mask via SDPA fallback)
# Explicit Q,K,V path (supports attn_mask/dropout/ALiBi in Triton, falling back to SDPA when needed)
q = torch.randn(batch_size, seq_len, config.num_heads, config.head_dim, device=x.device, dtype=x.dtype)
k = torch.randn_like(q)
v = torch.randn_like(q)
Expand All @@ -79,18 +81,10 @@ print(y_qkv.shape)
### StreamAttention
- Purpose: High-level module. Accepts either `hidden_states` ([B, T, H*D]) or explicit `(query, key, value)` ([B, T, H, D]).
- Signature (selected):
- `forward(hidden_states: Tensor, ..., use_cache: bool=False, causal: bool=True)` → `Tensor` or `(Tensor, (k, v))` if `use_cache=True`
- `forward(query: Tensor, key: Tensor, value: Tensor, causal: bool=True, attention_mask: Optional[Tensor]=None)` → `Tensor`
- Shapes: `[batch, seq, heads, dim]` for QKV mode.
- Dtypes: fp16/bf16 (CUDA), fp32 (CPU by default). On CPU, inputs upcast to fp32 if required.

### FusedOnlineAttention
- Purpose: Low-level fused online softmax attention (Triton when available; SDPA fallback otherwise).
- Signature (selected):
- `forward(query, key, value, causal: bool=True, return_lse: bool=False, attention_mask: Optional[Tensor]=None)` → `Tensor` (and `lse` if requested)
- `benchmark(seq_len: int, batch_size: int=1, warmup: int=10, iterations: int=100)` → metrics dict
- Autograd: If gradients are required, the module automatically falls back to PyTorch SDPA to ensure correct backward support. The Triton path is intended for forward-critical inference/benchmarking.
- Dropout is not supported in the fused kernel; apply it outside the module if needed.
- `forward(query, key, value, causal: bool=True, return_lse: bool=False, attention_mask: Optional[Tensor]=None, dropout_p: float=0.0, alibi_slopes: Optional[Tensor]=None, deterministic: Optional[bool]=None)` -> `Tensor` (and `lse` if requested)
- `benchmark(seq_len: int, batch_size: int=1, warmup: int=10, iterations: int=100)` -> metrics dict
- `set_deterministic(enabled: bool, seed: Optional[int]=None)` -> control deterministic dropout/mask behavior
- Autograd: When gradients are required and `dropout_p == 0`, the Triton kernel executes a single-sweep backward pass (streaming dQ/dK/dV using the saved `lse`). If dropout is enabled during training, the module falls back to PyTorch SDPA for gradient computation.

### Multihead-style wrapper
Use `create_stream_attention` to obtain an attention layer with a familiar
Expand Down Expand Up @@ -138,7 +132,7 @@ stream-attention-test --seq 1024 --batch 2 --heads 8 --dim 64 --dtype fp16
Behavior and methodology:
- On CUDA, the baseline uses PyTorch SDPA with the flash backend (FlashAttention-3 path). On CPU, both implementations use SDPA in fp32.
- Metrics reported: per-iteration latency, estimated TFLOPS, and approximate bandwidth based on tensor traffic. Measurements are averaged after warmup.
- The fused kernel uses Triton on CUDA for the forward pass; when gradients are required, it falls back to an SDPA-backed autograd path. Otherwise, SDPA is used to ensure correctness, masking, and training.
- The fused kernel uses Triton on CUDA for the forward pass; when gradients are required and `dropout_p == 0`, the streaming backward (single sweep over K/V) is invoked. If dropout is active during training, the module falls back to SDPA for gradient computation.
- For reproducibility, fix random seeds, pin CUDA clocks if applicable, and isolate runs. Actual performance depends on GPU architecture, drivers, and PyTorch/Triton versions.

Example output (format):
Expand Down Expand Up @@ -227,9 +221,8 @@ print(f"Replaced {num_replaced} attention modules")

## Roadmap

- Backward implementation for the Triton fused kernel
- Native RoPE / relative positional bias fusion in the Triton kernel (forward + backward)
- Advanced pipelining (warp specialization, asynchronous staging) and Hopper-specific paths (WGMMA/TMA)
- Full support for attention masks in the fused kernel
- Extended autotune coverage across architectures and sequence regimes
- Optional FP8 path with block-wise scaling

Expand All @@ -240,7 +233,6 @@ print(f"Replaced {num_replaced} attention modules")
- Accuracy checks: `stream-attention-test` CLI
- Examples: `examples/` directory includes basic usage, integrations, and long-context runs
- Environment variables: see `.env.example`; `StreamAttentionConfig.from_env()` can bootstrap configuration
- Environment variables: see `.env.example`; `StreamAttentionConfig.from_env()` can bootstrap configuration


## License
Expand Down
4 changes: 2 additions & 2 deletions docs/Index.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ Use log-sum-exp recentering:

The chosen tile size (`TILE_K = 64`) is a starting point. Optimal performance may require experimenting with different tile sizes based on your hardware. If performance is suboptimal, use profiling tools (e.g., NVIDIA Nsight compute) to identify and resolve bottlenecks.

### Triton Limitations
### Triton Notes

Currently, Triton does not expose `cp.async`. This implementation relies on `tl.load` with masking and autotuned tile sizes. The module automatically falls back to PyTorch SDPA for autograd or masking.
Currently, Triton does not expose `cp.async`. This implementation relies on `tl.load` with masking and autotuned tile sizes. The fused forward supports native boolean/additive attention masks, dropout, and ALiBi biasing. Deterministic mode (`set_deterministic`) seeds the Philox stream so dropout/mask sampling is reproducible. When `dropout_p == 0`, the same saved `lse` is reused to run a single-sweep backward pass; otherwise we fall back to PyTorch SDPA for gradients.

### Distributed Setup Issues

Expand Down
Loading