Benchmark: BF16/FP8 mixed-precision for cuTile & flash solvers
Following up on #8, this documents benchmarks for mixed-precision (bfloat16, float8_e4m3fn) support added to all ansatzes (pz_encoding, rpz_encoding, real) for both flash (Triton) and cutile solvers.
TL;DR
- BF16/FP8 full pipeline: 2.3-2.5x faster training, 45% less peak memory, identical convergence
- FP8 prescaled states: additional memory savings over BF16 for backward state checkpoints
- Coalesced state layout: 15-50% backward kernel speedup for both Triton and cuTile
Key Optimizations
1. BF16/FP8 mixed-precision for all ansatzes
Previously only real ansatz supported BF16. Now all ansatzes support c_dtype = torch.float32 | torch.bfloat16 | torch.float8_e4m3fn.
- Forward: upcast inputs to f32 (free identity cast for f32), compute in f32, downcast output
- Backward: store state checkpoints in bf16 or fp8, halving/quartering the dominant memory traffic
- FP8: prescale by 224.0 before fp8 store — QKAN states are bounded in [-1,1] (unitarity), mapping perfectly to fp8 range. No per-block scaling needed.
2. Coalesced state checkpoint layout
Transposed state buffer from (N, S, BLOCK_B, 4) to (N, S, 4, BLOCK_B). The old layout caused stride-4 gather/scatter access, wasting 75% of memory bandwidth. The new layout makes the BLOCK_B dimension contiguous. Applied to both cuTile and Triton backends.
Benchmark: GPT-2 HQKANsformer — TinyShakespeare
Model: GPT-2 (12L, 12H, 768E) with Linear(768,10) -> QKAN([10,10], reps=1) -> Linear(10,768) replacing MLP
Data: TinyShakespeare (char-level), batch size 1, block size 1024
Training: AdamW lr=0.0003, betas=(0.9,0.95), weight_decay=0.1, grad_clip=1.0, 500 iters
Dtype column = c_dtype. For bf16/fp8 rows, p_dtype=torch.bfloat16 and full model runs in bf16 (model.to(torch.bfloat16)). For f32 rows, p_dtype=torch.float32.
| Variant |
Ansatz |
c_dtype |
Params |
Init |
Forward |
Train Step |
Step vs f32_pz |
500 Steps |
Peak Mem |
Avg Mem |
Final Loss |
| flash_pz |
pz |
f32 |
29,428,824 |
151.1 ms |
7.713 ms |
24.769 ms |
1.00x |
12.4 s |
681.0 MiB |
363.6 MiB |
2.5671 |
| flash_pz |
pz |
bf16 |
29,428,824 |
149.0 ms |
2.879 ms |
10.594 ms |
2.34x |
5.3 s |
377.3 MiB |
209.3 MiB |
2.5625 |
| flash_pz |
pz |
fp8 |
29,428,824 |
142.0 ms |
2.856 ms |
10.592 ms |
2.34x |
5.3 s |
375.4 MiB |
202.1 MiB |
2.5625 |
| flash_real |
real |
f32 |
29,425,224 |
151.2 ms |
8.257 ms |
24.792 ms |
1.00x |
12.4 s |
687.3 MiB |
368.1 MiB |
2.5787 |
| flash_real |
real |
bf16 |
29,425,224 |
153.1 ms |
2.812 ms |
10.471 ms |
2.37x |
5.2 s |
373.0 MiB |
204.8 MiB |
2.5625 |
| flash_real |
real |
fp8 |
29,425,224 |
152.8 ms |
2.391 ms |
10.190 ms |
2.43x |
5.1 s |
372.9 MiB |
200.7 MiB |
2.5625 |
| cutile_pz |
pz |
f32 |
29,428,824 |
133.6 ms |
7.653 ms |
24.637 ms |
1.01x |
12.3 s |
684.0 MiB |
364.8 MiB |
2.5526 |
| cutile_pz |
pz |
bf16 |
29,428,824 |
147.4 ms |
2.547 ms |
10.668 ms |
2.32x |
5.3 s |
376.9 MiB |
204.4 MiB |
2.5781 |
| cutile_pz |
pz |
fp8 |
29,428,824 |
155.0 ms |
2.123 ms |
10.812 ms |
2.29x |
5.4 s |
373.9 MiB |
202.2 MiB |
2.5625 |
| cutile_real |
real |
f32 |
29,425,224 |
224.9 ms |
6.715 ms |
24.981 ms |
0.99x |
12.5 s |
683.6 MiB |
364.9 MiB |
2.5705 |
| cutile_real |
real |
bf16 |
29,425,224 |
130.3 ms |
2.829 ms |
10.200 ms |
2.43x |
5.1 s |
379.5 MiB |
208.0 MiB |
2.5781 |
| cutile_real |
real |
fp8 |
29,425,224 |
153.7 ms |
2.587 ms |
10.025 ms |
2.47x |
5.0 s |
374.0 MiB |
201.4 MiB |
2.5625 |
Key Observations
- BF16/FP8 halves memory: Peak 681→373 MiB (45% reduction), Avg 364→201 MiB (45% reduction)
- 2.3-2.5x faster across all solver/ansatz/dtype combinations
- Best config:
cutile_real fp8 — 10.0 ms/step, 374 MiB peak, 2.47x speedup
- All dtypes converge identically to loss ~2.56 after 500 steps
- FP8 saves 2-7 MiB more than BF16 (fp8 state checkpoints vs bf16)
Benchmark: Isolated QKAN Kernel (B5)
Model: QKAN([100, 100], reps=3), batch=1000 — where QKAN IS the workload
| Solver |
Ansatz |
Dtype |
Step (ms) |
vs flash f32 |
| flash |
pz |
f32 |
7.24 |
baseline |
| flash |
pz |
bf16 |
5.66 |
1.28x |
| flash |
pz |
fp8 |
5.29 |
1.37x |
| cutile |
pz |
f32 |
7.03 |
1.03x |
| cutile |
pz |
bf16 |
7.04 |
1.03x |
| cutile |
pz |
fp8 |
6.01 |
1.21x |
| cutile |
real |
bf16 |
2.65 |
— |
| cutile |
real |
fp8 |
1.33 |
— |
Coalesced layout alone improved cuTile real bf16: 0.92 ms → 0.60 ms (1.52x).
c_dtype and p_dtype
p_dtype — parameter storage dtype (torch.float32 or torch.bfloat16). Controls theta, preacts, and output precision.
c_dtype — compute dtype for quantum simulation kernels. Controls kernel I/O and backward state checkpoint precision.
Previously only torch.complex64 (default, maps to f32 compute) and torch.bfloat16 (real ansatz only) were supported. This PR adds torch.bfloat16 for pz/rpz and torch.float8_e4m3fn for all ansatzes.
c_dtype |
Kernel I/O |
State checkpoints |
Compute |
torch.complex64 / torch.float32 |
f32 |
f32 |
f32 |
torch.bfloat16 |
bf16 |
bf16 |
f32 |
torch.float8_e4m3fn |
bf16 |
fp8 (prescaled) |
f32 |
Note: p_dtype=torch.float8_e4m3fn is not supported — PyTorch has no FP8 arithmetic kernels, so optimizer updates and parameter init would fail. FP8 is a storage format for activations/checkpoints, not parameters.
When to Use Each Dtype
| Scenario |
Recommendation |
| Full model training (GPT-2 etc.) |
p_dtype=torch.bfloat16 + c_dtype=torch.bfloat16 — 2.3x speedup, 45% less memory |
| Large QKAN dimensions (dim>=50, reps>=3) |
c_dtype=torch.float8_e4m3fn — extra backward speedup from fp8 state checkpoints |
| Small QKAN (dim<=20, reps=1) in f32 model |
c_dtype=torch.float32 — kernel is launch-overhead dominated, dtype conversion adds overhead |
| Accuracy-critical applications |
c_dtype=torch.float32 — full f32 precision throughout |
Usage
import torch
from qkan import QKAN
# BF16 (recommended for training)
model = QKAN([100, 100], reps=3, solver="cutile", ansatz="pz_encoding",
c_dtype=torch.bfloat16, p_dtype=torch.bfloat16)
# FP8 prescaled states (max backward throughput)
model = QKAN([100, 100], reps=3, solver="cutile", ansatz="pz_encoding",
c_dtype=torch.float8_e4m3fn, p_dtype=torch.bfloat16)
# For full-model BF16 pipeline (recommended):
model = YourModel(...).to(torch.bfloat16)
p_dtype vs c_dtype Performance Matrix
p_dtype controls parameter storage, c_dtype controls kernel precision. They are independent.
Isolated QKAN kernel (flash pz, dim=100, reps=3, batch=1000):
| p_dtype |
c_dtype |
Fwd (ms) |
Bwd (ms) |
Total (ms) |
Speedup |
| f32 |
f32 |
0.32 |
1.60 |
1.93 |
baseline |
| f32 |
bf16 |
0.42 |
1.19 |
1.60 |
1.20x |
| f32 |
fp8 |
0.42 |
1.12 |
1.54 |
1.25x |
| bf16 |
f32 |
0.29 |
1.56 |
1.84 |
1.04x |
| bf16 |
bf16 |
0.33 |
1.09 |
1.42 |
1.36x |
| bf16 |
fp8 |
0.33 |
1.06 |
1.39 |
1.38x |
- bf16/bf16 is the sweet spot: avoids dtype conversion at boundary + bf16 state checkpoints
- f32/bf16: pays f32→bf16 conversion cost but still benefits from bf16 backward
- bf16/f32: bf16 params upcast to f32 for kernel, minimal gain
- For full-model training, use
model.to(torch.bfloat16) so p_dtype matches the rest of the model
Hardware
- GPU: NVIDIA GeForce RTX 5090
- PyTorch 2.11.0+cu130, CUDA 13.0
- Triton 3.6.0, cuTile 1.2.0
Benchmark: BF16/FP8 mixed-precision for cuTile & flash solvers
Following up on #8, this documents benchmarks for mixed-precision (
bfloat16,float8_e4m3fn) support added to all ansatzes (pz_encoding,rpz_encoding,real) for bothflash(Triton) andcutilesolvers.TL;DR
Key Optimizations
1. BF16/FP8 mixed-precision for all ansatzes
Previously only
realansatz supported BF16. Now all ansatzes supportc_dtype=torch.float32|torch.bfloat16|torch.float8_e4m3fn.2. Coalesced state checkpoint layout
Transposed state buffer from
(N, S, BLOCK_B, 4)to(N, S, 4, BLOCK_B). The old layout caused stride-4 gather/scatter access, wasting 75% of memory bandwidth. The new layout makes the BLOCK_B dimension contiguous. Applied to both cuTile and Triton backends.Benchmark: GPT-2 HQKANsformer — TinyShakespeare
Model: GPT-2 (12L, 12H, 768E) with
Linear(768,10) -> QKAN([10,10], reps=1) -> Linear(10,768)replacing MLPData: TinyShakespeare (char-level), batch size 1, block size 1024
Training: AdamW lr=0.0003, betas=(0.9,0.95), weight_decay=0.1, grad_clip=1.0, 500 iters
Key Observations
cutile_realfp8 — 10.0 ms/step, 374 MiB peak, 2.47x speedupBenchmark: Isolated QKAN Kernel (B5)
Model:
QKAN([100, 100], reps=3), batch=1000 — where QKAN IS the workloadCoalesced layout alone improved cuTile real bf16: 0.92 ms → 0.60 ms (1.52x).
c_dtypeandp_dtypep_dtype— parameter storage dtype (torch.float32ortorch.bfloat16). Controls theta, preacts, and output precision.c_dtype— compute dtype for quantum simulation kernels. Controls kernel I/O and backward state checkpoint precision.Previously only
torch.complex64(default, maps to f32 compute) andtorch.bfloat16(real ansatz only) were supported. This PR addstorch.bfloat16for pz/rpz andtorch.float8_e4m3fnfor all ansatzes.c_dtypetorch.complex64/torch.float32torch.bfloat16torch.float8_e4m3fnWhen to Use Each Dtype
p_dtype=torch.bfloat16+c_dtype=torch.bfloat16— 2.3x speedup, 45% less memoryc_dtype=torch.float8_e4m3fn— extra backward speedup from fp8 state checkpointsc_dtype=torch.float32— kernel is launch-overhead dominated, dtype conversion adds overheadc_dtype=torch.float32— full f32 precision throughoutUsage
p_dtype vs c_dtype Performance Matrix
p_dtypecontrols parameter storage,c_dtypecontrols kernel precision. They are independent.Isolated QKAN kernel (flash pz, dim=100, reps=3, batch=1000):
model.to(torch.bfloat16)sop_dtypematches the rest of the modelHardware