Skip to content

[Feature] GLA CPU reference implementation and comparison tests #166

@0xaskr

Description

@0xaskr

Motivation

Part of Epic #28. GLA (Gated Linear Attention) is the highest priority flash-linear-attention component requiring a CPU JAX gold reference for kernel correctness validation.

Expected Behavior

  • Complete CPU JAX implementation of GLA in tops/cpu/ops/gla/
  • Supports the GLA recurrence: h_t = h_{t-1} * exp(gk_t) + k_t^T @ v_t, o_t = q_t^T @ h_t
  • Input layout follows project convention: [B, T, H, K]
  • Comprehensive docstrings with tensor shape and dimension semantics
  • Strict input assertions on shape and types
  • Comparison tests in tests/ops/gla/ validating Pallas kernels against CPU reference
  • Tolerance-based assertions using compare_tensor(name, gold, tensor, atol, rtol, max_ulp)

Proposed Approach

  1. Implement CPU reference in tops/cpu/ops/gla/ covering fused recurrent and chunk variants
  2. Add comparison tests in tests/ops/gla/
  3. Error bound constraint: TPU error ≤ GPU error
  4. Export public API via tops/ops/__init__.py

Willingness to Contribute

  • I am willing to submit a PR for this feature.

Additional Context

Sub-task of #28.

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions