diff --git a/challenges/hard/83_flash_attention/challenge.html b/challenges/hard/83_flash_attention/challenge.html new file mode 100644 index 00000000..5a9b659f --- /dev/null +++ b/challenges/hard/83_flash_attention/challenge.html @@ -0,0 +1,204 @@ +

+ Implement causal multi-head self-attention using the Flash Attention algorithm. Given query, + key, and value tensors of shape (num_heads, seq_len, head_dim), compute the + attention output where each query position can only attend to positions at or before its own + index (causal mask). Your implementation must use tiled computation: process + the sequence in blocks so that the full seq_len × seq_len attention score + matrix is never materialized in memory. +

+ + + + + + + + + + + + + NAIVE (materializes full N×N matrix — too much memory!) + + + + Q + [N×d] + + + × + + + + Kᵀ + [d×N] + + + = + + + + S + [N×N] + + + + O(N²) memory + + + + + FLASH ATTENTION (process one Q-block at a time, loop over K/V blocks) + + + for each Q block i: + + + + Q block i + [Br × d] + + + + + + + for each K/V block j (where j ≤ i for causal): + + + + K block j + + + + V block j + + + + + + + S tile + [Br×Bc] + + + + + + + update m, l, O + + + + O(Br×Bc) + + + + + + Online Softmax Update (per Q-row i, for each K/V tile j): + + + init: + m = −∞, l = 0, O = 0 + + + 1. + m_new = max( m_prev, max_k( S_ik ) ) + + + 2. + l_new = exp(m_prev − m_new) · l_prev + ∑_k exp(S_ik − m_new) + + + 3. + O_new = ( exp(m_prev − m_new) · l_prev · O_prev + ∑_k exp(S_ik − m_new) · V_k ) / l_new + + + done: + output[i] = O after all tiles processed + + +

+ The key insight of Flash Attention is using the online softmax algorithm: instead + of computing softmax over the full row at once, maintain running statistics — the current maximum + \(m_i\) and the running normalizer \(l_i\) — and update the output accumulator \(O_i\) as each + new K/V tile is processed. This makes the computation exact while never requiring the full + \(N \times N\) score matrix in memory. +

+ +\[ +\begin{aligned} +m_i^{(j)} &= \max\!\left(m_i^{(j-1)},\; \max_k S_{ik}^{(j)}\right) \\[4pt] +l_i^{(j)} &= e^{m_i^{(j-1)} - m_i^{(j)}} l_i^{(j-1)} + \sum_k e^{S_{ik}^{(j)} - m_i^{(j)}} \\[4pt] +O_i^{(j)} &= \frac{e^{m_i^{(j-1)} - m_i^{(j)}} l_i^{(j-1)} \cdot O_i^{(j-1)} + \sum_k e^{S_{ik}^{(j)} - m_i^{(j)}} V_k^{(j)}}{l_i^{(j)}} +\end{aligned} +\] + +

where \(S_{ik}^{(j)} = \text{scale} \cdot Q_i K_k^\top\) is the raw attention score between +query position \(i\) and key position \(k\) in tile \(j\), with the causal mask applied +(\(S_{ik} = -\infty\) for \(k > i\)).

+ +

Implementation Requirements

+ + +

Example

+

+ Input: num_heads = 2, seq_len = 3, head_dim = 4 +

+

+ Head 0 — \(Q_0\): + \[ + \begin{bmatrix} + 1.0 & 0.0 & 1.0 & 0.0 \\ + 0.0 & 1.0 & 0.0 & 1.0 \\ + 1.0 & 1.0 & 0.0 & 0.0 + \end{bmatrix} + \] + Head 0 — \(K_0\): + \[ + \begin{bmatrix} + 1.0 & 0.0 & 0.0 & 1.0 \\ + 0.0 & 1.0 & 1.0 & 0.0 \\ + 1.0 & 0.0 & 1.0 & 0.0 + \end{bmatrix} + \] + Head 0 — \(V_0\): + \[ + \begin{bmatrix} + 1.0 & 2.0 & 3.0 & 4.0 \\ + 5.0 & 6.0 & 7.0 & 8.0 \\ + 9.0 & 10.0 & 11.0 & 12.0 + \end{bmatrix} + \] + Head 0 — output (token 0 attends only to itself due to causal mask): + \[ + \begin{bmatrix} + 1.0000 & 2.0000 & 3.0000 & 4.0000 \\ + 3.0000 & 4.0000 & 5.0000 & 6.0000 \\ + 5.0000 & 6.0000 & 7.0000 & 8.0000 + \end{bmatrix} + \] +

+ +

Constraints

+ diff --git a/challenges/hard/83_flash_attention/challenge.py b/challenges/hard/83_flash_attention/challenge.py new file mode 100644 index 00000000..d1e1eb38 --- /dev/null +++ b/challenges/hard/83_flash_attention/challenge.py @@ -0,0 +1,150 @@ +import ctypes +import math +from typing import Any, Dict, List + +import torch +from core.challenge_base import ChallengeBase + + +class Challenge(ChallengeBase): + def __init__(self): + super().__init__( + name="Flash Attention", + atol=1e-03, + rtol=1e-03, + num_gpus=1, + access_tier="free", + ) + + def reference_impl( + self, + Q: torch.Tensor, + K: torch.Tensor, + V: torch.Tensor, + output: torch.Tensor, + num_heads: int, + seq_len: int, + head_dim: int, + ): + assert Q.shape == (num_heads, seq_len, head_dim) + assert K.shape == (num_heads, seq_len, head_dim) + assert V.shape == (num_heads, seq_len, head_dim) + assert output.shape == (num_heads, seq_len, head_dim) + assert Q.dtype == K.dtype == V.dtype == output.dtype == torch.float32 + assert Q.device.type == "cuda" + assert K.device.type == "cuda" + assert V.device.type == "cuda" + assert output.device.type == "cuda" + + scale = 1.0 / math.sqrt(head_dim) + # scores: (num_heads, seq_len, seq_len) + scores = torch.bmm(Q, K.transpose(1, 2)) * scale + # causal mask: upper triangle (j > i) set to -inf + causal_mask = torch.triu( + torch.full((seq_len, seq_len), float("-inf"), device=Q.device, dtype=Q.dtype), + diagonal=1, + ) + scores = scores + causal_mask.unsqueeze(0) + attn_weights = torch.softmax(scores, dim=-1) + output.copy_(torch.bmm(attn_weights, V)) + + def get_solve_signature(self) -> Dict[str, tuple]: + return { + "Q": (ctypes.POINTER(ctypes.c_float), "in"), + "K": (ctypes.POINTER(ctypes.c_float), "in"), + "V": (ctypes.POINTER(ctypes.c_float), "in"), + "output": (ctypes.POINTER(ctypes.c_float), "out"), + "num_heads": (ctypes.c_int, "in"), + "seq_len": (ctypes.c_int, "in"), + "head_dim": (ctypes.c_int, "in"), + } + + def _make_test_case(self, num_heads, seq_len, head_dim, zero_inputs=False): + dtype = torch.float32 + device = "cuda" + if zero_inputs: + Q = torch.zeros(num_heads, seq_len, head_dim, device=device, dtype=dtype) + K = torch.zeros(num_heads, seq_len, head_dim, device=device, dtype=dtype) + V = torch.zeros(num_heads, seq_len, head_dim, device=device, dtype=dtype) + else: + Q = torch.randn(num_heads, seq_len, head_dim, device=device, dtype=dtype) + K = torch.randn(num_heads, seq_len, head_dim, device=device, dtype=dtype) + V = torch.randn(num_heads, seq_len, head_dim, device=device, dtype=dtype) + output = torch.zeros(num_heads, seq_len, head_dim, device=device, dtype=dtype) + return { + "Q": Q, + "K": K, + "V": V, + "output": output, + "num_heads": num_heads, + "seq_len": seq_len, + "head_dim": head_dim, + } + + def generate_example_test(self) -> Dict[str, Any]: + dtype = torch.float32 + device = "cuda" + num_heads = 2 + seq_len = 3 + head_dim = 4 + Q = torch.tensor( + [ + [[1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0], [1.0, 1.0, 0.0, 0.0]], + [[-1.0, 0.5, 0.0, 1.0], [0.0, -1.0, 1.0, 0.0], [0.5, 0.0, -0.5, 1.0]], + ], + device=device, + dtype=dtype, + ) + K = torch.tensor( + [ + [[1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 1.0, 0.0], [1.0, 0.0, 1.0, 0.0]], + [[0.5, -0.5, 1.0, 0.0], [1.0, 0.0, -1.0, 0.5], [-0.5, 1.0, 0.0, -1.0]], + ], + device=device, + dtype=dtype, + ) + V = torch.tensor( + [ + [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]], + [[-1.0, -2.0, -3.0, -4.0], [2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0]], + ], + device=device, + dtype=dtype, + ) + output = torch.zeros(num_heads, seq_len, head_dim, device=device, dtype=dtype) + return { + "Q": Q, + "K": K, + "V": V, + "output": output, + "num_heads": num_heads, + "seq_len": seq_len, + "head_dim": head_dim, + } + + def generate_functional_test(self) -> List[Dict[str, Any]]: + torch.manual_seed(42) + tests = [] + # Edge: single token (only attends to itself) + tests.append(self._make_test_case(1, 1, 8)) + # Edge: 2 tokens + tests.append(self._make_test_case(1, 2, 8)) + # Edge: 4 tokens, 2 heads + tests.append(self._make_test_case(2, 4, 16)) + # Zero inputs + tests.append(self._make_test_case(2, 4, 8, zero_inputs=True)) + # Power-of-2 sizes + tests.append(self._make_test_case(4, 16, 32)) + tests.append(self._make_test_case(4, 64, 64)) + # Non-power-of-2 sizes + tests.append(self._make_test_case(4, 30, 32)) + tests.append(self._make_test_case(4, 100, 64)) + # Realistic inference sizes + tests.append(self._make_test_case(8, 128, 64)) + tests.append(self._make_test_case(8, 256, 64)) + return tests + + def generate_performance_test(self) -> Dict[str, Any]: + torch.manual_seed(0) + # LLM-scale: 8 heads, seq_len=4096, head_dim=64 + return self._make_test_case(8, 4096, 64) diff --git a/challenges/hard/83_flash_attention/starter/starter.cu b/challenges/hard/83_flash_attention/starter/starter.cu new file mode 100644 index 00000000..707af5d8 --- /dev/null +++ b/challenges/hard/83_flash_attention/starter/starter.cu @@ -0,0 +1,5 @@ +#include + +// Q, K, V, output are device pointers +extern "C" void solve(const float* Q, const float* K, const float* V, float* output, int num_heads, + int seq_len, int head_dim) {} diff --git a/challenges/hard/83_flash_attention/starter/starter.cute.py b/challenges/hard/83_flash_attention/starter/starter.cute.py new file mode 100644 index 00000000..06da950a --- /dev/null +++ b/challenges/hard/83_flash_attention/starter/starter.cute.py @@ -0,0 +1,16 @@ +import cutlass +import cutlass.cute as cute + + +# Q, K, V, output are tensors on the GPU +@cute.jit +def solve( + Q: cute.Tensor, + K: cute.Tensor, + V: cute.Tensor, + output: cute.Tensor, + num_heads: cute.Int32, + seq_len: cute.Int32, + head_dim: cute.Int32, +): + pass diff --git a/challenges/hard/83_flash_attention/starter/starter.jax.py b/challenges/hard/83_flash_attention/starter/starter.jax.py new file mode 100644 index 00000000..d915992a --- /dev/null +++ b/challenges/hard/83_flash_attention/starter/starter.jax.py @@ -0,0 +1,11 @@ +import jax +import jax.numpy as jnp + + +# Q, K, V are tensors on GPU +@jax.jit +def solve( + Q: jax.Array, K: jax.Array, V: jax.Array, num_heads: int, seq_len: int, head_dim: int +) -> jax.Array: + # return output tensor directly + pass diff --git a/challenges/hard/83_flash_attention/starter/starter.mojo b/challenges/hard/83_flash_attention/starter/starter.mojo new file mode 100644 index 00000000..1bedbb3f --- /dev/null +++ b/challenges/hard/83_flash_attention/starter/starter.mojo @@ -0,0 +1,9 @@ +from gpu.host import DeviceContext +from gpu.id import block_dim, block_idx, thread_idx +from memory import UnsafePointer +from math import ceildiv + +# Q, K, V, output are device pointers +@export +def solve(Q: UnsafePointer[Float32], K: UnsafePointer[Float32], V: UnsafePointer[Float32], output: UnsafePointer[Float32], num_heads: Int32, seq_len: Int32, head_dim: Int32): + pass diff --git a/challenges/hard/83_flash_attention/starter/starter.pytorch.py b/challenges/hard/83_flash_attention/starter/starter.pytorch.py new file mode 100644 index 00000000..7ae6982d --- /dev/null +++ b/challenges/hard/83_flash_attention/starter/starter.pytorch.py @@ -0,0 +1,14 @@ +import torch + + +# Q, K, V, output are tensors on the GPU +def solve( + Q: torch.Tensor, + K: torch.Tensor, + V: torch.Tensor, + output: torch.Tensor, + num_heads: int, + seq_len: int, + head_dim: int, +): + pass diff --git a/challenges/hard/83_flash_attention/starter/starter.triton.py b/challenges/hard/83_flash_attention/starter/starter.triton.py new file mode 100644 index 00000000..b0e09f23 --- /dev/null +++ b/challenges/hard/83_flash_attention/starter/starter.triton.py @@ -0,0 +1,16 @@ +import torch +import triton +import triton.language as tl + + +# Q, K, V, output are tensors on the GPU +def solve( + Q: torch.Tensor, + K: torch.Tensor, + V: torch.Tensor, + output: torch.Tensor, + num_heads: int, + seq_len: int, + head_dim: int, +): + pass