Skip to content

[Feature] KDA CPU reference implementation and comparison tests #168

@0xaskr

Description

@0xaskr

Motivation

Part of Epic #28. KDA is one of the three priority flash-linear-attention components requiring a CPU JAX gold reference for kernel correctness validation.

Expected Behavior

  • Complete CPU JAX implementation of KDA in tops/cpu/ops/kda/
  • Input layout follows project convention: [B, T, H, K]
  • Comprehensive docstrings with tensor shape and dimension semantics
  • Strict input assertions on shape and types
  • Comparison tests in tests/ops/kda/ validating Pallas kernels against CPU reference
  • Tolerance-based assertions using compare_tensor(name, gold, tensor, atol, rtol, max_ulp)

Proposed Approach

  1. Implement CPU reference in tops/cpu/ops/kda/
  2. Follow existing GLA reference implementation patterns
  3. Add comparison tests in tests/ops/kda/
  4. Error bound constraint: TPU error ≤ GPU error
  5. Export public API via tops/ops/__init__.py

Willingness to Contribute

  • I am willing to submit a PR for this feature.

Additional Context

Sub-task of #28.

Metadata

Metadata

Labels

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions