Skip to content

[Perf] KDA kernel performance optimization #172

@0xaskr

Description

@0xaskr

Summary

Optimize KDA (Key-Driven Attention) kernel performance to meet the project's 80% hardware peak target.

Type

  • Performance regression (was faster before)
  • Below expected performance target (not meeting 80% roofline)
  • Optimization opportunity

Kernel / Operation

KDA kernels under tops/ops/kda/ (forward and backward passes).

Observed Performance

TBD — benchmark baseline to be established after CPU reference implementation (#168) is complete.

Expected Performance

Target 80% of hardware theoretical peak per project standards:

  • If compute-bound: 80% of hardware compute peak (FLOPS)
  • If memory bandwidth-bound: 80% of hardware memory bandwidth peak

Environment

  • Python version:
  • JAX version:
  • Hardware: CPU / GPU (model) / TPU (version)
  • OS:

Reproduction

# TBD: benchmark script

Tasks

  • Establish performance baseline with initial KDA kernel implementation
  • Conduct roofline analysis (arithmetic intensity, bound type)
  • Profile kernel execution (Perfetto / TensorBoard Profiler)
  • Optimize tiling strategy and data movement (DMA pipeline, double buffering)
  • Verify MXU/VPU overlap effectiveness
  • Validate correctness post-optimization (error ≤ CPU reference)
  • Document results in design doc under docs/design-docs/ops/kda/

Trace / Profile Data

TBD — to be collected after baseline implementation.

Additional Context

Depends on #168 (KDA CPU reference implementation and comparison tests) and #56 (KDA implementation). Performance optimization should follow the design doc requirements in CONTRIBUTING.md.

Metadata

Metadata

Labels

P1performancePerformance issue or optimization

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions