Skip to content

[Perf] GEMM FP8 performance optimization #171

@0xaskr

Description

@0xaskr

Summary

The ant-pretrain repository requires pallas-kernel to optimize FP8 GEMM (General Matrix Multiply) performance for training workloads on TPU.

Background

From ant-pretrain PR primatrix/ant-pretrain#325:

  • Current work applies optimal per-phase GMM tiling discovered via automated kernel evolution: fwd=(256,512,256), bwd=(256,512,128), tgmm=(1024,512,128)
  • FP8 quantization (fp8_full) is restricted to GMM/ragged_dot ops only
  • FP8 Tokamax training on TPU v7x achieves ~163 TFLOP/s/device steady-state throughput

Current Bottleneck

The FP8 kernel multiplication itself provides performance gains. However, the current overhead lies in the quantization and dequantization cost. Attempts to fuse quantize/dequantize into the GEMM kernel revealed that the MXU cannot overlap well with the VPU, resulting in pipeline bubbles. As a result, FP8 currently underperforms the BF16 kernel baseline, which completes in approximately ~1ms.

Type

  • Performance regression (was faster before)
  • Below expected performance target (not meeting 80% roofline)
  • Optimization opportunity

Kernel / Operation

GEMM with FP8 precision (E4M3 / E5M2) on TPU.

Observed Performance

  • FP8 GEMM with fused quantize/dequantize: slower than BF16 baseline due to MXU-VPU overlap issue
  • BF16 GEMM baseline: ~1ms
  • FP8 Tokamax steady-state: ~163 TFLOP/s/device (with BF16 GMM tiling)

Expected Performance

FP8 GEMM should outperform BF16 baseline by leveraging hardware-native FP8 compute units, targeting 80% of hardware theoretical peak per project standards.

Key Technical Challenges

  1. Quantize/Dequantize overhead: The cost of converting BF16 ↔ FP8 before/after GEMM negates the FP8 compute speedup
  2. MXU-VPU overlap: When fusing quantize/dequantize into the GEMM kernel, the MXU (matrix unit) and VPU (vector unit) cannot be effectively overlapped, creating pipeline bubbles
  3. Tiling strategy: Need to find tiling configurations that allow VPU quantization work to hide behind MXU compute

Environment

  • Hardware: TPU v7x (from ant-pretrain benchmarks)
  • JAX version: TBD
  • Topology: 2x2x1 (4 chips)

References

Additional Context

This is a downstream requirement from ant-pretrain. Solving the MXU-VPU overlap problem is the key to unlocking FP8 performance gains on TPU.

Metadata

Metadata

Assignees

Labels

P1performancePerformance issue or optimization

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions