Skip to content

[Feature] TurboQuant — PolarQuant INT4 KV Cache for Qwen3-14B #544

Description

@sunghajung6688

Summary

TurboQuant — PolarQuant INT4 KV Cache for Qwen3-14B

File Layout

File Role
turboquant_kv.py Lloyd-Max codebook solver + quantize function + dequant function
qwen3_14b_prefill_tq.py Prefill end-to-end: quantize K/V → write cache → dequant for attention
qwen3_14b_decode_tq.py Decode end-to-end: read quantized cache → dequant → attention

End-to-End Flow

┌─────────────────────────────────────────────────────────────────┐
│                     Prefill (compress)                           │
│                                                                  │
│  K/V proj ──► RoPE(K) / L2 norm ──► rotate ──► INT4 quantize    │
│                                         │                        │
│                write quant_k/v_cache ◄───┘                       │
│                                         │                        │
│             Attention reads ──► dequant ◄┘                       │
│                                                                  │
├─────────────────────────────────────────────────────────────────┤
│                     Decode (decompress)                          │
│                                                                  │
│   read INT4 from quant_k/v_cache ──► dequant ──► Attention      │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

Quantization

1. Lloyd-Max Codebook (computed once at module load)

solve_lloyd_max(head_dim, bits=4) in turboquant_kv.py solves for the
optimal 16-level scalar quantizer under N(0, 1/head_dim):

  • Input: distribution std σ = 1/sqrt(head_dim)
  • Algorithm: Lloyd-Max iteration (alternating boundary and centroid updates)
    • Step 1: boundaries = midpoints of adjacent centroids
    • Step 2: centroids = Gaussian conditional mean E[X | a < X < b]
      (computed analytically via torch.erf)
  • Output: _lm_centroids [16] centroids, _lm_boundaries [15] decision boundaries
  • The 15 boundary values are unpacked into _b0_b14 for use in kernel comparisons

2. Quantize Kernel (turboquant_kv_prefill@pl.jit.inline)

Executes 4 scopes per KV head during prefill:

Scope A: K RoPE + L2 norm + normalize
  ├─ K RoPE (cos/sin rotary transform)
  ├─ L2 norm = sqrt(sum(x²) + ε) per row
  ├─ normalize: x / L2_norm → BF16, write to k_norm_buf
  └─ L2 norm written to k_scales_buf

Scope B: K rotate + Lloyd-Max quantize
  ├─ k_norm_buf × rot_matrix → k_rot (BF16×BF16→FP32)
  ├─ 15-way boundary search: cumulative GE (≥_b0) + … + (≥_b14)
  │   → INT4 index in [0, 15]
  └─ FP32→INT32→FP16→UINT8, write to quant_k_temp

Scope C: V L2 norm + normalize (no RoPE for V)
  ├─ L2 norm = sqrt(sum(x²) + ε)
  ├─ normalize → BF16, write to v_norm_buf
  └─ L2 norm written to v_scales_buf

Scope D: V rotate + Lloyd-Max quantize
  └─ same as Scope B, write to quant_v_temp

Quantized outputs:

Output Shape Dtype Description
quant_k/v_cache [total_rows, HEAD_DIM] UINT8 INT4 indices, range [0, 15]
quant_k/v_scales [total_rows, 1] FP32 per-row L2 norm

Dequantization

turboquant_kv_dequant_chunk@pl.jit.inline

Dequantizes one CMP_CHUNK (32 rows) of INT4 KV cache. All three steps are
inlined into the caller's pl.at scope:

Step 1: Gather (codebook lookup)
  UINT8 indices → FP16 cast → INT32 cast → pl.gather(codebook) → FP32 centroids

Step 2: Renormalize + Scale
  centroid → rsqrt(row_sum(sq) + ε) normalize to unit sphere → × stored L2 scale

Step 3: Unrotate (rotate back to original space)
  → BF16 cast → matmul(rot_slice^T) → BF16 → write to out_bf16

Prefill QK attention (dequant K on the fly)

with pl.at("qk_dequant"):
    for each CMP_CHUNK (4 iterations for BLOCK_SIZE=128):
        chunk_out = pl.create_tensor([32, 128], BF16)
        turboquant_kv_dequant_chunk(indices, scales, codebook, rot, chunk_out)
        k_bf16_buf = pl.assemble(chunk_out)  # directly BF16, no extra cast

with pl.at("qk_matmul"):
    raw_scores = pl.matmul(q_padded, k_bf16_buf, b_trans=True)

Prefill SV attention (dequant V on the fly)

with pl.at("sv_dequant"):
    for each CMP_CHUNK:
        chunk_out = pl.create_tensor([32, 128], BF16)
        turboquant_kv_dequant_chunk(indices, scales, codebook, rot, chunk_out)
        v_tile_full = pl.assemble(chunk_out)  # directly BF16

with pl.at("sv_matmul"):
    oi_tmp = pl.matmul(exp_tile, v_tile_full)

Decode dequant

Identical to prefill, using BLOCK_SIZE=128 (same as SEQ_TILE).

Key Constants

Constant Value Notes
N_LEVELS 16 INT4, 16 quantization levels
CMP_CHUNK 32 Dequant sub-chunk rows (32×1B = 32-byte aligned)
SEQ_TILE / BLOCK_SIZE 128 Attention sequence tile
TOK_TILE 16 Prefill quantize token tile
CMP_TILE / CMP_TILE_SV 64 Dequant sub-tile
HEAD_DIM 128 Q/K/V head dimension (rotation → Gaussian for d≥64)

Dependency Graph

turboquant_kv.py  ──►  qwen3_14b_prefill_tq.py  (prefill_layer_tq)
                   ──►  qwen3_14b_decode_tq.py   (decode_layer_tq)

Exported symbols:
  - solve_lloyd_max(d, bits)           # CPU-side Lloyd-Max solver
  - _lm_centroids, _lm_boundaries      # Module-level codebook
  - turboquant_kv_prefill(...)         # Quantize (prefill only)
  - turboquant_kv_dequant_chunk(...)   # Dequant (shared by prefill & decode)

Design Decisions

  1. Shared rotation matrix across prefill and decode: rot_matrices is
    generated once via torch.linalg.qr(torch.randn(head_dim, head_dim))
    (seed=42) and shared between quantization and dequantization to ensure
    the rotated domain is consistent.

  2. BF16 dequant output: avoids extra pl.cast at call sites, saving
    one buffer and one scope per attention operation.

  3. No internal pl.at in inline dequant function:
    turboquant_kv_dequant_chunk does not contain pl.at scopes. All
    ops are inlined into the caller's scope, avoiding nested InCore
    functions that the compiler cannot resolve (call directions between
    nested scopes are undefined).

  4. Renormalize corrects quantization drift: centroid vectors gathered
    from the codebook are not guaranteed to lie on the unit sphere.
    Dequant always renormalizes via rsqrt(row_sum(sq) + EPS) before
    rescaling, preserving numerical stability.

Motivation / Use Case

TurboQuant compresses the KV cache to INT4, reducing memory usage to 1/4 of
BF16. It is based on PolarQuant: scalar quantization via Lloyd-Max in a
randomly-rotated space. After rotation by a random orthogonal matrix, each
coordinate of a d-dimensional unit vector follows approximately
N(0, 1/head_dim) (for d >= 64), enabling near-optimal quantization.

Proposed API / Behavior

No response

Alternatives Considered

No response

Additional Context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Fields

    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions