-
Notifications
You must be signed in to change notification settings - Fork 70
Add turbo quant attention #227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
kunal-mansukhani
wants to merge
5
commits into
main
Choose a base branch
from
kunal/add-turbo-quant-challenge
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+408
−0
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,90 @@ | ||
| <p> | ||
| Implement attention score computation against a | ||
| <a href="https://arxiv.org/abs/2504.19874" target="_blank" style="color:#4a9eff; text-decoration:underline;">TurboQuant</a>-compressed | ||
| KV cache. TurboQuant compresses each key vector to <code>uint8</code> codebook indices plus a 1-bit | ||
| residual correction (QJL), reducing KV cache memory by up to 6x. Your task: dequantize the | ||
| compressed keys and compute dot-product attention scores against full-precision queries. | ||
| </p> | ||
|
|
||
| <p> | ||
| <strong>Background - how the keys were compressed</strong> (already done for you, not part of the challenge): | ||
| </p> | ||
| <ol> | ||
| <li><strong>Rotate</strong>: multiply key by orthogonal matrix \(\Pi\): \(\;y = \Pi \cdot K\). This makes each | ||
| coordinate follow a Beta distribution, so a single fixed codebook works for all coordinates.</li> | ||
| <li><strong>Scalar quantize</strong>: replace each coordinate of \(y\) with the index of its nearest | ||
| codebook centroid \(\rightarrow K_\text{idx}\) (<code>uint8</code>).</li> | ||
| <li><strong>Residual correction</strong>: MSE quantization loses information. Compute the residual | ||
| \(r = K - \tilde{K}_\text{mse}\), then store: | ||
| <ul> | ||
| <li>\(\sigma = \text{sign}(S_\text{mat} \cdot r) \in \{-1,+1\}^D\) - direction (<code>int8</code>)</li> | ||
| <li>\(\gamma = \|r\|_2\) - magnitude (<code>float32</code> scalar per key)</li> | ||
| </ul> | ||
| where \(S_\text{mat} \in \mathbb{R}^{D \times D}\) is a random Gaussian projection matrix. | ||
| </li> | ||
| </ol> | ||
|
|
||
| <p> | ||
| <strong>What you compute</strong> - dequantize and score: | ||
| </p> | ||
| <ol> | ||
| <li><strong>MSE dequantize</strong>: look up centroids, undo the rotation: | ||
| \[\tilde{K}_\text{mse} = \text{codebook}[K_\text{idx}] \cdot \Pi\]</li> | ||
| <li><strong>Residual dequantize</strong>: reconstruct the residual correction: | ||
| \[\tilde{K}_\text{res} = \frac{\sqrt{\pi/2}}{D} \cdot \gamma \cdot \sigma \cdot S_\text{mat}\] | ||
| The \(\sqrt{\pi/2}/D\) constant corrects for the distortion introduced by taking signs.</li> | ||
| <li><strong>Combine</strong>: | ||
| \(\tilde{K} = \tilde{K}_\text{mse} + \tilde{K}_\text{res}\)</li> | ||
| <li><strong>Dot product</strong>: | ||
| \(\text{scores}_{b,s} = Q_b \cdot \tilde{K}_s\)</li> | ||
| </ol> | ||
| <p> | ||
| The residual correction makes the inner product <strong>unbiased</strong>: | ||
| \(\mathbb{E}[\langle Q, \tilde{K} \rangle] = \langle Q, K \rangle\). | ||
| </p> | ||
|
|
||
| <h2>Implementation Requirements</h2> | ||
| <ul> | ||
| <li>The <code>solve</code> function signature must remain unchanged.</li> | ||
| <li>Use only native features (no external libraries).</li> | ||
| <li>Store the result in <code>scores</code> as <code>float32</code>.</li> | ||
| </ul> | ||
|
|
||
| <h2>Example</h2> | ||
| <p> | ||
| Input: \(B=2,\; S=3,\; D=2,\; C=4\), with \(\Pi = I\), \(S_\text{mat} = I\), \(\gamma = \mathbf{0}\) (residual correction disabled), | ||
| \(\sigma = \mathbf{1}\) (all +1): | ||
| </p> | ||
| <p> | ||
| \(Q = \begin{bmatrix} 1 & 0 \\ 0 & 1 \end{bmatrix}\), | ||
| \(K_\text{idx} = \begin{bmatrix} 0 & 3 \\ 1 & 2 \\ 3 & 0 \end{bmatrix}\), | ||
| codebook \(= [-0.75,\; -0.25,\; 0.25,\; 0.75]\) | ||
| </p> | ||
| <p> | ||
| Step 1 - MSE lookup and rotate back (\(\Pi = I\)): | ||
| \[ | ||
| \tilde{K}_\text{mse} = \begin{bmatrix} -0.75 & 0.75 \\ -0.25 & 0.25 \\ 0.75 & -0.75 \end{bmatrix} | ||
| \] | ||
| Step 2 - Residual correction is zero (\(\gamma = 0\)), so \(\tilde{K} = \tilde{K}_\text{mse}\). | ||
| </p> | ||
| <p> | ||
| Output: | ||
| \[ | ||
| \text{scores} = Q \cdot \tilde{K}^T = \begin{bmatrix} -0.75 & -0.25 & 0.75 \\ 0.75 & 0.25 & -0.75 \end{bmatrix} | ||
| \] | ||
| </p> | ||
|
|
||
| <h2>Constraints</h2> | ||
| <ul> | ||
| <li>1 ≤ <code>B</code> ≤ 32</li> | ||
| <li>1 ≤ <code>S</code> ≤ 65,536</li> | ||
| <li>1 ≤ <code>D</code> ≤ 256</li> | ||
| <li>2 ≤ <code>C</code> ≤ 256</li> | ||
| <li>\(\Pi\) is orthogonal (\(\Pi^T \Pi = I\))</li> | ||
| <li><code>S_mat</code> has i.i.d. \(\mathcal{N}(0,1)\) entries</li> | ||
| <li><code>gamma</code> has shape \([S]\) (one \(\ell_2\) norm per key vector, <code>float32</code>)</li> | ||
| <li><code>qjl_signs</code> (\(\sigma\)) values are in \(\{-1, +1\}\) (<code>int8</code>)</li> | ||
| <li><code>K_idx</code> values are in \([0, C)\) (<code>uint8</code>)</li> | ||
| <li>All floating-point inputs are <code>float32</code></li> | ||
| <li>Performance is measured with <code>B</code> = 32, <code>S</code> = 32,768, <code>D</code> = 128, <code>C</code> = 16</li> | ||
| </ul> | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,221 @@ | ||
| import ctypes | ||
| import math | ||
| from typing import Any, Dict, List | ||
|
|
||
| import torch | ||
| from core.challenge_base import ChallengeBase | ||
|
|
||
|
|
||
| class Challenge(ChallengeBase): | ||
| def __init__(self): | ||
| super().__init__( | ||
| name="TurboQuant KV Cache Attention", | ||
| atol=1e-3, | ||
| rtol=1e-3, | ||
| num_gpus=1, | ||
| access_tier="free", | ||
| ) | ||
|
|
||
| def reference_impl( | ||
| self, | ||
| Q: torch.Tensor, | ||
| K_idx: torch.Tensor, | ||
| qjl_signs: torch.Tensor, | ||
| gamma: torch.Tensor, | ||
| Pi: torch.Tensor, | ||
| S_mat: torch.Tensor, | ||
| codebook: torch.Tensor, | ||
| scores: torch.Tensor, | ||
| B: int, | ||
| S: int, | ||
| D: int, | ||
| C: int, | ||
| ): | ||
| assert Q.shape == (B, D) | ||
| assert K_idx.shape == (S, D) | ||
| assert qjl_signs.shape == (S, D) | ||
| assert gamma.shape == (S,) | ||
| assert Pi.shape == (D, D) | ||
| assert S_mat.shape == (D, D) | ||
| assert codebook.shape == (C,) | ||
| assert scores.shape == (B, S) | ||
| assert Q.dtype == torch.float32 | ||
| assert K_idx.dtype == torch.uint8 | ||
| assert qjl_signs.dtype == torch.int8 | ||
| assert gamma.dtype == torch.float32 | ||
| assert Pi.dtype == torch.float32 | ||
| assert S_mat.dtype == torch.float32 | ||
| assert codebook.dtype == torch.float32 | ||
| assert scores.dtype == torch.float32 | ||
| assert Q.device.type == "cuda" | ||
| assert K_idx.device.type == "cuda" | ||
| assert qjl_signs.device.type == "cuda" | ||
| assert gamma.device.type == "cuda" | ||
| assert Pi.device.type == "cuda" | ||
| assert S_mat.device.type == "cuda" | ||
| assert codebook.device.type == "cuda" | ||
| assert scores.device.type == "cuda" | ||
|
|
||
| # Stage 1: MSE dequantization — lookup centroids, rotate back | ||
| K_centroids = codebook[K_idx.long()] # [S, D] | ||
| K_mse = K_centroids @ Pi # [S, D] (row convention: ỹ @ Π = Π^T · ỹ) | ||
|
|
||
| # Stage 2: QJL dequantization — reconstruct residual correction | ||
| scale = math.sqrt(math.pi / 2.0) / D | ||
| K_qjl = scale * gamma.unsqueeze(1) * (qjl_signs.float() @ S_mat) # [S, D] | ||
|
|
||
| # Combined dequantization | ||
| K_deq = K_mse + K_qjl # [S, D] | ||
|
|
||
| # Attention scores | ||
| scores.copy_(Q @ K_deq.T) # [B, S] | ||
|
|
||
| def get_solve_signature(self) -> Dict[str, tuple]: | ||
| return { | ||
| "Q": (ctypes.POINTER(ctypes.c_float), "in"), | ||
| "K_idx": (ctypes.POINTER(ctypes.c_uint8), "in"), | ||
| "qjl_signs": (ctypes.POINTER(ctypes.c_int8), "in"), | ||
| "gamma": (ctypes.POINTER(ctypes.c_float), "in"), | ||
| "Pi": (ctypes.POINTER(ctypes.c_float), "in"), | ||
| "S_mat": (ctypes.POINTER(ctypes.c_float), "in"), | ||
| "codebook": (ctypes.POINTER(ctypes.c_float), "in"), | ||
| "scores": (ctypes.POINTER(ctypes.c_float), "out"), | ||
| "B": (ctypes.c_int, "in"), | ||
| "S": (ctypes.c_int, "in"), | ||
| "D": (ctypes.c_int, "in"), | ||
| "C": (ctypes.c_int, "in"), | ||
| } | ||
|
|
||
| def _make_rotation(self, D): | ||
| G = torch.randn(D, D, device="cuda") | ||
| Q, _ = torch.linalg.qr(G) | ||
| return Q | ||
|
|
||
| def _make_codebook(self, C, scale=1.0): | ||
| return torch.linspace(-scale, scale, C, device="cuda", dtype=torch.float32) | ||
|
|
||
| def _encode_keys(self, K, Pi, S_mat, codebook): | ||
| """Simulate TurboQuant_prod encoding: rotate, quantize, compute QJL on residual.""" | ||
| S, D = K.shape | ||
|
|
||
| # Stage 1: MSE encoding | ||
| Y = K @ Pi.T # rotate into quantization space | ||
| # Scalar quantize each coordinate to nearest centroid | ||
| diffs = Y.unsqueeze(-1) - codebook.unsqueeze(0).unsqueeze(0) # [S, D, C] | ||
| K_idx = diffs.abs().argmin(dim=-1).to(torch.uint8) # [S, D] | ||
|
|
||
| # MSE dequantization (to compute residual) | ||
| K_centroids = codebook[K_idx.long()] # [S, D] | ||
| K_mse = K_centroids @ Pi # [S, D] | ||
|
|
||
| # Stage 2: QJL encoding of residual | ||
| residual = K - K_mse # [S, D] | ||
| gamma = residual.norm(dim=1) # [S] | ||
| proj = residual @ S_mat.T # [S, D] (row convention for S · r) | ||
| qjl_signs = torch.sign(proj).to(torch.int8) # [S, D] | ||
| # Ensure no zeros (sign(0)=0 → map to +1) | ||
| qjl_signs[qjl_signs == 0] = 1 | ||
|
|
||
| return K_idx, qjl_signs, gamma | ||
|
|
||
| def _make_test_case(self, B, S_seq, D, C, zero_q=False, seed=42): | ||
| torch.manual_seed(seed) | ||
| device = "cuda" | ||
|
|
||
| Pi = self._make_rotation(D) | ||
| S_mat = torch.randn(D, D, device=device, dtype=torch.float32) | ||
| codebook = self._make_codebook(C) | ||
|
|
||
| if zero_q: | ||
| Q = torch.zeros(B, D, device=device, dtype=torch.float32) | ||
| else: | ||
| Q = torch.randn(B, D, device=device, dtype=torch.float32) * 0.5 | ||
|
|
||
| # Generate realistic keys and encode them | ||
| K = torch.randn(S_seq, D, device=device, dtype=torch.float32) * 0.3 | ||
| K_idx, qjl_signs, gamma = self._encode_keys(K, Pi, S_mat, codebook) | ||
|
|
||
| scores = torch.zeros(B, S_seq, device=device, dtype=torch.float32) | ||
|
|
||
| return { | ||
| "Q": Q, | ||
| "K_idx": K_idx, | ||
| "qjl_signs": qjl_signs, | ||
| "gamma": gamma, | ||
| "Pi": Pi, | ||
| "S_mat": S_mat, | ||
| "codebook": codebook, | ||
| "scores": scores, | ||
| "B": B, | ||
| "S": S_seq, | ||
| "D": D, | ||
| "C": C, | ||
| } | ||
|
|
||
| def generate_example_test(self) -> Dict[str, Any]: | ||
| device = "cuda" | ||
| B, S, D, C = 2, 3, 2, 4 | ||
|
|
||
| Q = torch.tensor([[1.0, 0.0], [0.0, 1.0]], device=device, dtype=torch.float32) | ||
| K_idx = torch.tensor([[0, 3], [1, 2], [3, 0]], device=device, dtype=torch.uint8) | ||
| # QJL signs: all +1 for simplicity | ||
| qjl_signs = torch.ones(S, D, device=device, dtype=torch.int8) | ||
| # gamma = 0: no QJL correction (reduces to MSE-only for this example) | ||
| gamma = torch.zeros(S, device=device, dtype=torch.float32) | ||
| Pi = torch.eye(D, device=device, dtype=torch.float32) | ||
| S_mat = torch.eye(D, device=device, dtype=torch.float32) | ||
| codebook = torch.tensor([-0.75, -0.25, 0.25, 0.75], device=device, dtype=torch.float32) | ||
| scores = torch.zeros(B, S, device=device, dtype=torch.float32) | ||
|
|
||
| return { | ||
| "Q": Q, | ||
| "K_idx": K_idx, | ||
| "qjl_signs": qjl_signs, | ||
| "gamma": gamma, | ||
| "Pi": Pi, | ||
| "S_mat": S_mat, | ||
| "codebook": codebook, | ||
| "scores": scores, | ||
| "B": B, | ||
| "S": S, | ||
| "D": D, | ||
| "C": C, | ||
| } | ||
|
|
||
| def generate_functional_test(self) -> List[Dict[str, Any]]: | ||
| tests = [] | ||
|
|
||
| # Edge: single query, single key, D=1 | ||
| tests.append(self._make_test_case(1, 1, 1, 2, seed=1)) | ||
|
|
||
| # Edge: zero query | ||
| tests.append(self._make_test_case(2, 3, 4, 4, zero_q=True, seed=2)) | ||
|
|
||
| # Edge: small with negative queries | ||
| tests.append(self._make_test_case(2, 4, 4, 4, seed=3)) | ||
|
|
||
| # Power-of-2: B=4, S=16, D=32, C=8 | ||
| tests.append(self._make_test_case(4, 16, 32, 8, seed=10)) | ||
|
|
||
| # Power-of-2: B=8, S=64, D=64, C=16 | ||
| tests.append(self._make_test_case(8, 64, 64, 16, seed=20)) | ||
|
|
||
| # Power-of-2: B=16, S=128, D=128, C=16 | ||
| tests.append(self._make_test_case(16, 128, 128, 16, seed=30)) | ||
|
|
||
| # Non-power-of-2: B=3, S=30, D=50, C=8 | ||
| tests.append(self._make_test_case(3, 30, 50, 8, seed=40)) | ||
|
|
||
| # Non-power-of-2: B=7, S=255, D=100, C=16 | ||
| tests.append(self._make_test_case(7, 255, 100, 16, seed=50)) | ||
|
|
||
| # Realistic: B=16, S=4096, D=128, C=16 | ||
| tests.append(self._make_test_case(16, 4096, 128, 16, seed=60)) | ||
|
|
||
| # Realistic: B=32, S=8192, D=128, C=8 | ||
| tests.append(self._make_test_case(32, 8192, 128, 8, seed=70)) | ||
|
|
||
| return tests | ||
|
|
||
| def generate_performance_test(self) -> Dict[str, Any]: | ||
| return self._make_test_case(32, 32768, 128, 16, seed=0) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| #include <cuda_runtime.h> | ||
|
|
||
| // Q, K_idx, qjl_signs, gamma, Pi, S_mat, codebook, scores are device pointers | ||
| extern "C" void solve(const float* Q, const unsigned char* K_idx, const signed char* qjl_signs, | ||
| const float* gamma, const float* Pi, const float* S_mat, | ||
| const float* codebook, float* scores, int B, int S, int D, int C) {} |
21 changes: 21 additions & 0 deletions
21
challenges/hard/83_turboquant_attention/starter/starter.cute.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| import cutlass | ||
| import cutlass.cute as cute | ||
|
|
||
|
|
||
| # Q, K_idx, qjl_signs, gamma, Pi, S_mat, codebook, scores are tensors on the GPU | ||
| @cute.jit | ||
| def solve( | ||
| Q: cute.Tensor, | ||
| K_idx: cute.Tensor, | ||
| qjl_signs: cute.Tensor, | ||
| gamma: cute.Tensor, | ||
| Pi: cute.Tensor, | ||
| S_mat: cute.Tensor, | ||
| codebook: cute.Tensor, | ||
| scores: cute.Tensor, | ||
| B: cute.Int32, | ||
| S: cute.Int32, | ||
| D: cute.Int32, | ||
| C: cute.Int32, | ||
| ): | ||
| pass |
21 changes: 21 additions & 0 deletions
21
challenges/hard/83_turboquant_attention/starter/starter.jax.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| import jax | ||
| import jax.numpy as jnp | ||
|
|
||
|
|
||
| # Q, K_idx, qjl_signs, gamma, Pi, S_mat, codebook are tensors on GPU | ||
| @jax.jit | ||
| def solve( | ||
| Q: jax.Array, | ||
| K_idx: jax.Array, | ||
| qjl_signs: jax.Array, | ||
| gamma: jax.Array, | ||
| Pi: jax.Array, | ||
| S_mat: jax.Array, | ||
| codebook: jax.Array, | ||
| B: int, | ||
| S: int, | ||
| D: int, | ||
| C: int, | ||
| ) -> jax.Array: | ||
| # return output tensor directly | ||
| pass |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| from gpu.host import DeviceContext | ||
| from gpu.id import block_dim, block_idx, thread_idx | ||
| from memory import UnsafePointer | ||
| from math import ceildiv | ||
|
|
||
| # Q, K_idx, qjl_signs, gamma, Pi, S_mat, codebook, scores are device pointers | ||
| @export | ||
| def solve(Q: UnsafePointer[Float32], K_idx: UnsafePointer[UInt8], qjl_signs: UnsafePointer[Int8], gamma: UnsafePointer[Float32], Pi: UnsafePointer[Float32], S_mat: UnsafePointer[Float32], codebook: UnsafePointer[Float32], scores: UnsafePointer[Float32], B: Int32, S: Int32, D: Int32, C: Int32): | ||
| pass |
19 changes: 19 additions & 0 deletions
19
challenges/hard/83_turboquant_attention/starter/starter.pytorch.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| import torch | ||
|
|
||
|
|
||
| # Q, K_idx, qjl_signs, gamma, Pi, S_mat, codebook, scores are tensors on the GPU | ||
| def solve( | ||
| Q: torch.Tensor, | ||
| K_idx: torch.Tensor, | ||
| qjl_signs: torch.Tensor, | ||
| gamma: torch.Tensor, | ||
| Pi: torch.Tensor, | ||
| S_mat: torch.Tensor, | ||
| codebook: torch.Tensor, | ||
| scores: torch.Tensor, | ||
| B: int, | ||
| S: int, | ||
| D: int, | ||
| C: int, | ||
| ): | ||
| pass |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we can just say "not part of the challenge" and get rid of "already done for you" imo