Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 204 additions & 0 deletions challenges/hard/83_flash_attention/challenge.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
<p>
Implement causal multi-head self-attention using the Flash Attention algorithm. Given query,
key, and value tensors of shape <code>(num_heads, seq_len, head_dim)</code>, compute the
attention output where each query position can only attend to positions at or before its own
index (causal mask). Your implementation must use <strong>tiled computation</strong>: process
the sequence in blocks so that the full <code>seq_len &times; seq_len</code> attention score
matrix is never materialized in memory.
</p>

<svg width="620" height="380" viewBox="0 0 620 380" xmlns="http://www.w3.org/2000/svg"
style="display:block; margin:20px auto; font-family:monospace;">
<rect width="620" height="380" fill="#222" rx="8"/>
<defs>
<marker id="ah" markerWidth="8" markerHeight="8" refX="6" refY="3" orient="auto">
<path d="M0,0 L0,6 L8,3 z" fill="#aaa"/>
</marker>
</defs>

<!-- ================================================================ -->
<!-- TOP: Naive attention (what we want to avoid) -->
<!-- ================================================================ -->
<text x="16" y="20" fill="#666" font-size="10">NAIVE (materializes full N&#xd7;N matrix &#x2014; too much memory!)</text>

<!-- Q -->
<rect x="16" y="32" width="50" height="60" rx="3" fill="#1e2d4d" stroke="#4477bb" stroke-width="1.5"/>
<text x="41" y="55" text-anchor="middle" fill="#aaccee" font-size="10">Q</text>
<text x="41" y="68" text-anchor="middle" fill="#7799bb" font-size="9">[N&#xd7;d]</text>

<!-- times -->
<text x="80" y="66" fill="#888" font-size="13" text-anchor="middle">&#xd7;</text>

<!-- K^T -->
<rect x="94" y="32" width="60" height="50" rx="3" fill="#1e3d2d" stroke="#44aa66" stroke-width="1.5"/>
<text x="124" y="55" text-anchor="middle" fill="#aaeebb" font-size="10">K&#x1d40;</text>
<text x="124" y="68" text-anchor="middle" fill="#77bbaa" font-size="9">[d&#xd7;N]</text>

<!-- = -->
<text x="168" y="66" fill="#888" font-size="13" text-anchor="middle">=</text>

<!-- S full matrix (the problem) -->
<rect x="182" y="30" width="70" height="70" rx="3" fill="#3d1e1e" stroke="#cc4444" stroke-width="1.5"/>
<text x="217" y="60" text-anchor="middle" fill="#ee6666" font-size="10">S</text>
<text x="217" y="74" text-anchor="middle" fill="#cc4444" font-size="9">[N&#xd7;N]</text>

<!-- X mark -->
<text x="270" y="70" fill="#cc4444" font-size="14" text-anchor="middle">&#x2717;</text>
<text x="340" y="70" fill="#cc4444" font-size="10" text-anchor="middle">O(N&#xb2;) memory</text>

<!-- ================================================================ -->
<!-- MIDDLE: Flash Attention tiled approach -->
<!-- ================================================================ -->
<text x="16" y="122" fill="#666" font-size="10">FLASH ATTENTION (process one Q-block at a time, loop over K/V blocks)</text>

<!-- For each Q block i: -->
<text x="16" y="148" fill="#ccc" font-size="10">for each Q block i:</text>

<!-- Q_i block -->
<rect x="30" y="158" width="60" height="36" rx="3" fill="#1e2d4d" stroke="#4477bb" stroke-width="1.5"/>
<text x="60" y="176" text-anchor="middle" fill="#aaccee" font-size="10">Q block i</text>
<text x="60" y="188" text-anchor="middle" fill="#7799bb" font-size="8">[Br &#xd7; d]</text>

<!-- arrow right -->
<line x1="92" y1="176" x2="120" y2="176" stroke="#aaa" stroke-width="1.5" marker-end="url(#ah)"/>

<!-- Inner loop box -->
<rect x="124" y="140" width="370" height="70" rx="5" fill="#2d2010" stroke="#cc8833" stroke-width="1.5"/>
<text x="309" y="158" text-anchor="middle" fill="#ffcc66" font-size="10" font-weight="bold">for each K/V block j (where j &#x2264; i for causal):</text>

<!-- K_j block -->
<rect x="140" y="168" width="58" height="30" rx="3" fill="#1e3d2d" stroke="#44aa66" stroke-width="1"/>
<text x="169" y="187" text-anchor="middle" fill="#aaeebb" font-size="9">K block j</text>

<!-- V_j block -->
<rect x="206" y="168" width="58" height="30" rx="3" fill="#1e3d2d" stroke="#44aa66" stroke-width="1"/>
<text x="235" y="187" text-anchor="middle" fill="#aaeebb" font-size="9">V block j</text>

<!-- arrow right -->
<text x="278" y="187" fill="#888" font-size="12" text-anchor="middle">&#x2192;</text>

<!-- Small tile S_ij -->
<rect x="292" y="168" width="60" height="30" rx="3" fill="#3a2a1a" stroke="#e0a040" stroke-width="1"/>
<text x="322" y="182" text-anchor="middle" fill="#e0a040" font-size="9">S tile</text>
<text x="322" y="192" text-anchor="middle" fill="#cc9944" font-size="8">[Br&#xd7;Bc]</text>

<!-- arrow right -->
<text x="366" y="187" fill="#888" font-size="12" text-anchor="middle">&#x2192;</text>

<!-- Update O -->
<rect x="380" y="168" width="100" height="30" rx="3" fill="#2d1e4d" stroke="#8844cc" stroke-width="1"/>
<text x="430" y="187" text-anchor="middle" fill="#cc88ff" font-size="9">update m, l, O</text>

<!-- Checkmark and note -->
<text x="510" y="180" fill="#44aa66" font-size="14" text-anchor="middle">&#x2713;</text>
<text x="570" y="180" fill="#44aa66" font-size="10" text-anchor="middle">O(Br&#xd7;Bc)</text>

<!-- ================================================================ -->
<!-- BOTTOM: Online softmax formulas -->
<!-- ================================================================ -->
<rect x="16" y="230" width="588" height="136" rx="5" fill="#1a1a1a" stroke="#666" stroke-width="1"/>
<text x="310" y="252" text-anchor="middle" fill="#ccc" font-size="11" font-weight="bold">Online Softmax Update (per Q-row i, for each K/V tile j):</text>

<!-- Init -->
<text x="36" y="276" fill="#666" font-size="10">init:</text>
<text x="100" y="276" fill="#888" font-size="10">m = &#x2212;&#x221e;, l = 0, O = 0</text>

<!-- m update -->
<text x="36" y="298" fill="#ffcc66" font-size="10">1.</text>
<text x="56" y="298" fill="#aaa" font-size="10">m_new = max( m_prev, max_k( S_ik ) )</text>

<!-- l update -->
<text x="36" y="318" fill="#ffcc66" font-size="10">2.</text>
<text x="56" y="318" fill="#aaa" font-size="10">l_new = exp(m_prev &#x2212; m_new) &#xb7; l_prev + &#x2211;_k exp(S_ik &#x2212; m_new)</text>

<!-- O update -->
<text x="36" y="338" fill="#ffcc66" font-size="10">3.</text>
<text x="56" y="338" fill="#aaa" font-size="10">O_new = ( exp(m_prev &#x2212; m_new) &#xb7; l_prev &#xb7; O_prev + &#x2211;_k exp(S_ik &#x2212; m_new) &#xb7; V_k ) / l_new</text>

<!-- final -->
<text x="36" y="358" fill="#666" font-size="10">done:</text>
<text x="100" y="358" fill="#888" font-size="10">output[i] = O after all tiles processed</text>
</svg>

<p>
The key insight of Flash Attention is using the <strong>online softmax</strong> algorithm: instead
of computing softmax over the full row at once, maintain running statistics — the current maximum
\(m_i\) and the running normalizer \(l_i\) — and update the output accumulator \(O_i\) as each
new K/V tile is processed. This makes the computation exact while never requiring the full
\(N \times N\) score matrix in memory.
</p>

\[
\begin{aligned}
m_i^{(j)} &= \max\!\left(m_i^{(j-1)},\; \max_k S_{ik}^{(j)}\right) \\[4pt]
l_i^{(j)} &= e^{m_i^{(j-1)} - m_i^{(j)}} l_i^{(j-1)} + \sum_k e^{S_{ik}^{(j)} - m_i^{(j)}} \\[4pt]
O_i^{(j)} &= \frac{e^{m_i^{(j-1)} - m_i^{(j)}} l_i^{(j-1)} \cdot O_i^{(j-1)} + \sum_k e^{S_{ik}^{(j)} - m_i^{(j)}} V_k^{(j)}}{l_i^{(j)}}
\end{aligned}
\]

<p>where \(S_{ik}^{(j)} = \text{scale} \cdot Q_i K_k^\top\) is the raw attention score between
query position \(i\) and key position \(k\) in tile \(j\), with the causal mask applied
(\(S_{ik} = -\infty\) for \(k > i\)).</p>

<h2>Implementation Requirements</h2>
<ul>
<li>Use only native features (external libraries are not permitted)</li>
<li>The <code>solve</code> function signature must remain unchanged</li>
<li>Write the result into the <code>output</code> tensor</li>
<li>Scale factor: \(\text{scale} = 1 / \sqrt{\text{head_dim}}\)</li>
<li>Apply a causal mask: position \(j\) is masked out (set to \(-\infty\) before softmax) whenever \(j > i\)</li>
<li>Q, K, V are stored in row-major order with shape <code>(num_heads, seq_len, head_dim)</code></li>
<li>
Implement tiled computation: iterate over blocks of the key/value sequence and accumulate
the output using the online softmax recurrence above — do not allocate a full
<code>seq_len &times; seq_len</code> intermediate matrix
</li>
</ul>

<h2>Example</h2>
<p>
Input: <code>num_heads</code> = 2, <code>seq_len</code> = 3, <code>head_dim</code> = 4
</p>
<p>
Head 0 — \(Q_0\):
\[
\begin{bmatrix}
1.0 & 0.0 & 1.0 & 0.0 \\
0.0 & 1.0 & 0.0 & 1.0 \\
1.0 & 1.0 & 0.0 & 0.0
\end{bmatrix}
\]
Head 0 — \(K_0\):
\[
\begin{bmatrix}
1.0 & 0.0 & 0.0 & 1.0 \\
0.0 & 1.0 & 1.0 & 0.0 \\
1.0 & 0.0 & 1.0 & 0.0
\end{bmatrix}
\]
Head 0 — \(V_0\):
\[
\begin{bmatrix}
1.0 & 2.0 & 3.0 & 4.0 \\
5.0 & 6.0 & 7.0 & 8.0 \\
9.0 & 10.0 & 11.0 & 12.0
\end{bmatrix}
\]
Head 0 — output (token 0 attends only to itself due to causal mask):
\[
\begin{bmatrix}
1.0000 & 2.0000 & 3.0000 & 4.0000 \\
3.0000 & 4.0000 & 5.0000 & 6.0000 \\
5.0000 & 6.0000 & 7.0000 & 8.0000
\end{bmatrix}
\]
</p>

<h2>Constraints</h2>
<ul>
<li>1 &le; <code>num_heads</code> &le; 32</li>
<li>1 &le; <code>seq_len</code> &le; 8,192</li>
<li>1 &le; <code>head_dim</code> &le; 256</li>
<li>All tensors use 32-bit floating point</li>
<li>Performance is measured with <code>num_heads</code> = 8, <code>seq_len</code> = 4,096, <code>head_dim</code> = 64</li>
</ul>
150 changes: 150 additions & 0 deletions challenges/hard/83_flash_attention/challenge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import ctypes
import math
from typing import Any, Dict, List

import torch
from core.challenge_base import ChallengeBase


class Challenge(ChallengeBase):
def __init__(self):
super().__init__(
name="Flash Attention",
atol=1e-03,
rtol=1e-03,
num_gpus=1,
access_tier="free",
)

def reference_impl(
self,
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
output: torch.Tensor,
num_heads: int,
seq_len: int,
head_dim: int,
):
assert Q.shape == (num_heads, seq_len, head_dim)
assert K.shape == (num_heads, seq_len, head_dim)
assert V.shape == (num_heads, seq_len, head_dim)
assert output.shape == (num_heads, seq_len, head_dim)
assert Q.dtype == K.dtype == V.dtype == output.dtype == torch.float32
assert Q.device.type == "cuda"
assert K.device.type == "cuda"
assert V.device.type == "cuda"
assert output.device.type == "cuda"

scale = 1.0 / math.sqrt(head_dim)
# scores: (num_heads, seq_len, seq_len)
scores = torch.bmm(Q, K.transpose(1, 2)) * scale
# causal mask: upper triangle (j > i) set to -inf
causal_mask = torch.triu(
torch.full((seq_len, seq_len), float("-inf"), device=Q.device, dtype=Q.dtype),
diagonal=1,
)
scores = scores + causal_mask.unsqueeze(0)
attn_weights = torch.softmax(scores, dim=-1)
output.copy_(torch.bmm(attn_weights, V))

def get_solve_signature(self) -> Dict[str, tuple]:
return {
"Q": (ctypes.POINTER(ctypes.c_float), "in"),
"K": (ctypes.POINTER(ctypes.c_float), "in"),
"V": (ctypes.POINTER(ctypes.c_float), "in"),
"output": (ctypes.POINTER(ctypes.c_float), "out"),
"num_heads": (ctypes.c_int, "in"),
"seq_len": (ctypes.c_int, "in"),
"head_dim": (ctypes.c_int, "in"),
}

def _make_test_case(self, num_heads, seq_len, head_dim, zero_inputs=False):
dtype = torch.float32
device = "cuda"
if zero_inputs:
Q = torch.zeros(num_heads, seq_len, head_dim, device=device, dtype=dtype)
K = torch.zeros(num_heads, seq_len, head_dim, device=device, dtype=dtype)
V = torch.zeros(num_heads, seq_len, head_dim, device=device, dtype=dtype)
else:
Q = torch.randn(num_heads, seq_len, head_dim, device=device, dtype=dtype)
K = torch.randn(num_heads, seq_len, head_dim, device=device, dtype=dtype)
V = torch.randn(num_heads, seq_len, head_dim, device=device, dtype=dtype)
output = torch.zeros(num_heads, seq_len, head_dim, device=device, dtype=dtype)
return {
"Q": Q,
"K": K,
"V": V,
"output": output,
"num_heads": num_heads,
"seq_len": seq_len,
"head_dim": head_dim,
}

def generate_example_test(self) -> Dict[str, Any]:
dtype = torch.float32
device = "cuda"
num_heads = 2
seq_len = 3
head_dim = 4
Q = torch.tensor(
[
[[1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0], [1.0, 1.0, 0.0, 0.0]],
[[-1.0, 0.5, 0.0, 1.0], [0.0, -1.0, 1.0, 0.0], [0.5, 0.0, -0.5, 1.0]],
],
device=device,
dtype=dtype,
)
K = torch.tensor(
[
[[1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 1.0, 0.0], [1.0, 0.0, 1.0, 0.0]],
[[0.5, -0.5, 1.0, 0.0], [1.0, 0.0, -1.0, 0.5], [-0.5, 1.0, 0.0, -1.0]],
],
device=device,
dtype=dtype,
)
V = torch.tensor(
[
[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]],
[[-1.0, -2.0, -3.0, -4.0], [2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0]],
],
device=device,
dtype=dtype,
)
output = torch.zeros(num_heads, seq_len, head_dim, device=device, dtype=dtype)
return {
"Q": Q,
"K": K,
"V": V,
"output": output,
"num_heads": num_heads,
"seq_len": seq_len,
"head_dim": head_dim,
}

def generate_functional_test(self) -> List[Dict[str, Any]]:
torch.manual_seed(42)
tests = []
# Edge: single token (only attends to itself)
tests.append(self._make_test_case(1, 1, 8))
# Edge: 2 tokens
tests.append(self._make_test_case(1, 2, 8))
# Edge: 4 tokens, 2 heads
tests.append(self._make_test_case(2, 4, 16))
# Zero inputs
tests.append(self._make_test_case(2, 4, 8, zero_inputs=True))
# Power-of-2 sizes
tests.append(self._make_test_case(4, 16, 32))
tests.append(self._make_test_case(4, 64, 64))
# Non-power-of-2 sizes
tests.append(self._make_test_case(4, 30, 32))
tests.append(self._make_test_case(4, 100, 64))
# Realistic inference sizes
tests.append(self._make_test_case(8, 128, 64))
tests.append(self._make_test_case(8, 256, 64))
return tests

def generate_performance_test(self) -> Dict[str, Any]:
torch.manual_seed(0)
# LLM-scale: 8 heads, seq_len=4096, head_dim=64
return self._make_test_case(8, 4096, 64)
5 changes: 5 additions & 0 deletions challenges/hard/83_flash_attention/starter/starter.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include <cuda_runtime.h>

// Q, K, V, output are device pointers
extern "C" void solve(const float* Q, const float* K, const float* V, float* output, int num_heads,
int seq_len, int head_dim) {}
16 changes: 16 additions & 0 deletions challenges/hard/83_flash_attention/starter/starter.cute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import cutlass
import cutlass.cute as cute


# Q, K, V, output are tensors on the GPU
@cute.jit
def solve(
Q: cute.Tensor,
K: cute.Tensor,
V: cute.Tensor,
output: cute.Tensor,
num_heads: cute.Int32,
seq_len: cute.Int32,
head_dim: cute.Int32,
):
pass
11 changes: 11 additions & 0 deletions challenges/hard/83_flash_attention/starter/starter.jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import jax
import jax.numpy as jnp


# Q, K, V are tensors on GPU
@jax.jit
def solve(
Q: jax.Array, K: jax.Array, V: jax.Array, num_heads: int, seq_len: int, head_dim: int
) -> jax.Array:
# return output tensor directly
pass
Loading
Loading