Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 187 additions & 0 deletions challenges/medium/80_grouped_query_attention/challenge.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
<p>
Implement Grouped Query Attention (GQA), the attention mechanism used in modern large language
models such as LLaMA-3, Mistral, and Gemma. GQA reduces the KV-cache memory footprint during
inference by sharing key and value heads across groups of query heads. Given query tensor
<code>Q</code> with <code>num_q_heads</code> heads and key/value tensors <code>K</code>,
<code>V</code> each with <code>num_kv_heads</code> heads, compute scaled dot-product attention
where every group of <code>num_q_heads / num_kv_heads</code> consecutive query heads attends to
the same key and value head. All tensors use <code>float32</code>.
</p>

<svg width="700" height="260" viewBox="0 0 700 260" xmlns="http://www.w3.org/2000/svg" style="display:block; margin:20px auto;">
<rect width="700" height="260" fill="#222" rx="10"/>
<!-- Title -->
<text x="350" y="28" fill="#ccc" font-family="monospace" font-size="14" text-anchor="middle">Grouped Query Attention (num_q_heads=4, num_kv_heads=2, groups=2)</text>

<!-- Q heads -->
<text x="80" y="60" fill="#aaa" font-family="monospace" font-size="12" text-anchor="middle">Q heads</text>
<rect x="20" y="70" width="60" height="36" fill="#2563eb" rx="4"/>
<text x="50" y="93" fill="#fff" font-family="monospace" font-size="12" text-anchor="middle">Q[0]</text>
<rect x="100" y="70" width="60" height="36" fill="#2563eb" rx="4"/>
<text x="130" y="93" fill="#fff" font-family="monospace" font-size="12" text-anchor="middle">Q[1]</text>
<rect x="180" y="70" width="60" height="36" fill="#7c3aed" rx="4"/>
<text x="210" y="93" fill="#fff" font-family="monospace" font-size="12" text-anchor="middle">Q[2]</text>
<rect x="260" y="70" width="60" height="36" fill="#7c3aed" rx="4"/>
<text x="290" y="93" fill="#fff" font-family="monospace" font-size="12" text-anchor="middle">Q[3]</text>

<!-- KV heads -->
<text x="80" y="175" fill="#aaa" font-family="monospace" font-size="12" text-anchor="middle">KV heads</text>
<rect x="20" y="185" width="120" height="36" fill="#1d4ed8" rx="4"/>
<text x="80" y="208" fill="#fff" font-family="monospace" font-size="12" text-anchor="middle">K[0], V[0]</text>
<rect x="180" y="185" width="120" height="36" fill="#5b21b6" rx="4"/>
<text x="240" y="208" fill="#fff" font-family="monospace" font-size="12" text-anchor="middle">K[1], V[1]</text>

<!-- Arrows group 0 -->
<line x1="50" y1="106" x2="70" y2="185" stroke="#60a5fa" stroke-width="1.5" marker-end="url(#arr)"/>
<line x1="130" y1="106" x2="90" y2="185" stroke="#60a5fa" stroke-width="1.5" marker-end="url(#arr)"/>

<!-- Arrows group 1 -->
<line x1="210" y1="106" x2="230" y2="185" stroke="#c4b5fd" stroke-width="1.5" marker-end="url(#arr)"/>
<line x1="290" y1="106" x2="250" y2="185" stroke="#c4b5fd" stroke-width="1.5" marker-end="url(#arr)"/>

<!-- Output boxes -->
<text x="80" y="245" fill="#aaa" font-family="monospace" font-size="11" text-anchor="middle">group 0</text>
<text x="240" y="245" fill="#aaa" font-family="monospace" font-size="11" text-anchor="middle">group 1</text>

<!-- bracket labels -->
<text x="430" y="88" fill="#60a5fa" font-family="monospace" font-size="12">Q[0], Q[1] attend to K[0], V[0]</text>
<text x="430" y="112" fill="#c4b5fd" font-family="monospace" font-size="12">Q[2], Q[3] attend to K[1], V[1]</text>
<text x="430" y="150" fill="#4ade80" font-family="monospace" font-size="12">scale = 1 / sqrt(head_dim)</text>
<text x="430" y="174" fill="#4ade80" font-family="monospace" font-size="12">scores = Q @ K^T * scale</text>
<text x="430" y="198" fill="#4ade80" font-family="monospace" font-size="12">weights = softmax(scores)</text>
<text x="430" y="222" fill="#4ade80" font-family="monospace" font-size="12">output = weights @ V</text>

<defs>
<marker id="arr" markerWidth="6" markerHeight="6" refX="3" refY="3" orient="auto">
<path d="M0,0 L0,6 L6,3 z" fill="#888"/>
</marker>
</defs>
</svg>

<h2>Implementation Requirements</h2>
<ul>
<li>Implement the function <code>solve(Q, K, V, output, num_q_heads, num_kv_heads, seq_len, head_dim)</code>.</li>
<li>Do not change the function signature or use external libraries beyond the standard GPU frameworks.</li>
<li>Write the result into the provided <code>output</code> buffer.</li>
<li><code>num_q_heads</code> is always divisible by <code>num_kv_heads</code>.</li>
<li>Use scaled dot-product attention with scale factor <code>1 / sqrt(head_dim)</code> and a softmax over the key dimension.</li>
</ul>

<h2>Example</h2>
<p>
With <code>num_q_heads</code> = 4, <code>num_kv_heads</code> = 2 (groups of 2), <code>seq_len</code> = 3,
<code>head_dim</code> = 4:
</p>
<p>
<strong>Input:</strong><br>
\(Q_0\) (3&times;4):
\[
\begin{bmatrix}
1 & 0 & 0 & 1 \\
0 & 1 & 1 & 0 \\
1 & 1 & 0 & 0
\end{bmatrix}
\]
\(Q_1\) (3&times;4):
\[
\begin{bmatrix}
0 & 1 & 0 & 1 \\
1 & 0 & 1 & 0 \\
0 & 0 & 1 & 1
\end{bmatrix}
\]
\(Q_2\) (3&times;4):
\[
\begin{bmatrix}
-1 & 0 & 0.5 & 0 \\
0 & -1 & 0 & 0.5 \\
0.5 & 0 & -1 & 0
\end{bmatrix}
\]
\(Q_3\) (3&times;4):
\[
\begin{bmatrix}
0 & 0.5 & 0 & -1 \\
0.5 & 0 & 0 & -1 \\
0 & 0 & 0.5 & 0.5
\end{bmatrix}
\]
\(K_0\) (3&times;4):
\[
\begin{bmatrix}
1 & 0 & 1 & 0 \\
0 & 1 & 0 & 1 \\
1 & 1 & 1 & 1
\end{bmatrix}
\]
\(K_1\) (3&times;4):
\[
\begin{bmatrix}
0 & 1 & 0 & -1 \\
-1 & 0 & 1 & 0 \\
0 & -1 & 0 & 1
\end{bmatrix}
\]
\(V_0\) (3&times;4):
\[
\begin{bmatrix}
1 & 2 & 3 & 4 \\
5 & 6 & 7 & 8 \\
9 & 10 & 11 & 12
\end{bmatrix}
\]
\(V_1\) (3&times;4):
\[
\begin{bmatrix}
-1 & -2 & -3 & -4 \\
2 & 3 & 4 & 5 \\
6 & 7 & 8 & 9
\end{bmatrix}
\]
Groups: \(Q_0, Q_1 \to K_0, V_0\); \quad \(Q_2, Q_3 \to K_1, V_1\)
</p>
<p>
<strong>Output</strong> (values rounded to 2 decimal places):<br>
\(\text{output}_0\) (3&times;4):
\[
\begin{bmatrix}
5.71 & 6.71 & 7.71 & 8.71 \\
5.71 & 6.71 & 7.71 & 8.71 \\
5.71 & 6.71 & 7.71 & 8.71
\end{bmatrix}
\]
\(\text{output}_1\) (3&times;4):
\[
\begin{bmatrix}
6.07 & 7.07 & 8.07 & 9.07 \\
5.00 & 6.00 & 7.00 & 8.00 \\
5.71 & 6.71 & 7.71 & 8.71
\end{bmatrix}
\]
\(\text{output}_2\) (3&times;4):
\[
\begin{bmatrix}
2.24 & 2.76 & 3.27 & 3.79 \\
3.96 & 4.70 & 5.44 & 6.17 \\
2.40 & 2.60 & 2.79 & 2.98
\end{bmatrix}
\]
\(\text{output}_3\) (3&times;4):
\[
\begin{bmatrix}
0.76 & 0.58 & 0.40 & 0.22 \\
1.17 & 1.08 & 1.00 & 0.91 \\
2.84 & 3.37 & 3.91 & 4.44
\end{bmatrix}
\]
</p>

<h2>Constraints</h2>
<ul>
<li>1 &le; <code>num_kv_heads</code> &le; <code>num_q_heads</code> &le; 64</li>
<li><code>num_q_heads</code> is divisible by <code>num_kv_heads</code></li>
<li>1 &le; <code>seq_len</code> &le; 4,096</li>
<li>8 &le; <code>head_dim</code> &le; 256; <code>head_dim</code> is a multiple of 8</li>
<li>All tensor values are <code>float32</code></li>
<li>Performance is measured with <code>num_q_heads</code> = 32, <code>num_kv_heads</code> = 8, <code>seq_len</code> = 1,024, <code>head_dim</code> = 128</li>
</ul>
179 changes: 179 additions & 0 deletions challenges/medium/80_grouped_query_attention/challenge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import ctypes
import math
from typing import Any, Dict, List

import torch
from core.challenge_base import ChallengeBase


class Challenge(ChallengeBase):
def __init__(self):
super().__init__(
name="Grouped Query Attention",
atol=1e-04,
rtol=1e-04,
num_gpus=1,
access_tier="free",
)

def reference_impl(
self,
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
output: torch.Tensor,
num_q_heads: int,
num_kv_heads: int,
seq_len: int,
head_dim: int,
):
assert Q.shape == (num_q_heads, seq_len, head_dim)
assert K.shape == (num_kv_heads, seq_len, head_dim)
assert V.shape == (num_kv_heads, seq_len, head_dim)
assert output.shape == (num_q_heads, seq_len, head_dim)
assert Q.dtype == K.dtype == V.dtype == output.dtype == torch.float32
assert Q.device.type == "cuda"
assert K.device.type == "cuda"
assert V.device.type == "cuda"
assert output.device.type == "cuda"
assert num_q_heads % num_kv_heads == 0

num_groups = num_q_heads // num_kv_heads
scale = 1.0 / math.sqrt(head_dim)

# Expand K, V from (num_kv_heads, seq_len, head_dim)
# to (num_q_heads, seq_len, head_dim) by repeating each KV head num_groups times
K_expanded = K.repeat_interleave(num_groups, dim=0)
V_expanded = V.repeat_interleave(num_groups, dim=0)

# Scaled dot-product attention: (num_q_heads, seq_len, seq_len)
scores = torch.bmm(Q, K_expanded.transpose(1, 2)) * scale

# Softmax over the key dimension
attn_weights = torch.softmax(scores, dim=-1)

# Weighted sum of values: (num_q_heads, seq_len, head_dim)
output.copy_(torch.bmm(attn_weights, V_expanded))

def get_solve_signature(self) -> Dict[str, tuple]:
return {
"Q": (ctypes.POINTER(ctypes.c_float), "in"),
"K": (ctypes.POINTER(ctypes.c_float), "in"),
"V": (ctypes.POINTER(ctypes.c_float), "in"),
"output": (ctypes.POINTER(ctypes.c_float), "out"),
"num_q_heads": (ctypes.c_int, "in"),
"num_kv_heads": (ctypes.c_int, "in"),
"seq_len": (ctypes.c_int, "in"),
"head_dim": (ctypes.c_int, "in"),
}

def _make_test_case(self, num_q_heads, num_kv_heads, seq_len, head_dim, zero_inputs=False):
dtype = torch.float32
device = "cuda"
if zero_inputs:
Q = torch.zeros(num_q_heads, seq_len, head_dim, device=device, dtype=dtype)
K = torch.zeros(num_kv_heads, seq_len, head_dim, device=device, dtype=dtype)
V = torch.zeros(num_kv_heads, seq_len, head_dim, device=device, dtype=dtype)
else:
Q = torch.randn(num_q_heads, seq_len, head_dim, device=device, dtype=dtype)
K = torch.randn(num_kv_heads, seq_len, head_dim, device=device, dtype=dtype)
V = torch.randn(num_kv_heads, seq_len, head_dim, device=device, dtype=dtype)
output = torch.zeros(num_q_heads, seq_len, head_dim, device=device, dtype=dtype)
return {
"Q": Q,
"K": K,
"V": V,
"output": output,
"num_q_heads": num_q_heads,
"num_kv_heads": num_kv_heads,
"seq_len": seq_len,
"head_dim": head_dim,
}

def generate_example_test(self) -> Dict[str, Any]:
torch.manual_seed(0)
dtype = torch.float32
device = "cuda"
num_q_heads = 4
num_kv_heads = 2
seq_len = 3
head_dim = 4

Q = torch.tensor(
[
[[1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 1.0, 0.0], [1.0, 1.0, 0.0, 0.0]],
[[0.0, 1.0, 0.0, 1.0], [1.0, 0.0, 1.0, 0.0], [0.0, 0.0, 1.0, 1.0]],
[[-1.0, 0.0, 0.5, 0.0], [0.0, -1.0, 0.0, 0.5], [0.5, 0.0, -1.0, 0.0]],
[[0.0, 0.5, 0.0, -1.0], [0.5, 0.0, 0.0, -1.0], [0.0, 0.0, 0.5, 0.5]],
],
device=device,
dtype=dtype,
)
K = torch.tensor(
[
[[1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0], [1.0, 1.0, 1.0, 1.0]],
[[0.0, 1.0, 0.0, -1.0], [-1.0, 0.0, 1.0, 0.0], [0.0, -1.0, 0.0, 1.0]],
],
device=device,
dtype=dtype,
)
V = torch.tensor(
[
[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]],
[[-1.0, -2.0, -3.0, -4.0], [2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0]],
],
device=device,
dtype=dtype,
)
output = torch.zeros(num_q_heads, seq_len, head_dim, device=device, dtype=dtype)
return {
"Q": Q,
"K": K,
"V": V,
"output": output,
"num_q_heads": num_q_heads,
"num_kv_heads": num_kv_heads,
"seq_len": seq_len,
"head_dim": head_dim,
}

def generate_functional_test(self) -> List[Dict[str, Any]]:
torch.manual_seed(42)
tests = []

# Edge case: MQA (num_kv_heads=1), single token
tests.append(self._make_test_case(4, 1, 1, 8))

# Edge case: GQA with groups=2, tiny seq
tests.append(self._make_test_case(2, 1, 2, 4))

# Zero inputs
tests.append(self._make_test_case(4, 2, 4, 8, zero_inputs=True))

# Power-of-2: groups=4 (LLaMA-3 style ratio)
tests.append(self._make_test_case(8, 2, 16, 32))

# Power-of-2: seq_len=32, head_dim=64
tests.append(self._make_test_case(4, 2, 32, 64))

# Non-power-of-2 seq_len
tests.append(self._make_test_case(4, 2, 30, 32))

# Non-power-of-2 seq_len, different grouping
tests.append(self._make_test_case(6, 3, 100, 32))

# GQA groups=8 (Mistral style), seq_len=255
tests.append(self._make_test_case(8, 1, 255, 64))

# MHA equivalent (num_q_heads == num_kv_heads)
tests.append(self._make_test_case(8, 8, 64, 32))

# Realistic small inference batch
tests.append(self._make_test_case(8, 2, 128, 64))

return tests

def generate_performance_test(self) -> Dict[str, Any]:
torch.manual_seed(0)
# LLaMA-3 8B style: 32 Q heads, 8 KV heads, head_dim=128
return self._make_test_case(32, 8, 1024, 128)
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include <cuda_runtime.h>

// Q, K, V, output are device pointers
extern "C" void solve(const float* Q, const float* K, const float* V, float* output,
int num_q_heads, int num_kv_heads, int seq_len, int head_dim) {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import cutlass
import cutlass.cute as cute


# Q, K, V, output are tensors on the GPU
@cute.jit
def solve(
Q: cute.Tensor,
K: cute.Tensor,
V: cute.Tensor,
output: cute.Tensor,
num_q_heads: cute.Int32,
num_kv_heads: cute.Int32,
seq_len: cute.Int32,
head_dim: cute.Int32,
):
pass
Loading
Loading