From 7fe1867b3e3a149e9aa10648cd63084140d5ec15 Mon Sep 17 00:00:00 2001
From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com>
Date: Tue, 17 Mar 2026 04:15:42 +0000
Subject: [PATCH 1/4] Add challenge 83: Flash Attention (Hard)
Co-Authored-By: Claude Sonnet 4.6
---
.../hard/83_flash_attention/challenge.html | 150 ++++++++++++++++++
.../hard/83_flash_attention/challenge.py | 150 ++++++++++++++++++
.../83_flash_attention/solution/solution.cu | 91 +++++++++++
.../83_flash_attention/starter/starter.cu | 5 +
.../starter/starter.cute.py | 16 ++
.../83_flash_attention/starter/starter.jax.py | 11 ++
.../83_flash_attention/starter/starter.mojo | 9 ++
.../starter/starter.pytorch.py | 14 ++
.../starter/starter.triton.py | 16 ++
9 files changed, 462 insertions(+)
create mode 100644 challenges/hard/83_flash_attention/challenge.html
create mode 100644 challenges/hard/83_flash_attention/challenge.py
create mode 100644 challenges/hard/83_flash_attention/solution/solution.cu
create mode 100644 challenges/hard/83_flash_attention/starter/starter.cu
create mode 100644 challenges/hard/83_flash_attention/starter/starter.cute.py
create mode 100644 challenges/hard/83_flash_attention/starter/starter.jax.py
create mode 100644 challenges/hard/83_flash_attention/starter/starter.mojo
create mode 100644 challenges/hard/83_flash_attention/starter/starter.pytorch.py
create mode 100644 challenges/hard/83_flash_attention/starter/starter.triton.py
diff --git a/challenges/hard/83_flash_attention/challenge.html b/challenges/hard/83_flash_attention/challenge.html
new file mode 100644
index 00000000..00215193
--- /dev/null
+++ b/challenges/hard/83_flash_attention/challenge.html
@@ -0,0 +1,150 @@
+
+ Implement causal multi-head self-attention using the Flash Attention algorithm. Given query,
+ key, and value tensors of shape (num_heads, seq_len, head_dim), compute the
+ attention output where each query position can only attend to positions at or before its own
+ index (causal mask). Your implementation must use tiled computation : process
+ the sequence in blocks so that the full seq_len × seq_len attention score
+ matrix is never materialized in memory.
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Flash Attention: Tiled Computation
+
+
+ Q
+
+
+ Q₀ (rows 0..Br)
+
+
+ Q₁
+
+
+ Q₂
+
+
+ ⋮
+
+ seq_len / Br blocks
+
+
+
+
+
+ Kᵀ, V
+
+ K₀
+
+
+ K₁
+
+
+ K₂
+
+
+ …
+
+ seq_len / Bc blocks
+
+
+
+ Tile Sᵢⱼ = Qᵢ Kⱼᵀ · scale
+ Apply causal mask
+ Update (m, l, O) online
+
+
+
+ loop
+ over j
+
+
+ O
+
+ O₀ (accumulated)
+
+
+
+
+
+
+ Online Softmax Update (for each K/V tile j):
+ m_new = max(m_prev, rowmax(Sᵢⱼ))
+ l_new = exp(m_prev − m_new)·l_prev + rowsum(exp(Sᵢⱼ − m_new))
+ O_new = (exp(m_prev − m_new)·l_prev·O_prev + exp(Sᵢⱼ − m_new)·Vⱼ) / l_new
+
+
+
+ The key insight of Flash Attention is using the online softmax algorithm: instead
+ of computing softmax over the full row at once, maintain running statistics — the current maximum
+ \(m_i\) and the running normalizer \(l_i\) — and update the output accumulator \(O_i\) as each
+ new K/V tile is processed. This makes the computation exact while never requiring the full
+ \(N \times N\) score matrix in memory.
+
+
+\[
+\begin{aligned}
+m_i^{(j)} &= \max\!\left(m_i^{(j-1)},\; \max_k S_{ik}^{(j)}\right) \\[4pt]
+l_i^{(j)} &= e^{m_i^{(j-1)} - m_i^{(j)}} l_i^{(j-1)} + \sum_k e^{S_{ik}^{(j)} - m_i^{(j)}} \\[4pt]
+O_i^{(j)} &= \frac{e^{m_i^{(j-1)} - m_i^{(j)}} l_i^{(j-1)} \cdot O_i^{(j-1)} + e^{S_{ik}^{(j)} - m_i^{(j)}} V^{(j)}}{l_i^{(j)}}
+\end{aligned}
+\]
+
+where \(S_{ik}^{(j)} = \text{scale} \cdot Q_i K_k^\top\) is the raw attention score between
+query position \(i\) and key position \(k\) in tile \(j\), with the causal mask applied
+(\(S_{ik} = -\infty\) for \(k > i\)).
+
+Implementation Requirements
+
+ Use only native features (external libraries are not permitted)
+ The solve function signature must remain unchanged
+ Write the result into the output tensor
+ Scale factor: \(\text{scale} = 1 / \sqrt{\text{head\_dim}}\)
+ Apply a causal mask: position \(j\) is masked out (set to \(-\infty\) before softmax) whenever \(j > i\)
+ Q, K, V are stored in row-major order with shape (num_heads, seq_len, head_dim)
+
+ Implement tiled computation: iterate over blocks of the key/value sequence and accumulate
+ the output using the online softmax recurrence above — do not allocate a full
+ seq_len × seq_len intermediate matrix
+
+
+
+Example
+With num_heads = 2, seq_len = 3, head_dim = 4:
+
+Input: Q.shape = (2, 3, 4) # 2 heads, 3 tokens, dim 4
+ K.shape = (2, 3, 4)
+ V.shape = (2, 3, 4)
+Output: output.shape = (2, 3, 4)
+
+# For head 0, token 0 (first query):
+# Attends only to key 0 (causal) → output[0][0] = V[0][0]
+
+
+Constraints
+
+ 1 ≤ num_heads ≤ 32
+ 1 ≤ seq_len ≤ 8,192
+ 1 ≤ head_dim ≤ 256
+ All tensors use 32-bit floating point
+ Performance is measured with num_heads = 8, seq_len = 4,096, head_dim = 64
+
diff --git a/challenges/hard/83_flash_attention/challenge.py b/challenges/hard/83_flash_attention/challenge.py
new file mode 100644
index 00000000..d1e1eb38
--- /dev/null
+++ b/challenges/hard/83_flash_attention/challenge.py
@@ -0,0 +1,150 @@
+import ctypes
+import math
+from typing import Any, Dict, List
+
+import torch
+from core.challenge_base import ChallengeBase
+
+
+class Challenge(ChallengeBase):
+ def __init__(self):
+ super().__init__(
+ name="Flash Attention",
+ atol=1e-03,
+ rtol=1e-03,
+ num_gpus=1,
+ access_tier="free",
+ )
+
+ def reference_impl(
+ self,
+ Q: torch.Tensor,
+ K: torch.Tensor,
+ V: torch.Tensor,
+ output: torch.Tensor,
+ num_heads: int,
+ seq_len: int,
+ head_dim: int,
+ ):
+ assert Q.shape == (num_heads, seq_len, head_dim)
+ assert K.shape == (num_heads, seq_len, head_dim)
+ assert V.shape == (num_heads, seq_len, head_dim)
+ assert output.shape == (num_heads, seq_len, head_dim)
+ assert Q.dtype == K.dtype == V.dtype == output.dtype == torch.float32
+ assert Q.device.type == "cuda"
+ assert K.device.type == "cuda"
+ assert V.device.type == "cuda"
+ assert output.device.type == "cuda"
+
+ scale = 1.0 / math.sqrt(head_dim)
+ # scores: (num_heads, seq_len, seq_len)
+ scores = torch.bmm(Q, K.transpose(1, 2)) * scale
+ # causal mask: upper triangle (j > i) set to -inf
+ causal_mask = torch.triu(
+ torch.full((seq_len, seq_len), float("-inf"), device=Q.device, dtype=Q.dtype),
+ diagonal=1,
+ )
+ scores = scores + causal_mask.unsqueeze(0)
+ attn_weights = torch.softmax(scores, dim=-1)
+ output.copy_(torch.bmm(attn_weights, V))
+
+ def get_solve_signature(self) -> Dict[str, tuple]:
+ return {
+ "Q": (ctypes.POINTER(ctypes.c_float), "in"),
+ "K": (ctypes.POINTER(ctypes.c_float), "in"),
+ "V": (ctypes.POINTER(ctypes.c_float), "in"),
+ "output": (ctypes.POINTER(ctypes.c_float), "out"),
+ "num_heads": (ctypes.c_int, "in"),
+ "seq_len": (ctypes.c_int, "in"),
+ "head_dim": (ctypes.c_int, "in"),
+ }
+
+ def _make_test_case(self, num_heads, seq_len, head_dim, zero_inputs=False):
+ dtype = torch.float32
+ device = "cuda"
+ if zero_inputs:
+ Q = torch.zeros(num_heads, seq_len, head_dim, device=device, dtype=dtype)
+ K = torch.zeros(num_heads, seq_len, head_dim, device=device, dtype=dtype)
+ V = torch.zeros(num_heads, seq_len, head_dim, device=device, dtype=dtype)
+ else:
+ Q = torch.randn(num_heads, seq_len, head_dim, device=device, dtype=dtype)
+ K = torch.randn(num_heads, seq_len, head_dim, device=device, dtype=dtype)
+ V = torch.randn(num_heads, seq_len, head_dim, device=device, dtype=dtype)
+ output = torch.zeros(num_heads, seq_len, head_dim, device=device, dtype=dtype)
+ return {
+ "Q": Q,
+ "K": K,
+ "V": V,
+ "output": output,
+ "num_heads": num_heads,
+ "seq_len": seq_len,
+ "head_dim": head_dim,
+ }
+
+ def generate_example_test(self) -> Dict[str, Any]:
+ dtype = torch.float32
+ device = "cuda"
+ num_heads = 2
+ seq_len = 3
+ head_dim = 4
+ Q = torch.tensor(
+ [
+ [[1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0], [1.0, 1.0, 0.0, 0.0]],
+ [[-1.0, 0.5, 0.0, 1.0], [0.0, -1.0, 1.0, 0.0], [0.5, 0.0, -0.5, 1.0]],
+ ],
+ device=device,
+ dtype=dtype,
+ )
+ K = torch.tensor(
+ [
+ [[1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 1.0, 0.0], [1.0, 0.0, 1.0, 0.0]],
+ [[0.5, -0.5, 1.0, 0.0], [1.0, 0.0, -1.0, 0.5], [-0.5, 1.0, 0.0, -1.0]],
+ ],
+ device=device,
+ dtype=dtype,
+ )
+ V = torch.tensor(
+ [
+ [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]],
+ [[-1.0, -2.0, -3.0, -4.0], [2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0]],
+ ],
+ device=device,
+ dtype=dtype,
+ )
+ output = torch.zeros(num_heads, seq_len, head_dim, device=device, dtype=dtype)
+ return {
+ "Q": Q,
+ "K": K,
+ "V": V,
+ "output": output,
+ "num_heads": num_heads,
+ "seq_len": seq_len,
+ "head_dim": head_dim,
+ }
+
+ def generate_functional_test(self) -> List[Dict[str, Any]]:
+ torch.manual_seed(42)
+ tests = []
+ # Edge: single token (only attends to itself)
+ tests.append(self._make_test_case(1, 1, 8))
+ # Edge: 2 tokens
+ tests.append(self._make_test_case(1, 2, 8))
+ # Edge: 4 tokens, 2 heads
+ tests.append(self._make_test_case(2, 4, 16))
+ # Zero inputs
+ tests.append(self._make_test_case(2, 4, 8, zero_inputs=True))
+ # Power-of-2 sizes
+ tests.append(self._make_test_case(4, 16, 32))
+ tests.append(self._make_test_case(4, 64, 64))
+ # Non-power-of-2 sizes
+ tests.append(self._make_test_case(4, 30, 32))
+ tests.append(self._make_test_case(4, 100, 64))
+ # Realistic inference sizes
+ tests.append(self._make_test_case(8, 128, 64))
+ tests.append(self._make_test_case(8, 256, 64))
+ return tests
+
+ def generate_performance_test(self) -> Dict[str, Any]:
+ torch.manual_seed(0)
+ # LLM-scale: 8 heads, seq_len=4096, head_dim=64
+ return self._make_test_case(8, 4096, 64)
diff --git a/challenges/hard/83_flash_attention/solution/solution.cu b/challenges/hard/83_flash_attention/solution/solution.cu
new file mode 100644
index 00000000..97bb44a1
--- /dev/null
+++ b/challenges/hard/83_flash_attention/solution/solution.cu
@@ -0,0 +1,91 @@
+#include
+#include
+#include
+
+#define BR 32
+#define BC 32
+#define MAX_HEAD_DIM 256
+
+__global__ void flash_attn_kernel(const float* __restrict__ Q, const float* __restrict__ K,
+ const float* __restrict__ V, float* __restrict__ output,
+ int seq_len, int head_dim, float scale) {
+ int h = blockIdx.x;
+ int qi_block = blockIdx.y;
+ int ti = threadIdx.x;
+ int qi = qi_block * BR + ti;
+ if (qi >= seq_len)
+ return;
+
+ const float* Qh = Q + (long)h * seq_len * head_dim;
+ const float* Kh = K + (long)h * seq_len * head_dim;
+ const float* Vh = V + (long)h * seq_len * head_dim;
+ float* Oh = output + (long)h * seq_len * head_dim;
+
+ float m = -FLT_MAX;
+ float l = 0.0f;
+ float acc[MAX_HEAD_DIM];
+ for (int d = 0; d < head_dim; d++)
+ acc[d] = 0.0f;
+
+ for (int kj_block = 0; (long)kj_block * BC <= qi; kj_block++) {
+ int kj_start = kj_block * BC;
+ int kj_end = kj_start + BC;
+ if (kj_end > seq_len)
+ kj_end = seq_len;
+ if (kj_end > qi + 1)
+ kj_end = qi + 1;
+ int actual_len = kj_end - kj_start;
+
+ float s[BC];
+ for (int j = 0; j < actual_len; j++) {
+ int kj = kj_start + j;
+ float dot = 0.0f;
+ for (int d = 0; d < head_dim; d++) {
+ dot += Qh[qi * head_dim + d] * Kh[kj * head_dim + d];
+ }
+ s[j] = dot * scale;
+ }
+
+ float m_tile = -FLT_MAX;
+ for (int j = 0; j < actual_len; j++) {
+ if (s[j] > m_tile)
+ m_tile = s[j];
+ }
+
+ float m_new = (m > m_tile) ? m : m_tile;
+ float alpha = expf(m - m_new);
+
+ float p[BC];
+ float l_tile = 0.0f;
+ for (int j = 0; j < actual_len; j++) {
+ p[j] = expf(s[j] - m_new);
+ l_tile += p[j];
+ }
+ float l_new = alpha * l + l_tile;
+
+ for (int d = 0; d < head_dim; d++) {
+ float pv = 0.0f;
+ for (int j = 0; j < actual_len; j++) {
+ pv += p[j] * Vh[(kj_start + j) * head_dim + d];
+ }
+ acc[d] = (alpha * l * acc[d] + pv) / l_new;
+ }
+
+ m = m_new;
+ l = l_new;
+ }
+
+ for (int d = 0; d < head_dim; d++) {
+ Oh[qi * head_dim + d] = acc[d];
+ }
+}
+
+extern "C" void solve(const float* Q, const float* K, const float* V, float* output, int num_heads,
+ int seq_len, int head_dim) {
+ float scale = 1.0f / sqrtf((float)head_dim);
+ int num_q_blocks = (seq_len + BR - 1) / BR;
+ dim3 grid(num_heads, num_q_blocks);
+ dim3 block(BR);
+ flash_attn_kernel<<>>(Q, K, V, output, seq_len, head_dim, scale);
+ cudaDeviceSynchronize();
+}
diff --git a/challenges/hard/83_flash_attention/starter/starter.cu b/challenges/hard/83_flash_attention/starter/starter.cu
new file mode 100644
index 00000000..707af5d8
--- /dev/null
+++ b/challenges/hard/83_flash_attention/starter/starter.cu
@@ -0,0 +1,5 @@
+#include
+
+// Q, K, V, output are device pointers
+extern "C" void solve(const float* Q, const float* K, const float* V, float* output, int num_heads,
+ int seq_len, int head_dim) {}
diff --git a/challenges/hard/83_flash_attention/starter/starter.cute.py b/challenges/hard/83_flash_attention/starter/starter.cute.py
new file mode 100644
index 00000000..06da950a
--- /dev/null
+++ b/challenges/hard/83_flash_attention/starter/starter.cute.py
@@ -0,0 +1,16 @@
+import cutlass
+import cutlass.cute as cute
+
+
+# Q, K, V, output are tensors on the GPU
+@cute.jit
+def solve(
+ Q: cute.Tensor,
+ K: cute.Tensor,
+ V: cute.Tensor,
+ output: cute.Tensor,
+ num_heads: cute.Int32,
+ seq_len: cute.Int32,
+ head_dim: cute.Int32,
+):
+ pass
diff --git a/challenges/hard/83_flash_attention/starter/starter.jax.py b/challenges/hard/83_flash_attention/starter/starter.jax.py
new file mode 100644
index 00000000..d915992a
--- /dev/null
+++ b/challenges/hard/83_flash_attention/starter/starter.jax.py
@@ -0,0 +1,11 @@
+import jax
+import jax.numpy as jnp
+
+
+# Q, K, V are tensors on GPU
+@jax.jit
+def solve(
+ Q: jax.Array, K: jax.Array, V: jax.Array, num_heads: int, seq_len: int, head_dim: int
+) -> jax.Array:
+ # return output tensor directly
+ pass
diff --git a/challenges/hard/83_flash_attention/starter/starter.mojo b/challenges/hard/83_flash_attention/starter/starter.mojo
new file mode 100644
index 00000000..1bedbb3f
--- /dev/null
+++ b/challenges/hard/83_flash_attention/starter/starter.mojo
@@ -0,0 +1,9 @@
+from gpu.host import DeviceContext
+from gpu.id import block_dim, block_idx, thread_idx
+from memory import UnsafePointer
+from math import ceildiv
+
+# Q, K, V, output are device pointers
+@export
+def solve(Q: UnsafePointer[Float32], K: UnsafePointer[Float32], V: UnsafePointer[Float32], output: UnsafePointer[Float32], num_heads: Int32, seq_len: Int32, head_dim: Int32):
+ pass
diff --git a/challenges/hard/83_flash_attention/starter/starter.pytorch.py b/challenges/hard/83_flash_attention/starter/starter.pytorch.py
new file mode 100644
index 00000000..7ae6982d
--- /dev/null
+++ b/challenges/hard/83_flash_attention/starter/starter.pytorch.py
@@ -0,0 +1,14 @@
+import torch
+
+
+# Q, K, V, output are tensors on the GPU
+def solve(
+ Q: torch.Tensor,
+ K: torch.Tensor,
+ V: torch.Tensor,
+ output: torch.Tensor,
+ num_heads: int,
+ seq_len: int,
+ head_dim: int,
+):
+ pass
diff --git a/challenges/hard/83_flash_attention/starter/starter.triton.py b/challenges/hard/83_flash_attention/starter/starter.triton.py
new file mode 100644
index 00000000..b0e09f23
--- /dev/null
+++ b/challenges/hard/83_flash_attention/starter/starter.triton.py
@@ -0,0 +1,16 @@
+import torch
+import triton
+import triton.language as tl
+
+
+# Q, K, V, output are tensors on the GPU
+def solve(
+ Q: torch.Tensor,
+ K: torch.Tensor,
+ V: torch.Tensor,
+ output: torch.Tensor,
+ num_heads: int,
+ seq_len: int,
+ head_dim: int,
+):
+ pass
From 1aa0f716d11fe73629d56d212105b19eac6dca9a Mon Sep 17 00:00:00 2001
From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com>
Date: Tue, 17 Mar 2026 04:15:56 +0000
Subject: [PATCH 2/4] Remove solution file from challenge 83
Co-Authored-By: Claude Sonnet 4.6
---
.../83_flash_attention/solution/solution.cu | 91 -------------------
1 file changed, 91 deletions(-)
delete mode 100644 challenges/hard/83_flash_attention/solution/solution.cu
diff --git a/challenges/hard/83_flash_attention/solution/solution.cu b/challenges/hard/83_flash_attention/solution/solution.cu
deleted file mode 100644
index 97bb44a1..00000000
--- a/challenges/hard/83_flash_attention/solution/solution.cu
+++ /dev/null
@@ -1,91 +0,0 @@
-#include
-#include
-#include
-
-#define BR 32
-#define BC 32
-#define MAX_HEAD_DIM 256
-
-__global__ void flash_attn_kernel(const float* __restrict__ Q, const float* __restrict__ K,
- const float* __restrict__ V, float* __restrict__ output,
- int seq_len, int head_dim, float scale) {
- int h = blockIdx.x;
- int qi_block = blockIdx.y;
- int ti = threadIdx.x;
- int qi = qi_block * BR + ti;
- if (qi >= seq_len)
- return;
-
- const float* Qh = Q + (long)h * seq_len * head_dim;
- const float* Kh = K + (long)h * seq_len * head_dim;
- const float* Vh = V + (long)h * seq_len * head_dim;
- float* Oh = output + (long)h * seq_len * head_dim;
-
- float m = -FLT_MAX;
- float l = 0.0f;
- float acc[MAX_HEAD_DIM];
- for (int d = 0; d < head_dim; d++)
- acc[d] = 0.0f;
-
- for (int kj_block = 0; (long)kj_block * BC <= qi; kj_block++) {
- int kj_start = kj_block * BC;
- int kj_end = kj_start + BC;
- if (kj_end > seq_len)
- kj_end = seq_len;
- if (kj_end > qi + 1)
- kj_end = qi + 1;
- int actual_len = kj_end - kj_start;
-
- float s[BC];
- for (int j = 0; j < actual_len; j++) {
- int kj = kj_start + j;
- float dot = 0.0f;
- for (int d = 0; d < head_dim; d++) {
- dot += Qh[qi * head_dim + d] * Kh[kj * head_dim + d];
- }
- s[j] = dot * scale;
- }
-
- float m_tile = -FLT_MAX;
- for (int j = 0; j < actual_len; j++) {
- if (s[j] > m_tile)
- m_tile = s[j];
- }
-
- float m_new = (m > m_tile) ? m : m_tile;
- float alpha = expf(m - m_new);
-
- float p[BC];
- float l_tile = 0.0f;
- for (int j = 0; j < actual_len; j++) {
- p[j] = expf(s[j] - m_new);
- l_tile += p[j];
- }
- float l_new = alpha * l + l_tile;
-
- for (int d = 0; d < head_dim; d++) {
- float pv = 0.0f;
- for (int j = 0; j < actual_len; j++) {
- pv += p[j] * Vh[(kj_start + j) * head_dim + d];
- }
- acc[d] = (alpha * l * acc[d] + pv) / l_new;
- }
-
- m = m_new;
- l = l_new;
- }
-
- for (int d = 0; d < head_dim; d++) {
- Oh[qi * head_dim + d] = acc[d];
- }
-}
-
-extern "C" void solve(const float* Q, const float* K, const float* V, float* output, int num_heads,
- int seq_len, int head_dim) {
- float scale = 1.0f / sqrtf((float)head_dim);
- int num_q_blocks = (seq_len + BR - 1) / BR;
- dim3 grid(num_heads, num_q_blocks);
- dim3 block(BR);
- flash_attn_kernel<<>>(Q, K, V, output, seq_len, head_dim, scale);
- cudaDeviceSynchronize();
-}
From af2e359698d3dffd62d31f1bfb3478dfb95dbe41 Mon Sep 17 00:00:00 2001
From: James Song
Date: Thu, 26 Mar 2026 21:26:32 -0400
Subject: [PATCH 3/4] Fix flash attention HTML: example values, O formula
summation, SVG redesign
- Show actual Q/K/V/output values from generate_example_test() instead
of just tensor shapes
- Add missing summation over k in the LaTeX O_i update formula
- Redesign SVG to clearly contrast naive O(N^2) approach vs Flash
Attention's tiled approach with online softmax update steps
Co-Authored-By: Claude Opus 4.6 (1M context)
---
.../hard/83_flash_attention/challenge.html | 236 +++++++++++-------
1 file changed, 145 insertions(+), 91 deletions(-)
diff --git a/challenges/hard/83_flash_attention/challenge.html b/challenges/hard/83_flash_attention/challenge.html
index 00215193..ef75ea7d 100644
--- a/challenges/hard/83_flash_attention/challenge.html
+++ b/challenges/hard/83_flash_attention/challenge.html
@@ -7,90 +7,117 @@
matrix is never materialized in memory.
-
+
+
-
-
+
+
-
-
-
-
- Flash Attention: Tiled Computation
-
-
- Q
-
-
- Q₀ (rows 0..Br)
-
-
- Q₁
-
-
- Q₂
-
-
- ⋮
-
- seq_len / Br blocks
-
-
-
-
-
- Kᵀ, V
-
- K₀
-
-
- K₁
-
-
- K₂
-
-
- …
-
- seq_len / Bc blocks
-
-
-
- Tile Sᵢⱼ = Qᵢ Kⱼᵀ · scale
- Apply causal mask
- Update (m, l, O) online
-
-
-
- loop
- over j
-
-
- O
-
- O₀ (accumulated)
-
-
-
-
-
-
- Online Softmax Update (for each K/V tile j):
- m_new = max(m_prev, rowmax(Sᵢⱼ))
- l_new = exp(m_prev − m_new)·l_prev + rowsum(exp(Sᵢⱼ − m_new))
- O_new = (exp(m_prev − m_new)·l_prev·O_prev + exp(Sᵢⱼ − m_new)·Vⱼ) / l_new
+
+
+
+ NAIVE (materializes full N×N matrix — too much memory!)
+
+
+
+ Q
+ [N×d]
+
+
+ ×
+
+
+
+ Kᵀ
+ [d×N]
+
+
+ =
+
+
+
+ S
+ [N×N]
+
+
+ ✗
+ O(N²) memory
+
+
+
+
+ FLASH ATTENTION (process one Q-block at a time, loop over K/V blocks)
+
+
+ for each Q block i:
+
+
+
+ Q block i
+ [Br × d]
+
+
+
+
+
+
+ for each K/V block j (where j ≤ i for causal):
+
+
+
+ K block j
+
+
+
+ V block j
+
+
+ →
+
+
+
+ S tile
+ [Br×Bc]
+
+
+ →
+
+
+
+ update m, l, O
+
+
+ ✓
+ O(Br×Bc)
+
+
+
+
+
+ Online Softmax Update (per Q-row i, for each K/V tile j):
+
+
+ init:
+ m = −∞, l = 0, O = 0
+
+
+ 1.
+ m_new = max( m_prev, max_k( S_ik ) )
+
+
+ 2.
+ l_new = exp(m_prev − m_new) · l_prev + ∑_k exp(S_ik − m_new)
+
+
+ 3.
+ O_new = ( exp(m_prev − m_new) · l_prev · O_prev + ∑_k exp(S_ik − m_new) · V_k ) / l_new
+
+
+ done:
+ output[i] = O after all tiles processed
@@ -105,7 +132,7 @@
\begin{aligned}
m_i^{(j)} &= \max\!\left(m_i^{(j-1)},\; \max_k S_{ik}^{(j)}\right) \\[4pt]
l_i^{(j)} &= e^{m_i^{(j-1)} - m_i^{(j)}} l_i^{(j-1)} + \sum_k e^{S_{ik}^{(j)} - m_i^{(j)}} \\[4pt]
-O_i^{(j)} &= \frac{e^{m_i^{(j-1)} - m_i^{(j)}} l_i^{(j-1)} \cdot O_i^{(j-1)} + e^{S_{ik}^{(j)} - m_i^{(j)}} V^{(j)}}{l_i^{(j)}}
+O_i^{(j)} &= \frac{e^{m_i^{(j-1)} - m_i^{(j)}} l_i^{(j-1)} \cdot O_i^{(j-1)} + \sum_k e^{S_{ik}^{(j)} - m_i^{(j)}} V_k^{(j)}}{l_i^{(j)}}
\end{aligned}
\]
@@ -129,16 +156,43 @@
Implementation Requirements
Example
-With num_heads = 2, seq_len = 3, head_dim = 4:
-
-Input: Q.shape = (2, 3, 4) # 2 heads, 3 tokens, dim 4
- K.shape = (2, 3, 4)
- V.shape = (2, 3, 4)
-Output: output.shape = (2, 3, 4)
-
-# For head 0, token 0 (first query):
-# Attends only to key 0 (causal) → output[0][0] = V[0][0]
-
+
+ Input: num_heads = 2, seq_len = 3, head_dim = 4
+
+
+ Head 0 — \(Q_0\):
+ \[
+ \begin{bmatrix}
+ 1.0 & 0.0 & 1.0 & 0.0 \\
+ 0.0 & 1.0 & 0.0 & 1.0 \\
+ 1.0 & 1.0 & 0.0 & 0.0
+ \end{bmatrix}
+ \]
+ Head 0 — \(K_0\):
+ \[
+ \begin{bmatrix}
+ 1.0 & 0.0 & 0.0 & 1.0 \\
+ 0.0 & 1.0 & 1.0 & 0.0 \\
+ 1.0 & 0.0 & 1.0 & 0.0
+ \end{bmatrix}
+ \]
+ Head 0 — \(V_0\):
+ \[
+ \begin{bmatrix}
+ 1.0 & 2.0 & 3.0 & 4.0 \\
+ 5.0 & 6.0 & 7.0 & 8.0 \\
+ 9.0 & 10.0 & 11.0 & 12.0
+ \end{bmatrix}
+ \]
+ Head 0 — output (token 0 attends only to itself due to causal mask):
+ \[
+ \begin{bmatrix}
+ 1.0000 & 2.0000 & 3.0000 & 4.0000 \\
+ 3.0000 & 4.0000 & 5.0000 & 6.0000 \\
+ 5.0000 & 6.0000 & 7.0000 & 8.0000
+ \end{bmatrix}
+ \]
+
Constraints
From 10ac388747099c6b63255f662c0ff7dd077e9947 Mon Sep 17 00:00:00 2001
From: James Song
Date: Thu, 26 Mar 2026 21:47:42 -0400
Subject: [PATCH 4/4] Fix LaTeX rendering: use plain underscore in
\text{head_dim}
The escaped form \_ renders literally as a backslash in MathJax/KaTeX
when inside \text{}.
Co-Authored-By: Claude Opus 4.6 (1M context)
---
challenges/hard/83_flash_attention/challenge.html | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/challenges/hard/83_flash_attention/challenge.html b/challenges/hard/83_flash_attention/challenge.html
index ef75ea7d..5a9b659f 100644
--- a/challenges/hard/83_flash_attention/challenge.html
+++ b/challenges/hard/83_flash_attention/challenge.html
@@ -145,7 +145,7 @@ Implementation Requirements
Use only native features (external libraries are not permitted)
The solve function signature must remain unchanged
Write the result into the output tensor
- Scale factor: \(\text{scale} = 1 / \sqrt{\text{head\_dim}}\)
+ Scale factor: \(\text{scale} = 1 / \sqrt{\text{head_dim}}\)
Apply a causal mask: position \(j\) is masked out (set to \(-\infty\) before softmax) whenever \(j > i\)
Q, K, V are stored in row-major order with shape (num_heads, seq_len, head_dim)