From 7fe1867b3e3a149e9aa10648cd63084140d5ec15 Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Tue, 17 Mar 2026 04:15:42 +0000 Subject: [PATCH 1/4] Add challenge 83: Flash Attention (Hard) Co-Authored-By: Claude Sonnet 4.6 --- .../hard/83_flash_attention/challenge.html | 150 ++++++++++++++++++ .../hard/83_flash_attention/challenge.py | 150 ++++++++++++++++++ .../83_flash_attention/solution/solution.cu | 91 +++++++++++ .../83_flash_attention/starter/starter.cu | 5 + .../starter/starter.cute.py | 16 ++ .../83_flash_attention/starter/starter.jax.py | 11 ++ .../83_flash_attention/starter/starter.mojo | 9 ++ .../starter/starter.pytorch.py | 14 ++ .../starter/starter.triton.py | 16 ++ 9 files changed, 462 insertions(+) create mode 100644 challenges/hard/83_flash_attention/challenge.html create mode 100644 challenges/hard/83_flash_attention/challenge.py create mode 100644 challenges/hard/83_flash_attention/solution/solution.cu create mode 100644 challenges/hard/83_flash_attention/starter/starter.cu create mode 100644 challenges/hard/83_flash_attention/starter/starter.cute.py create mode 100644 challenges/hard/83_flash_attention/starter/starter.jax.py create mode 100644 challenges/hard/83_flash_attention/starter/starter.mojo create mode 100644 challenges/hard/83_flash_attention/starter/starter.pytorch.py create mode 100644 challenges/hard/83_flash_attention/starter/starter.triton.py diff --git a/challenges/hard/83_flash_attention/challenge.html b/challenges/hard/83_flash_attention/challenge.html new file mode 100644 index 00000000..00215193 --- /dev/null +++ b/challenges/hard/83_flash_attention/challenge.html @@ -0,0 +1,150 @@ +

+ Implement causal multi-head self-attention using the Flash Attention algorithm. Given query, + key, and value tensors of shape (num_heads, seq_len, head_dim), compute the + attention output where each query position can only attend to positions at or before its own + index (causal mask). Your implementation must use tiled computation: process + the sequence in blocks so that the full seq_len × seq_len attention score + matrix is never materialized in memory. +

+ + + + + + + + + + + + + Flash Attention: Tiled Computation + + + Q + + + Q₀ (rows 0..Br) + + + Q₁ + + + Q₂ + + + + + seq_len / Br blocks + + + + + + Kᵀ, V + + K₀ + + + K₁ + + + K₂ + + + + + seq_len / Bc blocks + + + + Tile Sᵢⱼ = Qᵢ Kⱼᵀ · scale + Apply causal mask + Update (m, l, O) online + + + + loop + over j + + + O + + O₀ (accumulated) + + + + + + + Online Softmax Update (for each K/V tile j): + m_new = max(m_prev, rowmax(Sᵢⱼ)) + l_new = exp(m_prev − m_new)·l_prev + rowsum(exp(Sᵢⱼ − m_new)) + O_new = (exp(m_prev − m_new)·l_prev·O_prev + exp(Sᵢⱼ − m_new)·Vⱼ) / l_new + + +

+ The key insight of Flash Attention is using the online softmax algorithm: instead + of computing softmax over the full row at once, maintain running statistics — the current maximum + \(m_i\) and the running normalizer \(l_i\) — and update the output accumulator \(O_i\) as each + new K/V tile is processed. This makes the computation exact while never requiring the full + \(N \times N\) score matrix in memory. +

+ +\[ +\begin{aligned} +m_i^{(j)} &= \max\!\left(m_i^{(j-1)},\; \max_k S_{ik}^{(j)}\right) \\[4pt] +l_i^{(j)} &= e^{m_i^{(j-1)} - m_i^{(j)}} l_i^{(j-1)} + \sum_k e^{S_{ik}^{(j)} - m_i^{(j)}} \\[4pt] +O_i^{(j)} &= \frac{e^{m_i^{(j-1)} - m_i^{(j)}} l_i^{(j-1)} \cdot O_i^{(j-1)} + e^{S_{ik}^{(j)} - m_i^{(j)}} V^{(j)}}{l_i^{(j)}} +\end{aligned} +\] + +

where \(S_{ik}^{(j)} = \text{scale} \cdot Q_i K_k^\top\) is the raw attention score between +query position \(i\) and key position \(k\) in tile \(j\), with the causal mask applied +(\(S_{ik} = -\infty\) for \(k > i\)).

+ +

Implementation Requirements

+ + +

Example

+

With num_heads = 2, seq_len = 3, head_dim = 4:

+
+Input:  Q.shape = (2, 3, 4)   # 2 heads, 3 tokens, dim 4
+        K.shape = (2, 3, 4)
+        V.shape = (2, 3, 4)
+Output: output.shape = (2, 3, 4)
+
+# For head 0, token 0 (first query):
+#   Attends only to key 0 (causal) → output[0][0] = V[0][0]
+
+ +

Constraints

+ diff --git a/challenges/hard/83_flash_attention/challenge.py b/challenges/hard/83_flash_attention/challenge.py new file mode 100644 index 00000000..d1e1eb38 --- /dev/null +++ b/challenges/hard/83_flash_attention/challenge.py @@ -0,0 +1,150 @@ +import ctypes +import math +from typing import Any, Dict, List + +import torch +from core.challenge_base import ChallengeBase + + +class Challenge(ChallengeBase): + def __init__(self): + super().__init__( + name="Flash Attention", + atol=1e-03, + rtol=1e-03, + num_gpus=1, + access_tier="free", + ) + + def reference_impl( + self, + Q: torch.Tensor, + K: torch.Tensor, + V: torch.Tensor, + output: torch.Tensor, + num_heads: int, + seq_len: int, + head_dim: int, + ): + assert Q.shape == (num_heads, seq_len, head_dim) + assert K.shape == (num_heads, seq_len, head_dim) + assert V.shape == (num_heads, seq_len, head_dim) + assert output.shape == (num_heads, seq_len, head_dim) + assert Q.dtype == K.dtype == V.dtype == output.dtype == torch.float32 + assert Q.device.type == "cuda" + assert K.device.type == "cuda" + assert V.device.type == "cuda" + assert output.device.type == "cuda" + + scale = 1.0 / math.sqrt(head_dim) + # scores: (num_heads, seq_len, seq_len) + scores = torch.bmm(Q, K.transpose(1, 2)) * scale + # causal mask: upper triangle (j > i) set to -inf + causal_mask = torch.triu( + torch.full((seq_len, seq_len), float("-inf"), device=Q.device, dtype=Q.dtype), + diagonal=1, + ) + scores = scores + causal_mask.unsqueeze(0) + attn_weights = torch.softmax(scores, dim=-1) + output.copy_(torch.bmm(attn_weights, V)) + + def get_solve_signature(self) -> Dict[str, tuple]: + return { + "Q": (ctypes.POINTER(ctypes.c_float), "in"), + "K": (ctypes.POINTER(ctypes.c_float), "in"), + "V": (ctypes.POINTER(ctypes.c_float), "in"), + "output": (ctypes.POINTER(ctypes.c_float), "out"), + "num_heads": (ctypes.c_int, "in"), + "seq_len": (ctypes.c_int, "in"), + "head_dim": (ctypes.c_int, "in"), + } + + def _make_test_case(self, num_heads, seq_len, head_dim, zero_inputs=False): + dtype = torch.float32 + device = "cuda" + if zero_inputs: + Q = torch.zeros(num_heads, seq_len, head_dim, device=device, dtype=dtype) + K = torch.zeros(num_heads, seq_len, head_dim, device=device, dtype=dtype) + V = torch.zeros(num_heads, seq_len, head_dim, device=device, dtype=dtype) + else: + Q = torch.randn(num_heads, seq_len, head_dim, device=device, dtype=dtype) + K = torch.randn(num_heads, seq_len, head_dim, device=device, dtype=dtype) + V = torch.randn(num_heads, seq_len, head_dim, device=device, dtype=dtype) + output = torch.zeros(num_heads, seq_len, head_dim, device=device, dtype=dtype) + return { + "Q": Q, + "K": K, + "V": V, + "output": output, + "num_heads": num_heads, + "seq_len": seq_len, + "head_dim": head_dim, + } + + def generate_example_test(self) -> Dict[str, Any]: + dtype = torch.float32 + device = "cuda" + num_heads = 2 + seq_len = 3 + head_dim = 4 + Q = torch.tensor( + [ + [[1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0], [1.0, 1.0, 0.0, 0.0]], + [[-1.0, 0.5, 0.0, 1.0], [0.0, -1.0, 1.0, 0.0], [0.5, 0.0, -0.5, 1.0]], + ], + device=device, + dtype=dtype, + ) + K = torch.tensor( + [ + [[1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 1.0, 0.0], [1.0, 0.0, 1.0, 0.0]], + [[0.5, -0.5, 1.0, 0.0], [1.0, 0.0, -1.0, 0.5], [-0.5, 1.0, 0.0, -1.0]], + ], + device=device, + dtype=dtype, + ) + V = torch.tensor( + [ + [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]], + [[-1.0, -2.0, -3.0, -4.0], [2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0]], + ], + device=device, + dtype=dtype, + ) + output = torch.zeros(num_heads, seq_len, head_dim, device=device, dtype=dtype) + return { + "Q": Q, + "K": K, + "V": V, + "output": output, + "num_heads": num_heads, + "seq_len": seq_len, + "head_dim": head_dim, + } + + def generate_functional_test(self) -> List[Dict[str, Any]]: + torch.manual_seed(42) + tests = [] + # Edge: single token (only attends to itself) + tests.append(self._make_test_case(1, 1, 8)) + # Edge: 2 tokens + tests.append(self._make_test_case(1, 2, 8)) + # Edge: 4 tokens, 2 heads + tests.append(self._make_test_case(2, 4, 16)) + # Zero inputs + tests.append(self._make_test_case(2, 4, 8, zero_inputs=True)) + # Power-of-2 sizes + tests.append(self._make_test_case(4, 16, 32)) + tests.append(self._make_test_case(4, 64, 64)) + # Non-power-of-2 sizes + tests.append(self._make_test_case(4, 30, 32)) + tests.append(self._make_test_case(4, 100, 64)) + # Realistic inference sizes + tests.append(self._make_test_case(8, 128, 64)) + tests.append(self._make_test_case(8, 256, 64)) + return tests + + def generate_performance_test(self) -> Dict[str, Any]: + torch.manual_seed(0) + # LLM-scale: 8 heads, seq_len=4096, head_dim=64 + return self._make_test_case(8, 4096, 64) diff --git a/challenges/hard/83_flash_attention/solution/solution.cu b/challenges/hard/83_flash_attention/solution/solution.cu new file mode 100644 index 00000000..97bb44a1 --- /dev/null +++ b/challenges/hard/83_flash_attention/solution/solution.cu @@ -0,0 +1,91 @@ +#include +#include +#include + +#define BR 32 +#define BC 32 +#define MAX_HEAD_DIM 256 + +__global__ void flash_attn_kernel(const float* __restrict__ Q, const float* __restrict__ K, + const float* __restrict__ V, float* __restrict__ output, + int seq_len, int head_dim, float scale) { + int h = blockIdx.x; + int qi_block = blockIdx.y; + int ti = threadIdx.x; + int qi = qi_block * BR + ti; + if (qi >= seq_len) + return; + + const float* Qh = Q + (long)h * seq_len * head_dim; + const float* Kh = K + (long)h * seq_len * head_dim; + const float* Vh = V + (long)h * seq_len * head_dim; + float* Oh = output + (long)h * seq_len * head_dim; + + float m = -FLT_MAX; + float l = 0.0f; + float acc[MAX_HEAD_DIM]; + for (int d = 0; d < head_dim; d++) + acc[d] = 0.0f; + + for (int kj_block = 0; (long)kj_block * BC <= qi; kj_block++) { + int kj_start = kj_block * BC; + int kj_end = kj_start + BC; + if (kj_end > seq_len) + kj_end = seq_len; + if (kj_end > qi + 1) + kj_end = qi + 1; + int actual_len = kj_end - kj_start; + + float s[BC]; + for (int j = 0; j < actual_len; j++) { + int kj = kj_start + j; + float dot = 0.0f; + for (int d = 0; d < head_dim; d++) { + dot += Qh[qi * head_dim + d] * Kh[kj * head_dim + d]; + } + s[j] = dot * scale; + } + + float m_tile = -FLT_MAX; + for (int j = 0; j < actual_len; j++) { + if (s[j] > m_tile) + m_tile = s[j]; + } + + float m_new = (m > m_tile) ? m : m_tile; + float alpha = expf(m - m_new); + + float p[BC]; + float l_tile = 0.0f; + for (int j = 0; j < actual_len; j++) { + p[j] = expf(s[j] - m_new); + l_tile += p[j]; + } + float l_new = alpha * l + l_tile; + + for (int d = 0; d < head_dim; d++) { + float pv = 0.0f; + for (int j = 0; j < actual_len; j++) { + pv += p[j] * Vh[(kj_start + j) * head_dim + d]; + } + acc[d] = (alpha * l * acc[d] + pv) / l_new; + } + + m = m_new; + l = l_new; + } + + for (int d = 0; d < head_dim; d++) { + Oh[qi * head_dim + d] = acc[d]; + } +} + +extern "C" void solve(const float* Q, const float* K, const float* V, float* output, int num_heads, + int seq_len, int head_dim) { + float scale = 1.0f / sqrtf((float)head_dim); + int num_q_blocks = (seq_len + BR - 1) / BR; + dim3 grid(num_heads, num_q_blocks); + dim3 block(BR); + flash_attn_kernel<<>>(Q, K, V, output, seq_len, head_dim, scale); + cudaDeviceSynchronize(); +} diff --git a/challenges/hard/83_flash_attention/starter/starter.cu b/challenges/hard/83_flash_attention/starter/starter.cu new file mode 100644 index 00000000..707af5d8 --- /dev/null +++ b/challenges/hard/83_flash_attention/starter/starter.cu @@ -0,0 +1,5 @@ +#include + +// Q, K, V, output are device pointers +extern "C" void solve(const float* Q, const float* K, const float* V, float* output, int num_heads, + int seq_len, int head_dim) {} diff --git a/challenges/hard/83_flash_attention/starter/starter.cute.py b/challenges/hard/83_flash_attention/starter/starter.cute.py new file mode 100644 index 00000000..06da950a --- /dev/null +++ b/challenges/hard/83_flash_attention/starter/starter.cute.py @@ -0,0 +1,16 @@ +import cutlass +import cutlass.cute as cute + + +# Q, K, V, output are tensors on the GPU +@cute.jit +def solve( + Q: cute.Tensor, + K: cute.Tensor, + V: cute.Tensor, + output: cute.Tensor, + num_heads: cute.Int32, + seq_len: cute.Int32, + head_dim: cute.Int32, +): + pass diff --git a/challenges/hard/83_flash_attention/starter/starter.jax.py b/challenges/hard/83_flash_attention/starter/starter.jax.py new file mode 100644 index 00000000..d915992a --- /dev/null +++ b/challenges/hard/83_flash_attention/starter/starter.jax.py @@ -0,0 +1,11 @@ +import jax +import jax.numpy as jnp + + +# Q, K, V are tensors on GPU +@jax.jit +def solve( + Q: jax.Array, K: jax.Array, V: jax.Array, num_heads: int, seq_len: int, head_dim: int +) -> jax.Array: + # return output tensor directly + pass diff --git a/challenges/hard/83_flash_attention/starter/starter.mojo b/challenges/hard/83_flash_attention/starter/starter.mojo new file mode 100644 index 00000000..1bedbb3f --- /dev/null +++ b/challenges/hard/83_flash_attention/starter/starter.mojo @@ -0,0 +1,9 @@ +from gpu.host import DeviceContext +from gpu.id import block_dim, block_idx, thread_idx +from memory import UnsafePointer +from math import ceildiv + +# Q, K, V, output are device pointers +@export +def solve(Q: UnsafePointer[Float32], K: UnsafePointer[Float32], V: UnsafePointer[Float32], output: UnsafePointer[Float32], num_heads: Int32, seq_len: Int32, head_dim: Int32): + pass diff --git a/challenges/hard/83_flash_attention/starter/starter.pytorch.py b/challenges/hard/83_flash_attention/starter/starter.pytorch.py new file mode 100644 index 00000000..7ae6982d --- /dev/null +++ b/challenges/hard/83_flash_attention/starter/starter.pytorch.py @@ -0,0 +1,14 @@ +import torch + + +# Q, K, V, output are tensors on the GPU +def solve( + Q: torch.Tensor, + K: torch.Tensor, + V: torch.Tensor, + output: torch.Tensor, + num_heads: int, + seq_len: int, + head_dim: int, +): + pass diff --git a/challenges/hard/83_flash_attention/starter/starter.triton.py b/challenges/hard/83_flash_attention/starter/starter.triton.py new file mode 100644 index 00000000..b0e09f23 --- /dev/null +++ b/challenges/hard/83_flash_attention/starter/starter.triton.py @@ -0,0 +1,16 @@ +import torch +import triton +import triton.language as tl + + +# Q, K, V, output are tensors on the GPU +def solve( + Q: torch.Tensor, + K: torch.Tensor, + V: torch.Tensor, + output: torch.Tensor, + num_heads: int, + seq_len: int, + head_dim: int, +): + pass From 1aa0f716d11fe73629d56d212105b19eac6dca9a Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Tue, 17 Mar 2026 04:15:56 +0000 Subject: [PATCH 2/4] Remove solution file from challenge 83 Co-Authored-By: Claude Sonnet 4.6 --- .../83_flash_attention/solution/solution.cu | 91 ------------------- 1 file changed, 91 deletions(-) delete mode 100644 challenges/hard/83_flash_attention/solution/solution.cu diff --git a/challenges/hard/83_flash_attention/solution/solution.cu b/challenges/hard/83_flash_attention/solution/solution.cu deleted file mode 100644 index 97bb44a1..00000000 --- a/challenges/hard/83_flash_attention/solution/solution.cu +++ /dev/null @@ -1,91 +0,0 @@ -#include -#include -#include - -#define BR 32 -#define BC 32 -#define MAX_HEAD_DIM 256 - -__global__ void flash_attn_kernel(const float* __restrict__ Q, const float* __restrict__ K, - const float* __restrict__ V, float* __restrict__ output, - int seq_len, int head_dim, float scale) { - int h = blockIdx.x; - int qi_block = blockIdx.y; - int ti = threadIdx.x; - int qi = qi_block * BR + ti; - if (qi >= seq_len) - return; - - const float* Qh = Q + (long)h * seq_len * head_dim; - const float* Kh = K + (long)h * seq_len * head_dim; - const float* Vh = V + (long)h * seq_len * head_dim; - float* Oh = output + (long)h * seq_len * head_dim; - - float m = -FLT_MAX; - float l = 0.0f; - float acc[MAX_HEAD_DIM]; - for (int d = 0; d < head_dim; d++) - acc[d] = 0.0f; - - for (int kj_block = 0; (long)kj_block * BC <= qi; kj_block++) { - int kj_start = kj_block * BC; - int kj_end = kj_start + BC; - if (kj_end > seq_len) - kj_end = seq_len; - if (kj_end > qi + 1) - kj_end = qi + 1; - int actual_len = kj_end - kj_start; - - float s[BC]; - for (int j = 0; j < actual_len; j++) { - int kj = kj_start + j; - float dot = 0.0f; - for (int d = 0; d < head_dim; d++) { - dot += Qh[qi * head_dim + d] * Kh[kj * head_dim + d]; - } - s[j] = dot * scale; - } - - float m_tile = -FLT_MAX; - for (int j = 0; j < actual_len; j++) { - if (s[j] > m_tile) - m_tile = s[j]; - } - - float m_new = (m > m_tile) ? m : m_tile; - float alpha = expf(m - m_new); - - float p[BC]; - float l_tile = 0.0f; - for (int j = 0; j < actual_len; j++) { - p[j] = expf(s[j] - m_new); - l_tile += p[j]; - } - float l_new = alpha * l + l_tile; - - for (int d = 0; d < head_dim; d++) { - float pv = 0.0f; - for (int j = 0; j < actual_len; j++) { - pv += p[j] * Vh[(kj_start + j) * head_dim + d]; - } - acc[d] = (alpha * l * acc[d] + pv) / l_new; - } - - m = m_new; - l = l_new; - } - - for (int d = 0; d < head_dim; d++) { - Oh[qi * head_dim + d] = acc[d]; - } -} - -extern "C" void solve(const float* Q, const float* K, const float* V, float* output, int num_heads, - int seq_len, int head_dim) { - float scale = 1.0f / sqrtf((float)head_dim); - int num_q_blocks = (seq_len + BR - 1) / BR; - dim3 grid(num_heads, num_q_blocks); - dim3 block(BR); - flash_attn_kernel<<>>(Q, K, V, output, seq_len, head_dim, scale); - cudaDeviceSynchronize(); -} From af2e359698d3dffd62d31f1bfb3478dfb95dbe41 Mon Sep 17 00:00:00 2001 From: James Song Date: Thu, 26 Mar 2026 21:26:32 -0400 Subject: [PATCH 3/4] Fix flash attention HTML: example values, O formula summation, SVG redesign - Show actual Q/K/V/output values from generate_example_test() instead of just tensor shapes - Add missing summation over k in the LaTeX O_i update formula - Redesign SVG to clearly contrast naive O(N^2) approach vs Flash Attention's tiled approach with online softmax update steps Co-Authored-By: Claude Opus 4.6 (1M context) --- .../hard/83_flash_attention/challenge.html | 236 +++++++++++------- 1 file changed, 145 insertions(+), 91 deletions(-) diff --git a/challenges/hard/83_flash_attention/challenge.html b/challenges/hard/83_flash_attention/challenge.html index 00215193..ef75ea7d 100644 --- a/challenges/hard/83_flash_attention/challenge.html +++ b/challenges/hard/83_flash_attention/challenge.html @@ -7,90 +7,117 @@ matrix is never materialized in memory.

- + + - - + + - - - - - Flash Attention: Tiled Computation - - - Q - - - Q₀ (rows 0..Br) - - - Q₁ - - - Q₂ - - - - - seq_len / Br blocks - - - - - - Kᵀ, V - - K₀ - - - K₁ - - - K₂ - - - - - seq_len / Bc blocks - - - - Tile Sᵢⱼ = Qᵢ Kⱼᵀ · scale - Apply causal mask - Update (m, l, O) online - - - - loop - over j - - - O - - O₀ (accumulated) - - - - - - - Online Softmax Update (for each K/V tile j): - m_new = max(m_prev, rowmax(Sᵢⱼ)) - l_new = exp(m_prev − m_new)·l_prev + rowsum(exp(Sᵢⱼ − m_new)) - O_new = (exp(m_prev − m_new)·l_prev·O_prev + exp(Sᵢⱼ − m_new)·Vⱼ) / l_new + + + + NAIVE (materializes full N×N matrix — too much memory!) + + + + Q + [N×d] + + + × + + + + Kᵀ + [d×N] + + + = + + + + S + [N×N] + + + + O(N²) memory + + + + + FLASH ATTENTION (process one Q-block at a time, loop over K/V blocks) + + + for each Q block i: + + + + Q block i + [Br × d] + + + + + + + for each K/V block j (where j ≤ i for causal): + + + + K block j + + + + V block j + + + + + + + S tile + [Br×Bc] + + + + + + + update m, l, O + + + + O(Br×Bc) + + + + + + Online Softmax Update (per Q-row i, for each K/V tile j): + + + init: + m = −∞, l = 0, O = 0 + + + 1. + m_new = max( m_prev, max_k( S_ik ) ) + + + 2. + l_new = exp(m_prev − m_new) · l_prev + ∑_k exp(S_ik − m_new) + + + 3. + O_new = ( exp(m_prev − m_new) · l_prev · O_prev + ∑_k exp(S_ik − m_new) · V_k ) / l_new + + + done: + output[i] = O after all tiles processed

@@ -105,7 +132,7 @@ \begin{aligned} m_i^{(j)} &= \max\!\left(m_i^{(j-1)},\; \max_k S_{ik}^{(j)}\right) \\[4pt] l_i^{(j)} &= e^{m_i^{(j-1)} - m_i^{(j)}} l_i^{(j-1)} + \sum_k e^{S_{ik}^{(j)} - m_i^{(j)}} \\[4pt] -O_i^{(j)} &= \frac{e^{m_i^{(j-1)} - m_i^{(j)}} l_i^{(j-1)} \cdot O_i^{(j-1)} + e^{S_{ik}^{(j)} - m_i^{(j)}} V^{(j)}}{l_i^{(j)}} +O_i^{(j)} &= \frac{e^{m_i^{(j-1)} - m_i^{(j)}} l_i^{(j-1)} \cdot O_i^{(j-1)} + \sum_k e^{S_{ik}^{(j)} - m_i^{(j)}} V_k^{(j)}}{l_i^{(j)}} \end{aligned} \] @@ -129,16 +156,43 @@

Implementation Requirements

Example

-

With num_heads = 2, seq_len = 3, head_dim = 4:

-
-Input:  Q.shape = (2, 3, 4)   # 2 heads, 3 tokens, dim 4
-        K.shape = (2, 3, 4)
-        V.shape = (2, 3, 4)
-Output: output.shape = (2, 3, 4)
-
-# For head 0, token 0 (first query):
-#   Attends only to key 0 (causal) → output[0][0] = V[0][0]
-
+

+ Input: num_heads = 2, seq_len = 3, head_dim = 4 +

+

+ Head 0 — \(Q_0\): + \[ + \begin{bmatrix} + 1.0 & 0.0 & 1.0 & 0.0 \\ + 0.0 & 1.0 & 0.0 & 1.0 \\ + 1.0 & 1.0 & 0.0 & 0.0 + \end{bmatrix} + \] + Head 0 — \(K_0\): + \[ + \begin{bmatrix} + 1.0 & 0.0 & 0.0 & 1.0 \\ + 0.0 & 1.0 & 1.0 & 0.0 \\ + 1.0 & 0.0 & 1.0 & 0.0 + \end{bmatrix} + \] + Head 0 — \(V_0\): + \[ + \begin{bmatrix} + 1.0 & 2.0 & 3.0 & 4.0 \\ + 5.0 & 6.0 & 7.0 & 8.0 \\ + 9.0 & 10.0 & 11.0 & 12.0 + \end{bmatrix} + \] + Head 0 — output (token 0 attends only to itself due to causal mask): + \[ + \begin{bmatrix} + 1.0000 & 2.0000 & 3.0000 & 4.0000 \\ + 3.0000 & 4.0000 & 5.0000 & 6.0000 \\ + 5.0000 & 6.0000 & 7.0000 & 8.0000 + \end{bmatrix} + \] +

Constraints

    From 10ac388747099c6b63255f662c0ff7dd077e9947 Mon Sep 17 00:00:00 2001 From: James Song Date: Thu, 26 Mar 2026 21:47:42 -0400 Subject: [PATCH 4/4] Fix LaTeX rendering: use plain underscore in \text{head_dim} The escaped form \_ renders literally as a backslash in MathJax/KaTeX when inside \text{}. Co-Authored-By: Claude Opus 4.6 (1M context) --- challenges/hard/83_flash_attention/challenge.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/challenges/hard/83_flash_attention/challenge.html b/challenges/hard/83_flash_attention/challenge.html index ef75ea7d..5a9b659f 100644 --- a/challenges/hard/83_flash_attention/challenge.html +++ b/challenges/hard/83_flash_attention/challenge.html @@ -145,7 +145,7 @@

    Implementation Requirements

  • Use only native features (external libraries are not permitted)
  • The solve function signature must remain unchanged
  • Write the result into the output tensor
  • -
  • Scale factor: \(\text{scale} = 1 / \sqrt{\text{head\_dim}}\)
  • +
  • Scale factor: \(\text{scale} = 1 / \sqrt{\text{head_dim}}\)
  • Apply a causal mask: position \(j\) is masked out (set to \(-\infty\) before softmax) whenever \(j > i\)
  • Q, K, V are stored in row-major order with shape (num_heads, seq_len, head_dim)