Skip to content

[Feature] MLA CPU reference implementation and comparison tests #167

@0xaskr

Description

@0xaskr

Motivation

Part of Epic #28. MLA (Multi-head Latent Attention) is one of the three priority flash-linear-attention components requiring a CPU JAX gold reference for kernel correctness validation.

Expected Behavior

  • Complete CPU JAX implementation of MLA in tops/cpu/ops/mla/
  • Input layout follows project convention: [B, T, H, K]
  • Comprehensive docstrings with tensor shape and dimension semantics
  • Strict input assertions on shape and types
  • Comparison tests in tests/ops/mla/ validating Pallas kernels against CPU reference
  • Tolerance-based assertions using compare_tensor(name, gold, tensor, atol, rtol, max_ulp)

Proposed Approach

  1. Implement CPU reference in tops/cpu/ops/mla/
  2. Follow existing GLA reference implementation patterns
  3. Add comparison tests in tests/ops/mla/
  4. Error bound constraint: TPU error ≤ GPU error
  5. Export public API via tops/ops/__init__.py

Willingness to Contribute

  • I am willing to submit a PR for this feature.

Additional Context

Sub-task of #28.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions