-
Notifications
You must be signed in to change notification settings - Fork 2
[Feature] GLA CPU reference implementation and comparison tests #166
Copy link
Copy link
Open
Labels
Description
Motivation
Part of Epic #28. GLA (Gated Linear Attention) is the highest priority flash-linear-attention component requiring a CPU JAX gold reference for kernel correctness validation.
Expected Behavior
- Complete CPU JAX implementation of GLA in
tops/cpu/ops/gla/ - Supports the GLA recurrence:
h_t = h_{t-1} * exp(gk_t) + k_t^T @ v_t,o_t = q_t^T @ h_t - Input layout follows project convention:
[B, T, H, K] - Comprehensive docstrings with tensor shape and dimension semantics
- Strict input assertions on shape and types
- Comparison tests in
tests/ops/gla/validating Pallas kernels against CPU reference - Tolerance-based assertions using
compare_tensor(name, gold, tensor, atol, rtol, max_ulp)
Proposed Approach
- Implement CPU reference in
tops/cpu/ops/gla/covering fused recurrent and chunk variants - Add comparison tests in
tests/ops/gla/ - Error bound constraint: TPU error ≤ GPU error
- Export public API via
tops/ops/__init__.py
Willingness to Contribute
- I am willing to submit a PR for this feature.
Additional Context
Sub-task of #28.
Reactions are currently unavailable