-
Notifications
You must be signed in to change notification settings - Fork 2
[Feature] KDA CPU reference implementation and comparison tests #168
Copy link
Copy link
Open
Labels
Description
Motivation
Part of Epic #28. KDA is one of the three priority flash-linear-attention components requiring a CPU JAX gold reference for kernel correctness validation.
Expected Behavior
- Complete CPU JAX implementation of KDA in
tops/cpu/ops/kda/ - Input layout follows project convention:
[B, T, H, K] - Comprehensive docstrings with tensor shape and dimension semantics
- Strict input assertions on shape and types
- Comparison tests in
tests/ops/kda/validating Pallas kernels against CPU reference - Tolerance-based assertions using
compare_tensor(name, gold, tensor, atol, rtol, max_ulp)
Proposed Approach
- Implement CPU reference in
tops/cpu/ops/kda/ - Follow existing GLA reference implementation patterns
- Add comparison tests in
tests/ops/kda/ - Error bound constraint: TPU error ≤ GPU error
- Export public API via
tops/ops/__init__.py
Willingness to Contribute
- I am willing to submit a PR for this feature.
Additional Context
Sub-task of #28.
Reactions are currently unavailable