-
Notifications
You must be signed in to change notification settings - Fork 2
[Perf] KDA kernel performance optimization #172
Copy link
Copy link
Open
Labels
Description
Summary
Optimize KDA (Key-Driven Attention) kernel performance to meet the project's 80% hardware peak target.
Type
- Performance regression (was faster before)
- Below expected performance target (not meeting 80% roofline)
- Optimization opportunity
Kernel / Operation
KDA kernels under tops/ops/kda/ (forward and backward passes).
Observed Performance
TBD — benchmark baseline to be established after CPU reference implementation (#168) is complete.
Expected Performance
Target 80% of hardware theoretical peak per project standards:
- If compute-bound: 80% of hardware compute peak (FLOPS)
- If memory bandwidth-bound: 80% of hardware memory bandwidth peak
Environment
- Python version:
- JAX version:
- Hardware: CPU / GPU (model) / TPU (version)
- OS:
Reproduction
# TBD: benchmark scriptTasks
- Establish performance baseline with initial KDA kernel implementation
- Conduct roofline analysis (arithmetic intensity, bound type)
- Profile kernel execution (Perfetto / TensorBoard Profiler)
- Optimize tiling strategy and data movement (DMA pipeline, double buffering)
- Verify MXU/VPU overlap effectiveness
- Validate correctness post-optimization (error ≤ CPU reference)
- Document results in design doc under
docs/design-docs/ops/kda/
Trace / Profile Data
TBD — to be collected after baseline implementation.
Additional Context
Depends on #168 (KDA CPU reference implementation and comparison tests) and #56 (KDA implementation). Performance optimization should follow the design doc requirements in CONTRIBUTING.md.
Reactions are currently unavailable