Skip to content

Benchmark: BF16/FP8 mixed-precision for cuTile, cuTe, flash solvers #12

@Jim137

Description

@Jim137

Benchmark: BF16/FP8 mixed-precision for cuTile & flash solvers

Following up on #8, this documents benchmarks for mixed-precision (bfloat16, float8_e4m3fn) support added to all ansatzes (pz_encoding, rpz_encoding, real) for both flash (Triton) and cutile solvers.

TL;DR

  • BF16/FP8 full pipeline: 2.3-2.5x faster training, 45% less peak memory, identical convergence
  • FP8 prescaled states: additional memory savings over BF16 for backward state checkpoints
  • Coalesced state layout: 15-50% backward kernel speedup for both Triton and cuTile

Key Optimizations

1. BF16/FP8 mixed-precision for all ansatzes

Previously only real ansatz supported BF16. Now all ansatzes support c_dtype = torch.float32 | torch.bfloat16 | torch.float8_e4m3fn.

  • Forward: upcast inputs to f32 (free identity cast for f32), compute in f32, downcast output
  • Backward: store state checkpoints in bf16 or fp8, halving/quartering the dominant memory traffic
  • FP8: prescale by 224.0 before fp8 store — QKAN states are bounded in [-1,1] (unitarity), mapping perfectly to fp8 range. No per-block scaling needed.

2. Coalesced state checkpoint layout

Transposed state buffer from (N, S, BLOCK_B, 4) to (N, S, 4, BLOCK_B). The old layout caused stride-4 gather/scatter access, wasting 75% of memory bandwidth. The new layout makes the BLOCK_B dimension contiguous. Applied to both cuTile and Triton backends.

Benchmark: GPT-2 HQKANsformer — TinyShakespeare

Model: GPT-2 (12L, 12H, 768E) with Linear(768,10) -> QKAN([10,10], reps=1) -> Linear(10,768) replacing MLP
Data: TinyShakespeare (char-level), batch size 1, block size 1024
Training: AdamW lr=0.0003, betas=(0.9,0.95), weight_decay=0.1, grad_clip=1.0, 500 iters

Dtype column = c_dtype. For bf16/fp8 rows, p_dtype=torch.bfloat16 and full model runs in bf16 (model.to(torch.bfloat16)). For f32 rows, p_dtype=torch.float32.

Variant Ansatz c_dtype Params Init Forward Train Step Step vs f32_pz 500 Steps Peak Mem Avg Mem Final Loss
flash_pz pz f32 29,428,824 151.1 ms 7.713 ms 24.769 ms 1.00x 12.4 s 681.0 MiB 363.6 MiB 2.5671
flash_pz pz bf16 29,428,824 149.0 ms 2.879 ms 10.594 ms 2.34x 5.3 s 377.3 MiB 209.3 MiB 2.5625
flash_pz pz fp8 29,428,824 142.0 ms 2.856 ms 10.592 ms 2.34x 5.3 s 375.4 MiB 202.1 MiB 2.5625
flash_real real f32 29,425,224 151.2 ms 8.257 ms 24.792 ms 1.00x 12.4 s 687.3 MiB 368.1 MiB 2.5787
flash_real real bf16 29,425,224 153.1 ms 2.812 ms 10.471 ms 2.37x 5.2 s 373.0 MiB 204.8 MiB 2.5625
flash_real real fp8 29,425,224 152.8 ms 2.391 ms 10.190 ms 2.43x 5.1 s 372.9 MiB 200.7 MiB 2.5625
cutile_pz pz f32 29,428,824 133.6 ms 7.653 ms 24.637 ms 1.01x 12.3 s 684.0 MiB 364.8 MiB 2.5526
cutile_pz pz bf16 29,428,824 147.4 ms 2.547 ms 10.668 ms 2.32x 5.3 s 376.9 MiB 204.4 MiB 2.5781
cutile_pz pz fp8 29,428,824 155.0 ms 2.123 ms 10.812 ms 2.29x 5.4 s 373.9 MiB 202.2 MiB 2.5625
cutile_real real f32 29,425,224 224.9 ms 6.715 ms 24.981 ms 0.99x 12.5 s 683.6 MiB 364.9 MiB 2.5705
cutile_real real bf16 29,425,224 130.3 ms 2.829 ms 10.200 ms 2.43x 5.1 s 379.5 MiB 208.0 MiB 2.5781
cutile_real real fp8 29,425,224 153.7 ms 2.587 ms 10.025 ms 2.47x 5.0 s 374.0 MiB 201.4 MiB 2.5625

Key Observations

  • BF16/FP8 halves memory: Peak 681→373 MiB (45% reduction), Avg 364→201 MiB (45% reduction)
  • 2.3-2.5x faster across all solver/ansatz/dtype combinations
  • Best config: cutile_real fp8 — 10.0 ms/step, 374 MiB peak, 2.47x speedup
  • All dtypes converge identically to loss ~2.56 after 500 steps
  • FP8 saves 2-7 MiB more than BF16 (fp8 state checkpoints vs bf16)

Benchmark: Isolated QKAN Kernel (B5)

Model: QKAN([100, 100], reps=3), batch=1000 — where QKAN IS the workload

Solver Ansatz Dtype Step (ms) vs flash f32
flash pz f32 7.24 baseline
flash pz bf16 5.66 1.28x
flash pz fp8 5.29 1.37x
cutile pz f32 7.03 1.03x
cutile pz bf16 7.04 1.03x
cutile pz fp8 6.01 1.21x
cutile real bf16 2.65
cutile real fp8 1.33

Coalesced layout alone improved cuTile real bf16: 0.92 ms → 0.60 ms (1.52x).

c_dtype and p_dtype

  • p_dtype — parameter storage dtype (torch.float32 or torch.bfloat16). Controls theta, preacts, and output precision.
  • c_dtype — compute dtype for quantum simulation kernels. Controls kernel I/O and backward state checkpoint precision.

Previously only torch.complex64 (default, maps to f32 compute) and torch.bfloat16 (real ansatz only) were supported. This PR adds torch.bfloat16 for pz/rpz and torch.float8_e4m3fn for all ansatzes.

c_dtype Kernel I/O State checkpoints Compute
torch.complex64 / torch.float32 f32 f32 f32
torch.bfloat16 bf16 bf16 f32
torch.float8_e4m3fn bf16 fp8 (prescaled) f32

Note: p_dtype=torch.float8_e4m3fn is not supported — PyTorch has no FP8 arithmetic kernels, so optimizer updates and parameter init would fail. FP8 is a storage format for activations/checkpoints, not parameters.

When to Use Each Dtype

Scenario Recommendation
Full model training (GPT-2 etc.) p_dtype=torch.bfloat16 + c_dtype=torch.bfloat16 — 2.3x speedup, 45% less memory
Large QKAN dimensions (dim>=50, reps>=3) c_dtype=torch.float8_e4m3fn — extra backward speedup from fp8 state checkpoints
Small QKAN (dim<=20, reps=1) in f32 model c_dtype=torch.float32 — kernel is launch-overhead dominated, dtype conversion adds overhead
Accuracy-critical applications c_dtype=torch.float32 — full f32 precision throughout

Usage

import torch
from qkan import QKAN

# BF16 (recommended for training)
model = QKAN([100, 100], reps=3, solver="cutile", ansatz="pz_encoding",
             c_dtype=torch.bfloat16, p_dtype=torch.bfloat16)

# FP8 prescaled states (max backward throughput)
model = QKAN([100, 100], reps=3, solver="cutile", ansatz="pz_encoding",
             c_dtype=torch.float8_e4m3fn, p_dtype=torch.bfloat16)

# For full-model BF16 pipeline (recommended):
model = YourModel(...).to(torch.bfloat16)

p_dtype vs c_dtype Performance Matrix

p_dtype controls parameter storage, c_dtype controls kernel precision. They are independent.

Isolated QKAN kernel (flash pz, dim=100, reps=3, batch=1000):

p_dtype c_dtype Fwd (ms) Bwd (ms) Total (ms) Speedup
f32 f32 0.32 1.60 1.93 baseline
f32 bf16 0.42 1.19 1.60 1.20x
f32 fp8 0.42 1.12 1.54 1.25x
bf16 f32 0.29 1.56 1.84 1.04x
bf16 bf16 0.33 1.09 1.42 1.36x
bf16 fp8 0.33 1.06 1.39 1.38x
  • bf16/bf16 is the sweet spot: avoids dtype conversion at boundary + bf16 state checkpoints
  • f32/bf16: pays f32→bf16 conversion cost but still benefits from bf16 backward
  • bf16/f32: bf16 params upcast to f32 for kernel, minimal gain
  • For full-model training, use model.to(torch.bfloat16) so p_dtype matches the rest of the model

Hardware

  • GPU: NVIDIA GeForce RTX 5090
  • PyTorch 2.11.0+cu130, CUDA 13.0
  • Triton 3.6.0, cuTile 1.2.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    documentationImprovements or additions to documentation

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions