Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions challenges/hard/83_turboquant_attention/challenge.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
<p>
Implement attention score computation against a
<a href="https://arxiv.org/abs/2504.19874" target="_blank" style="color:#4a9eff; text-decoration:underline;">TurboQuant</a>-compressed
KV cache. TurboQuant compresses each key vector to <code>uint8</code> codebook indices plus a 1-bit
residual correction (QJL), reducing KV cache memory by up to 6x. Your task: dequantize the
compressed keys and compute dot-product attention scores against full-precision queries.
</p>

<p>
<strong>Background - how the keys were compressed</strong> (already done for you, not part of the challenge):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can just say "not part of the challenge" and get rid of "already done for you" imo

</p>
<ol>
<li><strong>Rotate</strong>: multiply key by orthogonal matrix \(\Pi\): \(\;y = \Pi \cdot K\). This makes each
coordinate follow a Beta distribution, so a single fixed codebook works for all coordinates.</li>
<li><strong>Scalar quantize</strong>: replace each coordinate of \(y\) with the index of its nearest
codebook centroid \(\rightarrow K_\text{idx}\) (<code>uint8</code>).</li>
<li><strong>Residual correction</strong>: MSE quantization loses information. Compute the residual
\(r = K - \tilde{K}_\text{mse}\), then store:
<ul>
<li>\(\sigma = \text{sign}(S_\text{mat} \cdot r) \in \{-1,+1\}^D\) - direction (<code>int8</code>)</li>
<li>\(\gamma = \|r\|_2\) - magnitude (<code>float32</code> scalar per key)</li>
</ul>
where \(S_\text{mat} \in \mathbb{R}^{D \times D}\) is a random Gaussian projection matrix.
</li>
</ol>

<p>
<strong>What you compute</strong> - dequantize and score:
</p>
<ol>
<li><strong>MSE dequantize</strong>: look up centroids, undo the rotation:
\[\tilde{K}_\text{mse} = \text{codebook}[K_\text{idx}] \cdot \Pi\]</li>
<li><strong>Residual dequantize</strong>: reconstruct the residual correction:
\[\tilde{K}_\text{res} = \frac{\sqrt{\pi/2}}{D} \cdot \gamma \cdot \sigma \cdot S_\text{mat}\]
The \(\sqrt{\pi/2}/D\) constant corrects for the distortion introduced by taking signs.</li>
<li><strong>Combine</strong>:
\(\tilde{K} = \tilde{K}_\text{mse} + \tilde{K}_\text{res}\)</li>
<li><strong>Dot product</strong>:
\(\text{scores}_{b,s} = Q_b \cdot \tilde{K}_s\)</li>
</ol>
<p>
The residual correction makes the inner product <strong>unbiased</strong>:
\(\mathbb{E}[\langle Q, \tilde{K} \rangle] = \langle Q, K \rangle\).
</p>

<h2>Implementation Requirements</h2>
<ul>
<li>The <code>solve</code> function signature must remain unchanged.</li>
<li>Use only native features (no external libraries).</li>
<li>Store the result in <code>scores</code> as <code>float32</code>.</li>
</ul>

<h2>Example</h2>
<p>
Input: \(B=2,\; S=3,\; D=2,\; C=4\), with \(\Pi = I\), \(S_\text{mat} = I\), \(\gamma = \mathbf{0}\) (residual correction disabled),
\(\sigma = \mathbf{1}\) (all +1):
</p>
<p>
\(Q = \begin{bmatrix} 1 & 0 \\ 0 & 1 \end{bmatrix}\),
\(K_\text{idx} = \begin{bmatrix} 0 & 3 \\ 1 & 2 \\ 3 & 0 \end{bmatrix}\),
codebook \(= [-0.75,\; -0.25,\; 0.25,\; 0.75]\)
</p>
<p>
Step 1 - MSE lookup and rotate back (\(\Pi = I\)):
\[
\tilde{K}_\text{mse} = \begin{bmatrix} -0.75 & 0.75 \\ -0.25 & 0.25 \\ 0.75 & -0.75 \end{bmatrix}
\]
Step 2 - Residual correction is zero (\(\gamma = 0\)), so \(\tilde{K} = \tilde{K}_\text{mse}\).
</p>
<p>
Output:
\[
\text{scores} = Q \cdot \tilde{K}^T = \begin{bmatrix} -0.75 & -0.25 & 0.75 \\ 0.75 & 0.25 & -0.75 \end{bmatrix}
\]
</p>

<h2>Constraints</h2>
<ul>
<li>1 &le; <code>B</code> &le; 32</li>
<li>1 &le; <code>S</code> &le; 65,536</li>
<li>1 &le; <code>D</code> &le; 256</li>
<li>2 &le; <code>C</code> &le; 256</li>
<li>\(\Pi\) is orthogonal (\(\Pi^T \Pi = I\))</li>
<li><code>S_mat</code> has i.i.d. \(\mathcal{N}(0,1)\) entries</li>
<li><code>gamma</code> has shape \([S]\) (one \(\ell_2\) norm per key vector, <code>float32</code>)</li>
<li><code>qjl_signs</code> (\(\sigma\)) values are in \(\{-1, +1\}\) (<code>int8</code>)</li>
<li><code>K_idx</code> values are in \([0, C)\) (<code>uint8</code>)</li>
<li>All floating-point inputs are <code>float32</code></li>
<li>Performance is measured with <code>B</code> = 32, <code>S</code> = 32,768, <code>D</code> = 128, <code>C</code> = 16</li>
</ul>
221 changes: 221 additions & 0 deletions challenges/hard/83_turboquant_attention/challenge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
import ctypes
import math
from typing import Any, Dict, List

import torch
from core.challenge_base import ChallengeBase


class Challenge(ChallengeBase):
def __init__(self):
super().__init__(
name="TurboQuant KV Cache Attention",
atol=1e-3,
rtol=1e-3,
num_gpus=1,
access_tier="free",
)

def reference_impl(
self,
Q: torch.Tensor,
K_idx: torch.Tensor,
qjl_signs: torch.Tensor,
gamma: torch.Tensor,
Pi: torch.Tensor,
S_mat: torch.Tensor,
codebook: torch.Tensor,
scores: torch.Tensor,
B: int,
S: int,
D: int,
C: int,
):
assert Q.shape == (B, D)
assert K_idx.shape == (S, D)
assert qjl_signs.shape == (S, D)
assert gamma.shape == (S,)
assert Pi.shape == (D, D)
assert S_mat.shape == (D, D)
assert codebook.shape == (C,)
assert scores.shape == (B, S)
assert Q.dtype == torch.float32
assert K_idx.dtype == torch.uint8
assert qjl_signs.dtype == torch.int8
assert gamma.dtype == torch.float32
assert Pi.dtype == torch.float32
assert S_mat.dtype == torch.float32
assert codebook.dtype == torch.float32
assert scores.dtype == torch.float32
assert Q.device.type == "cuda"
assert K_idx.device.type == "cuda"
assert qjl_signs.device.type == "cuda"
assert gamma.device.type == "cuda"
assert Pi.device.type == "cuda"
assert S_mat.device.type == "cuda"
assert codebook.device.type == "cuda"
assert scores.device.type == "cuda"

# Stage 1: MSE dequantization — lookup centroids, rotate back
K_centroids = codebook[K_idx.long()] # [S, D]
K_mse = K_centroids @ Pi # [S, D] (row convention: ỹ @ Π = Π^T · ỹ)

# Stage 2: QJL dequantization — reconstruct residual correction
scale = math.sqrt(math.pi / 2.0) / D
K_qjl = scale * gamma.unsqueeze(1) * (qjl_signs.float() @ S_mat) # [S, D]

# Combined dequantization
K_deq = K_mse + K_qjl # [S, D]

# Attention scores
scores.copy_(Q @ K_deq.T) # [B, S]

def get_solve_signature(self) -> Dict[str, tuple]:
return {
"Q": (ctypes.POINTER(ctypes.c_float), "in"),
"K_idx": (ctypes.POINTER(ctypes.c_uint8), "in"),
"qjl_signs": (ctypes.POINTER(ctypes.c_int8), "in"),
"gamma": (ctypes.POINTER(ctypes.c_float), "in"),
"Pi": (ctypes.POINTER(ctypes.c_float), "in"),
"S_mat": (ctypes.POINTER(ctypes.c_float), "in"),
"codebook": (ctypes.POINTER(ctypes.c_float), "in"),
"scores": (ctypes.POINTER(ctypes.c_float), "out"),
"B": (ctypes.c_int, "in"),
"S": (ctypes.c_int, "in"),
"D": (ctypes.c_int, "in"),
"C": (ctypes.c_int, "in"),
}

def _make_rotation(self, D):
G = torch.randn(D, D, device="cuda")
Q, _ = torch.linalg.qr(G)
return Q

def _make_codebook(self, C, scale=1.0):
return torch.linspace(-scale, scale, C, device="cuda", dtype=torch.float32)

def _encode_keys(self, K, Pi, S_mat, codebook):
"""Simulate TurboQuant_prod encoding: rotate, quantize, compute QJL on residual."""
S, D = K.shape

# Stage 1: MSE encoding
Y = K @ Pi.T # rotate into quantization space
# Scalar quantize each coordinate to nearest centroid
diffs = Y.unsqueeze(-1) - codebook.unsqueeze(0).unsqueeze(0) # [S, D, C]
K_idx = diffs.abs().argmin(dim=-1).to(torch.uint8) # [S, D]

# MSE dequantization (to compute residual)
K_centroids = codebook[K_idx.long()] # [S, D]
K_mse = K_centroids @ Pi # [S, D]

# Stage 2: QJL encoding of residual
residual = K - K_mse # [S, D]
gamma = residual.norm(dim=1) # [S]
proj = residual @ S_mat.T # [S, D] (row convention for S · r)
qjl_signs = torch.sign(proj).to(torch.int8) # [S, D]
# Ensure no zeros (sign(0)=0 → map to +1)
qjl_signs[qjl_signs == 0] = 1

return K_idx, qjl_signs, gamma

def _make_test_case(self, B, S_seq, D, C, zero_q=False, seed=42):
torch.manual_seed(seed)
device = "cuda"

Pi = self._make_rotation(D)
S_mat = torch.randn(D, D, device=device, dtype=torch.float32)
codebook = self._make_codebook(C)

if zero_q:
Q = torch.zeros(B, D, device=device, dtype=torch.float32)
else:
Q = torch.randn(B, D, device=device, dtype=torch.float32) * 0.5

# Generate realistic keys and encode them
K = torch.randn(S_seq, D, device=device, dtype=torch.float32) * 0.3
K_idx, qjl_signs, gamma = self._encode_keys(K, Pi, S_mat, codebook)

scores = torch.zeros(B, S_seq, device=device, dtype=torch.float32)

return {
"Q": Q,
"K_idx": K_idx,
"qjl_signs": qjl_signs,
"gamma": gamma,
"Pi": Pi,
"S_mat": S_mat,
"codebook": codebook,
"scores": scores,
"B": B,
"S": S_seq,
"D": D,
"C": C,
}

def generate_example_test(self) -> Dict[str, Any]:
device = "cuda"
B, S, D, C = 2, 3, 2, 4

Q = torch.tensor([[1.0, 0.0], [0.0, 1.0]], device=device, dtype=torch.float32)
K_idx = torch.tensor([[0, 3], [1, 2], [3, 0]], device=device, dtype=torch.uint8)
# QJL signs: all +1 for simplicity
qjl_signs = torch.ones(S, D, device=device, dtype=torch.int8)
# gamma = 0: no QJL correction (reduces to MSE-only for this example)
gamma = torch.zeros(S, device=device, dtype=torch.float32)
Pi = torch.eye(D, device=device, dtype=torch.float32)
S_mat = torch.eye(D, device=device, dtype=torch.float32)
codebook = torch.tensor([-0.75, -0.25, 0.25, 0.75], device=device, dtype=torch.float32)
scores = torch.zeros(B, S, device=device, dtype=torch.float32)

return {
"Q": Q,
"K_idx": K_idx,
"qjl_signs": qjl_signs,
"gamma": gamma,
"Pi": Pi,
"S_mat": S_mat,
"codebook": codebook,
"scores": scores,
"B": B,
"S": S,
"D": D,
"C": C,
}

def generate_functional_test(self) -> List[Dict[str, Any]]:
tests = []

# Edge: single query, single key, D=1
tests.append(self._make_test_case(1, 1, 1, 2, seed=1))

# Edge: zero query
tests.append(self._make_test_case(2, 3, 4, 4, zero_q=True, seed=2))

# Edge: small with negative queries
tests.append(self._make_test_case(2, 4, 4, 4, seed=3))

# Power-of-2: B=4, S=16, D=32, C=8
tests.append(self._make_test_case(4, 16, 32, 8, seed=10))

# Power-of-2: B=8, S=64, D=64, C=16
tests.append(self._make_test_case(8, 64, 64, 16, seed=20))

# Power-of-2: B=16, S=128, D=128, C=16
tests.append(self._make_test_case(16, 128, 128, 16, seed=30))

# Non-power-of-2: B=3, S=30, D=50, C=8
tests.append(self._make_test_case(3, 30, 50, 8, seed=40))

# Non-power-of-2: B=7, S=255, D=100, C=16
tests.append(self._make_test_case(7, 255, 100, 16, seed=50))

# Realistic: B=16, S=4096, D=128, C=16
tests.append(self._make_test_case(16, 4096, 128, 16, seed=60))

# Realistic: B=32, S=8192, D=128, C=8
tests.append(self._make_test_case(32, 8192, 128, 8, seed=70))

return tests

def generate_performance_test(self) -> Dict[str, Any]:
return self._make_test_case(32, 32768, 128, 16, seed=0)
6 changes: 6 additions & 0 deletions challenges/hard/83_turboquant_attention/starter/starter.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#include <cuda_runtime.h>

// Q, K_idx, qjl_signs, gamma, Pi, S_mat, codebook, scores are device pointers
extern "C" void solve(const float* Q, const unsigned char* K_idx, const signed char* qjl_signs,
const float* gamma, const float* Pi, const float* S_mat,
const float* codebook, float* scores, int B, int S, int D, int C) {}
21 changes: 21 additions & 0 deletions challenges/hard/83_turboquant_attention/starter/starter.cute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import cutlass
import cutlass.cute as cute


# Q, K_idx, qjl_signs, gamma, Pi, S_mat, codebook, scores are tensors on the GPU
@cute.jit
def solve(
Q: cute.Tensor,
K_idx: cute.Tensor,
qjl_signs: cute.Tensor,
gamma: cute.Tensor,
Pi: cute.Tensor,
S_mat: cute.Tensor,
codebook: cute.Tensor,
scores: cute.Tensor,
B: cute.Int32,
S: cute.Int32,
D: cute.Int32,
C: cute.Int32,
):
pass
21 changes: 21 additions & 0 deletions challenges/hard/83_turboquant_attention/starter/starter.jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import jax
import jax.numpy as jnp


# Q, K_idx, qjl_signs, gamma, Pi, S_mat, codebook are tensors on GPU
@jax.jit
def solve(
Q: jax.Array,
K_idx: jax.Array,
qjl_signs: jax.Array,
gamma: jax.Array,
Pi: jax.Array,
S_mat: jax.Array,
codebook: jax.Array,
B: int,
S: int,
D: int,
C: int,
) -> jax.Array:
# return output tensor directly
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from gpu.host import DeviceContext
from gpu.id import block_dim, block_idx, thread_idx
from memory import UnsafePointer
from math import ceildiv

# Q, K_idx, qjl_signs, gamma, Pi, S_mat, codebook, scores are device pointers
@export
def solve(Q: UnsafePointer[Float32], K_idx: UnsafePointer[UInt8], qjl_signs: UnsafePointer[Int8], gamma: UnsafePointer[Float32], Pi: UnsafePointer[Float32], S_mat: UnsafePointer[Float32], codebook: UnsafePointer[Float32], scores: UnsafePointer[Float32], B: Int32, S: Int32, D: Int32, C: Int32):
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import torch


# Q, K_idx, qjl_signs, gamma, Pi, S_mat, codebook, scores are tensors on the GPU
def solve(
Q: torch.Tensor,
K_idx: torch.Tensor,
qjl_signs: torch.Tensor,
gamma: torch.Tensor,
Pi: torch.Tensor,
S_mat: torch.Tensor,
codebook: torch.Tensor,
scores: torch.Tensor,
B: int,
S: int,
D: int,
C: int,
):
pass
Loading
Loading