Summary
TurboQuant — PolarQuant INT4 KV Cache for Qwen3-14B
File Layout
| File |
Role |
turboquant_kv.py |
Lloyd-Max codebook solver + quantize function + dequant function |
qwen3_14b_prefill_tq.py |
Prefill end-to-end: quantize K/V → write cache → dequant for attention |
qwen3_14b_decode_tq.py |
Decode end-to-end: read quantized cache → dequant → attention |
End-to-End Flow
┌─────────────────────────────────────────────────────────────────┐
│ Prefill (compress) │
│ │
│ K/V proj ──► RoPE(K) / L2 norm ──► rotate ──► INT4 quantize │
│ │ │
│ write quant_k/v_cache ◄───┘ │
│ │ │
│ Attention reads ──► dequant ◄┘ │
│ │
├─────────────────────────────────────────────────────────────────┤
│ Decode (decompress) │
│ │
│ read INT4 from quant_k/v_cache ──► dequant ──► Attention │
│ │
└─────────────────────────────────────────────────────────────────┘
Quantization
1. Lloyd-Max Codebook (computed once at module load)
solve_lloyd_max(head_dim, bits=4) in turboquant_kv.py solves for the
optimal 16-level scalar quantizer under N(0, 1/head_dim):
- Input: distribution std
σ = 1/sqrt(head_dim)
- Algorithm: Lloyd-Max iteration (alternating boundary and centroid updates)
- Step 1: boundaries = midpoints of adjacent centroids
- Step 2: centroids = Gaussian conditional mean
E[X | a < X < b]
(computed analytically via torch.erf)
- Output:
_lm_centroids [16] centroids, _lm_boundaries [15] decision boundaries
- The 15 boundary values are unpacked into
_b0…_b14 for use in kernel comparisons
2. Quantize Kernel (turboquant_kv_prefill — @pl.jit.inline)
Executes 4 scopes per KV head during prefill:
Scope A: K RoPE + L2 norm + normalize
├─ K RoPE (cos/sin rotary transform)
├─ L2 norm = sqrt(sum(x²) + ε) per row
├─ normalize: x / L2_norm → BF16, write to k_norm_buf
└─ L2 norm written to k_scales_buf
Scope B: K rotate + Lloyd-Max quantize
├─ k_norm_buf × rot_matrix → k_rot (BF16×BF16→FP32)
├─ 15-way boundary search: cumulative GE (≥_b0) + … + (≥_b14)
│ → INT4 index in [0, 15]
└─ FP32→INT32→FP16→UINT8, write to quant_k_temp
Scope C: V L2 norm + normalize (no RoPE for V)
├─ L2 norm = sqrt(sum(x²) + ε)
├─ normalize → BF16, write to v_norm_buf
└─ L2 norm written to v_scales_buf
Scope D: V rotate + Lloyd-Max quantize
└─ same as Scope B, write to quant_v_temp
Quantized outputs:
| Output |
Shape |
Dtype |
Description |
quant_k/v_cache |
[total_rows, HEAD_DIM] |
UINT8 |
INT4 indices, range [0, 15] |
quant_k/v_scales |
[total_rows, 1] |
FP32 |
per-row L2 norm |
Dequantization
turboquant_kv_dequant_chunk — @pl.jit.inline
Dequantizes one CMP_CHUNK (32 rows) of INT4 KV cache. All three steps are
inlined into the caller's pl.at scope:
Step 1: Gather (codebook lookup)
UINT8 indices → FP16 cast → INT32 cast → pl.gather(codebook) → FP32 centroids
Step 2: Renormalize + Scale
centroid → rsqrt(row_sum(sq) + ε) normalize to unit sphere → × stored L2 scale
Step 3: Unrotate (rotate back to original space)
→ BF16 cast → matmul(rot_slice^T) → BF16 → write to out_bf16
Prefill QK attention (dequant K on the fly)
with pl.at("qk_dequant"):
for each CMP_CHUNK (4 iterations for BLOCK_SIZE=128):
chunk_out = pl.create_tensor([32, 128], BF16)
turboquant_kv_dequant_chunk(indices, scales, codebook, rot, chunk_out)
k_bf16_buf = pl.assemble(chunk_out) # directly BF16, no extra cast
with pl.at("qk_matmul"):
raw_scores = pl.matmul(q_padded, k_bf16_buf, b_trans=True)
Prefill SV attention (dequant V on the fly)
with pl.at("sv_dequant"):
for each CMP_CHUNK:
chunk_out = pl.create_tensor([32, 128], BF16)
turboquant_kv_dequant_chunk(indices, scales, codebook, rot, chunk_out)
v_tile_full = pl.assemble(chunk_out) # directly BF16
with pl.at("sv_matmul"):
oi_tmp = pl.matmul(exp_tile, v_tile_full)
Decode dequant
Identical to prefill, using BLOCK_SIZE=128 (same as SEQ_TILE).
Key Constants
| Constant |
Value |
Notes |
N_LEVELS |
16 |
INT4, 16 quantization levels |
CMP_CHUNK |
32 |
Dequant sub-chunk rows (32×1B = 32-byte aligned) |
SEQ_TILE / BLOCK_SIZE |
128 |
Attention sequence tile |
TOK_TILE |
16 |
Prefill quantize token tile |
CMP_TILE / CMP_TILE_SV |
64 |
Dequant sub-tile |
HEAD_DIM |
128 |
Q/K/V head dimension (rotation → Gaussian for d≥64) |
Dependency Graph
turboquant_kv.py ──► qwen3_14b_prefill_tq.py (prefill_layer_tq)
──► qwen3_14b_decode_tq.py (decode_layer_tq)
Exported symbols:
- solve_lloyd_max(d, bits) # CPU-side Lloyd-Max solver
- _lm_centroids, _lm_boundaries # Module-level codebook
- turboquant_kv_prefill(...) # Quantize (prefill only)
- turboquant_kv_dequant_chunk(...) # Dequant (shared by prefill & decode)
Design Decisions
-
Shared rotation matrix across prefill and decode: rot_matrices is
generated once via torch.linalg.qr(torch.randn(head_dim, head_dim))
(seed=42) and shared between quantization and dequantization to ensure
the rotated domain is consistent.
-
BF16 dequant output: avoids extra pl.cast at call sites, saving
one buffer and one scope per attention operation.
-
No internal pl.at in inline dequant function:
turboquant_kv_dequant_chunk does not contain pl.at scopes. All
ops are inlined into the caller's scope, avoiding nested InCore
functions that the compiler cannot resolve (call directions between
nested scopes are undefined).
-
Renormalize corrects quantization drift: centroid vectors gathered
from the codebook are not guaranteed to lie on the unit sphere.
Dequant always renormalizes via rsqrt(row_sum(sq) + EPS) before
rescaling, preserving numerical stability.
Motivation / Use Case
TurboQuant compresses the KV cache to INT4, reducing memory usage to 1/4 of
BF16. It is based on PolarQuant: scalar quantization via Lloyd-Max in a
randomly-rotated space. After rotation by a random orthogonal matrix, each
coordinate of a d-dimensional unit vector follows approximately
N(0, 1/head_dim) (for d >= 64), enabling near-optimal quantization.
Proposed API / Behavior
No response
Alternatives Considered
No response
Additional Context
No response
Summary
TurboQuant — PolarQuant INT4 KV Cache for Qwen3-14B
File Layout
turboquant_kv.pyqwen3_14b_prefill_tq.pyqwen3_14b_decode_tq.pyEnd-to-End Flow
Quantization
1. Lloyd-Max Codebook (computed once at module load)
solve_lloyd_max(head_dim, bits=4)inturboquant_kv.pysolves for theoptimal 16-level scalar quantizer under
N(0, 1/head_dim):σ = 1/sqrt(head_dim)E[X | a < X < b](computed analytically via
torch.erf)_lm_centroids[16] centroids,_lm_boundaries[15] decision boundaries_b0…_b14for use in kernel comparisons2. Quantize Kernel (
turboquant_kv_prefill—@pl.jit.inline)Executes 4 scopes per KV head during prefill:
Quantized outputs:
quant_k/v_cachequant_k/v_scalesDequantization
turboquant_kv_dequant_chunk—@pl.jit.inlineDequantizes one CMP_CHUNK (32 rows) of INT4 KV cache. All three steps are
inlined into the caller's
pl.atscope:Prefill QK attention (dequant K on the fly)
Prefill SV attention (dequant V on the fly)
Decode dequant
Identical to prefill, using
BLOCK_SIZE=128(same asSEQ_TILE).Key Constants
N_LEVELSCMP_CHUNKSEQ_TILE/BLOCK_SIZETOK_TILECMP_TILE/CMP_TILE_SVHEAD_DIMDependency Graph
Design Decisions
Shared rotation matrix across prefill and decode:
rot_matricesisgenerated once via
torch.linalg.qr(torch.randn(head_dim, head_dim))(seed=42) and shared between quantization and dequantization to ensure
the rotated domain is consistent.
BF16 dequant output: avoids extra
pl.castat call sites, savingone buffer and one scope per attention operation.
No internal
pl.atin inline dequant function:turboquant_kv_dequant_chunkdoes not containpl.atscopes. Allops are inlined into the caller's scope, avoiding nested InCore
functions that the compiler cannot resolve (call directions between
nested scopes are undefined).
Renormalize corrects quantization drift: centroid vectors gathered
from the codebook are not guaranteed to lie on the unit sphere.
Dequant always renormalizes via
rsqrt(row_sum(sq) + EPS)beforerescaling, preserving numerical stability.
Motivation / Use Case
TurboQuant compresses the KV cache to INT4, reducing memory usage to 1/4 of
BF16. It is based on PolarQuant: scalar quantization via Lloyd-Max in a
randomly-rotated space. After rotation by a random orthogonal matrix, each
coordinate of a
d-dimensional unit vector follows approximatelyN(0, 1/head_dim)(ford >= 64), enabling near-optimal quantization.Proposed API / Behavior
No response
Alternatives Considered
No response
Additional Context
No response