-
Notifications
You must be signed in to change notification settings - Fork 2
[Perf] GEMM FP8 performance optimization #171
Description
Summary
The ant-pretrain repository requires pallas-kernel to optimize FP8 GEMM (General Matrix Multiply) performance for training workloads on TPU.
Background
From ant-pretrain PR primatrix/ant-pretrain#325:
- Current work applies optimal per-phase GMM tiling discovered via automated kernel evolution:
fwd=(256,512,256),bwd=(256,512,128),tgmm=(1024,512,128) - FP8 quantization (
fp8_full) is restricted to GMM/ragged_dot ops only - FP8 Tokamax training on TPU v7x achieves ~163 TFLOP/s/device steady-state throughput
Current Bottleneck
The FP8 kernel multiplication itself provides performance gains. However, the current overhead lies in the quantization and dequantization cost. Attempts to fuse quantize/dequantize into the GEMM kernel revealed that the MXU cannot overlap well with the VPU, resulting in pipeline bubbles. As a result, FP8 currently underperforms the BF16 kernel baseline, which completes in approximately ~1ms.
Type
- Performance regression (was faster before)
- Below expected performance target (not meeting 80% roofline)
- Optimization opportunity
Kernel / Operation
GEMM with FP8 precision (E4M3 / E5M2) on TPU.
Observed Performance
- FP8 GEMM with fused quantize/dequantize: slower than BF16 baseline due to MXU-VPU overlap issue
- BF16 GEMM baseline: ~1ms
- FP8 Tokamax steady-state: ~163 TFLOP/s/device (with BF16 GMM tiling)
Expected Performance
FP8 GEMM should outperform BF16 baseline by leveraging hardware-native FP8 compute units, targeting 80% of hardware theoretical peak per project standards.
Key Technical Challenges
- Quantize/Dequantize overhead: The cost of converting BF16 ↔ FP8 before/after GEMM negates the FP8 compute speedup
- MXU-VPU overlap: When fusing quantize/dequantize into the GEMM kernel, the MXU (matrix unit) and VPU (vector unit) cannot be effectively overlapped, creating pipeline bubbles
- Tiling strategy: Need to find tiling configurations that allow VPU quantization work to hide behind MXU compute
Environment
- Hardware: TPU v7x (from ant-pretrain benchmarks)
- JAX version: TBD
- Topology: 2x2x1 (4 chips)
References
- ant-pretrain PR: primatrix/ant-pretrain#325
- Kernel evolution: sii-xinglong/Glaucis#48
Additional Context
This is a downstream requirement from ant-pretrain. Solving the MXU-VPU overlap problem is the key to unlocking FP8 performance gains on TPU.