Skip to content

[Feature] TurboQuant NPU Kernel Integration #20

Description

@sunghajung6688

TurboQuant KV Cache Compression — Design Document

Overview

TurboQuant is a 4-bit KV cache compression method for Qwen3-14B inference on
Ascend NPU (910B/C, 950). It uses PolarQuant (random rotation + optimal
scalar quantization) to compress K/V caches from BF16 (~16 bits/value) down to
4 bits/value. The reference implementation is adapted from the ICLR 2026 paper
TurboQuant.

Key insight: 4-bit PolarQuant with 16 Lloyd-Max centroids outperforms 3-bit
PolarQuant + 1-bit QJL error correction at the same bit rate. QJL is excluded
from this implementation.

Architecture

Files

File Lines Role
turboquant_kv.py 337 Lloyd-Max codebook, shared KV quantize/dequant primitives
qwen3_14b_prefill_tq.py 1362 Full prefill forward with inline TQ compression
qwen3_14b_decode_tq.py 1322 Full decode forward with TQ dequant attention
npu_executor.py - Compile-time: rotation matrix generation, codebook creation
cpu_executor.py - Reference CPU executor with TQ support
python/core/turboquant/ - CPU-side compressor, rotation matrix, codebook helpers

Runtime Pipeline

Input tokens
    │
    ├── Prefill (prompt): prefill_fwd_tq
    │     └── For each transformer layer:
    │           ├── RMSNorm → Q/K/V projection → per-head RMSNorm
    │           ├── K/V: RoPE → L2 norm → rotate → Lloyd-Max quantize → store UINT8 + FP32 scale
    │           ├── Q:   per-head RMSNorm → RoPE
    │           ├── Attention: dequant K/V from cache → QK matmul → softmax → SV matmul
    │           └── Output proj + residual + RMSNorm + SwiGLU MLP + residual
    │
    ├── Decode (autoregressive): decode_fwd_tq
    │     └── For each token, for each transformer layer:
    │           ├── RMSNorm → Q/K/V projection → per-head RMSNorm
    │           ├── K/V: RoPE → L2 norm → rotate → quantize → append to cache
    │           ├── Q:   per-head RMSNorm → RoPE (one head at a time)
    │           ├── Attention: dequant all K/V from cache → QK matmul → softmax → SV matmul
    │           └── Output proj + residual + RMSNorm + SwiGLU MLP + residual
    │
    └── RMSNorm head → LM head → logits → sample next token

PolarQuant Algorithm

Quantize

Given: K/V vector x ∈ R^d (post-RoPE for K, raw for V)

1. L2 norm:        γ = sqrt(‖x‖² + ε)
2. Normalize:      x̂ = x / γ                    
3. Rotate:         y = x̂_bf16 @ R               (R: Haar orthogonal, BF16)
4. Lloyd-Max:      idx[j] = Σ 1(y[j] ≥ b_i)     (15 boundary comparisons → 0–15)
5. Store:          (indices as UINT8, γ as FP32 scale)

Dequantize

Given: UINT8 indices, FP32 scale γ, rotation matrix R

1. Gather:         ŷ = codebook[indices]          (FP32 centroids in rotated domain)
2. Renormalize:    ŷ = ŷ / ‖ŷ‖                    (correct quantization norm drift)
3. Rescale:        ŷ = ŷ · γ                      (restore original magnitude)
4. Unrotate:       x̂ = ŷ_bf16 @ R^T               (back to original space, BF16)

Lloyd-Max Codebook

  • Distribution: N(0, 1/√d) — coordinates after rotation are i.i.d. Gaussian
  • Solved analytically via Gaussian conditional expectation (iterative)
  • 4-bit → 16 levels: [-0.2416, -0.1829, ..., 0.1829, 0.2416]
  • 15 boundary thresholds: midpoints between adjacent centroids
  • Computed once at module load in turboquant_kv.solve_lloyd_max()

Rotation Matrix

  • Per-layer dense orthogonal matrix (128×128, BF16)
  • Generated via QR decomposition: Q = qr(randn(128,128))
  • Normalized: Q = Q * sign(diag(R)) for consistent sign between CPU and NPU
  • Seed scheme: seed = 42 + layer_idx * 1000 (deterministic, per-layer)
  • Stored as [num_layers * head_dim, head_dim] stacked BF16 tensor

Prefill Attention (with TQ)

prefill_fwd_tq
  │
  ├── Expand codebook: [1, 16] → [CMP_CHUNK, 16] for dequant gather
  │
  └── For each layer:
        prefill_layer_tq
          │
          ├── Scope 1: RMSNorm → Q/K/V proj → per-head Q/K RMSNorm
          │
          ├── Scope 2a: turboquant_kv_prefill()
          │     For each KV head (8 heads, batched by 8):
          │       A: K RoPE + L2 norm + normalize → k_norm_buf (BF16)
          │       B: k_norm_bf16 @ R → k_rot → 15× boundary cmp → UINT8 indices
          │       C: V L2 norm + normalize → v_norm_buf (BF16)
          │       D: v_norm_bf16 @ R → v_rot → 15× boundary cmp → UINT8 indices
          │     → quant_k_temp (UINT8), k_scales_buf (FP32), V equivalents
          │
          ├── Scope 2b: Cache write + Q RoPE + Causal Attention
          │     For each token:
          │       - Scatter quantized K/V indices + scales to paged cache
          │       - Q RoPE + pad into all_q_padded (BF16, 16 rows per KV head group)
          │       - For each Q group (8 groups total):
          │           q_padded @ k_dequant^T → raw_scores
          │           softmax → attn_weights
          │           attn_weights @ v_dequant → ctx
          │           Online softmax accumulation across sequence blocks
          │       - ctx = oi / li → attention output
          │
          └── Scope 3: out_proj + residual + post RMSNorm + SwiGLU MLP

Decode Attention (with TQ)

decode_fwd_tq
  │
  ├── Expand codebook: [1, 16] → [CMP_CHUNK, 16]
  │
  └── For each layer:
        decode_layer_tq
          │
          ├── Scope 1: RMSNorm → Q/K/V proj → per-head Q/K RMSNorm
          │
          ├── Scope 2a: turboquant_kv_prefill() — quantize new token's K/V
          │     (same as prefill, padded to TOK_TILE=16 rows)
          │
          ├── Scope 2b: Cache scatter + Q RoPE (one head at a time)
          │     CRITICAL: Q_HEAD_BATCH=5 heads per KV head group
          │     Each Q head processed individually as [1, HEAD_DIM]
          │     (avoids pypto compiler bug with [5, HEAD_DIM] batch slice)
          │
          ├── Scope 2c: Attention with dequantized K/V
          │     For each Q group (8 groups) and each sequence block:
          │       qk_dequant:    gather(codebook, UINT8 indices) → BF16
          │       qk_renorm:     rsqrt + scale → BF16
          │       qk_unrotate:   matmul(rot_slice^T, out=FP32) → cast BF16
          │       qk_matmul:     q_padded @ k_dequant^T → raw_scores
          │       softmax:       exp(scores / √d) → BF16
          │       sv_dequant:    V dequant (BF16, same as K)
          │       sv_unrotate:   V unrotate (BF16 matmul)
          │       sv_matmul:     attn @ v_dequant → oi
          │       Online softmax accumulation
          │       ctx = oi / li
          │
          └── Scope 3: out_proj + residual + post RMSNorm + SwiGLU MLP

Dequant Pipeline (Decode Attention)

The entire attention dequant path uses BF16 precision to match the prefill:

quant_k_cache[UINT8] 
  → cast FP16 → cast INT32 
  → gather(codebook) → FP32 centroids
  → cast BF16 → temp_bf16_cache

temp_bf16_cache
  → rsqrt renorm → row_expand_mul(scale) → BF16

temp_bf16_cache  
  → matmul(R^T) → FP32 unrotated → cast BF16 → temp_bf16_cache

temp_bf16_cache
  → copy to k_bf16_buf → QK matmul (BF16 @ BF16^T)

Key Constants

Constant Value Meaning
N_LEVELS 16 4-bit quantization levels
HEAD_DIM 128 Attention head dimension
NUM_KV_HEADS 8 Number of KV heads (GQA)
NUM_HEADS 40 Number of Q heads
Q_HEAD_BATCH 5 Q heads per attention group
Q_HEAD_PAD 16 Padded Q rows (cube alignment)
BLOCK_SIZE 128 KV cache page size
CMP_CHUNK 32 Dequant gather sub-tile
TOK_TILE 16 Prefill token block size
SEQ_TILE 128 Sequence tile for prefill attention

Cache Layout

Quantized KV Cache (per layer):
  quant_k_cache:  [num_layers * total_blocks * NUM_KV_HEADS * BLOCK_SIZE, HEAD_DIM] UINT8
  quant_v_cache:  [same] UINT8
  quant_k_scales: [same rows, 1] FP32
  quant_v_scales: [same rows, 1] FP32

Per cache row: one token's K or V for one KV head
  - 128 bytes of UINT8 indices (4-bit per coordinate, stored as 1 byte)
  - 4 bytes of FP32 L2 scale
  - Total: 132 bytes vs 256 bytes for BF16 → 1.94× compression

Bug Fixes Applied

1. Rotation Matrix Seed Consistency

  • Symptom: NPU and CPU produced different quantized K/V, leading to different model output.
  • Root cause: NPU used torch.manual_seed(42) with sequential generation; CPU used seed=42+layer*1000 with per-layer independent generation.
  • Fix: Use torch.Generator().manual_seed(42 + l * 1000) for both, plus Q * sign(diag(R)) normalization matching CPU convention.

2. Decode Q RoPE — Per-Head Processing

  • Symptom: In decode attention, Q heads 0-1 produced garbage scores (~-1000 vs expected ~0-10), while heads 2-4 were correct.
  • Root cause: pypto compiler bug: pl.slice(q_block, [5, 64], ...) on a [5, 128] tile from pl.reshape(slice, [5, 128]) corrupts the first two rows.
  • Fix: Process each Q head individually as [1, HEAD_DIM] instead of batching 5 heads as [5, HEAD_DIM]. This matches the prefill's approach.

Numerical Precision

  • K/V quantization-dequantization round-trip: cos ≥ 0.9999 (observed at all 40 layers in prefill)
  • Decode BF16 attention vs CPU FP32: layer 0 hidden cos ≈ 0.999
  • Prefill BF16 attention vs CPU FP32: all 40 layers hidden cos ≥ 0.9999

Reference

Validation

Offline Comparison (NPU TQ vs CPU TQ)

Both run the same prompt "Huawei is" with greedy decoding (temperature=0).

Step Metric Result
Prefill quant_k exact match 99.3% (all 40 layers, per-layer ≥97.6%)
Prefill hidden cos (residual stream) ≥0.9999 (all 40 layers)
Decode3 (1st token) quant_k exact match 99.8% (layer 0)
Decode3 (1st token) hidden cos 0.999 (layer 0)

Performance Comparison (Ascend NPU 910B/C, device 15)

Prompt: "Huawei is", max-new-tokens=128, max-seq-len=512.

Metric FP (BF16, no compression) TQ (4-bit) Speedup
Prefill 5467 ms 5240 ms 1.04×
Decode speed 3.15 tok/s 5.56 tok/s 1.76×
Decode latency 317.7 ms/tok 180.0 ms/tok 1.76×
Overall throughput 2.72 tok/s 4.29 tok/s 1.58×

Speedup comes from 4× smaller KV cache reducing memory bandwidth in attention.

Output Quality

FP output: "a Chinese multinational technology company that designs and sells consumer electronics, telecommunications equipment, and services. It is one of the world's largest and most valuable technology companies. Huawei was founded in 1987 by Ren Zhengfei..."

TQ output: "a Chinese multinational technology company headquartered in Shenzhen, Guangdong, China. It is the world's second-largest smartphone manufacturer by unit sales and the world's second-largest provider of telecommunications equipment..."

Both outputs are factually correct and coherent. The difference in wording is normal sampling variation between independent runs.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Fields

    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions