-
Notifications
You must be signed in to change notification settings - Fork 2
feat(gla): fused forward+backward kernels for g_gamma mode #122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
sii-xinglong
wants to merge
16
commits into
main
Choose a base branch
from
feat/chunk-gla-fused-kernels
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 5 commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
76b5444
feat(gla): add fused forward kernel for g_gamma mode
sii-xinglong 71a8429
feat(gla): add fused backward kernel for g_gamma mode
sii-xinglong 22c6c54
style(chunk_fused_kernels): use exp() utility consistently in backwar…
sii-xinglong 05c6726
feat(gla): dispatch to fused kernels for g_gamma mode on TPU
sii-xinglong d8d3523
bench(gla): add chunk_fused and chunk_fused_bwd benchmark providers
sii-xinglong 3a88c91
fix(chunk_fused_kernels): cast float32 scratch to h_ref dtype on save
sii-xinglong 3e0372b
fix(chunk_fused_kernels): make scale a static JIT argument
sii-xinglong 572ed36
fix(chunk_fused_kernels): use parallel semantics for K/V grid dims
sii-xinglong 2b71a64
fix(chunk_fused_kernels): use f32 inputs for all matmuls in Pallas ke…
sii-xinglong 1247d97
fix(gla): correct cu_seqlens parameter names in chunk kernel calls
sii-xinglong 66c19d8
fix(tests): use bfloat16 inputs for both fused and reference paths
sii-xinglong f15840f
fix(tests): use naive recurrent reference instead of non-fused Pallas
sii-xinglong 7ffeae2
fix: scale g_gamma in E2E test to prevent numerical overflow
sii-xinglong 81458e0
fix: address PR #122 review comments
sii-xinglong 7e9d63f
perf(chunk_gla): eliminate jnp.flip copies in fused backward via reve…
sii-xinglong 1de5415
bench(gla): add memory profiling script for fused vs non-fused
sii-xinglong File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,123 @@ | ||
| """Tests for the fused chunked GLA backward kernel (g_gamma mode). | ||
|
|
||
| Compares the output (dq, dk, dv) of the fused single-pallas_call backward | ||
| against the non-fused ``chunk_gla_bwd_with_pl`` reference that uses separate | ||
| dh propagation and gradient computation kernels. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import sys | ||
| from pathlib import Path | ||
|
|
||
| sys.path.insert(0, str(Path(__file__).resolve().parents[3])) | ||
|
|
||
| import jax | ||
| import jax.numpy as jnp | ||
| import numpy as np | ||
| import pytest | ||
|
|
||
| from tests.utils import compare_tensor | ||
| from tops.ops.gla.chunk import chunk_gla_bwd_with_pl | ||
| from tops.ops.gla.chunk_fused_kernels import ( | ||
| chunk_bwd_fused_g_gamma, | ||
| chunk_fwd_fused_g_gamma, | ||
| ) | ||
|
|
||
|
|
||
| def _make_test_data(B, T, H, K, V, seed=42): | ||
| """Create deterministic (q, k, v, g_gamma, do) for a GLA backward test.""" | ||
| key = jax.random.PRNGKey(seed) | ||
| k1, k2, k3, k4, k5 = jax.random.split(key, 5) | ||
| q = jax.random.normal(k1, (B, T, H, K), dtype=jnp.bfloat16) | ||
| k_arr = jax.random.normal(k2, (B, T, H, K), dtype=jnp.bfloat16) | ||
| v = jax.random.normal(k3, (B, T, H, V), dtype=jnp.bfloat16) | ||
| g_gamma = -jnp.abs(jax.random.normal(k4, (H,), dtype=jnp.float32)) * 0.1 | ||
| do = jax.random.normal(k5, (B, T, H, V), dtype=jnp.bfloat16) | ||
| return q, k_arr, v, g_gamma, do | ||
|
|
||
|
|
||
| @pytest.mark.tpu_only | ||
| class TestChunkBwdFused: | ||
| """Tests for chunk_bwd_fused_g_gamma against the non-fused reference.""" | ||
|
|
||
| def _run_reference(self, q, k, v, g_gamma, do, scale, chunk_size): | ||
| """Run the non-fused chunk_gla_bwd_with_pl to get reference dq, dk, dv.""" | ||
| q_f32 = q.astype(jnp.float32) | ||
| k_f32 = k.astype(jnp.float32) | ||
| v_f32 = v.astype(jnp.float32) | ||
| do_f32 = do.astype(jnp.float32) | ||
| ref_dq, ref_dk, ref_dv, _, _ = chunk_gla_bwd_with_pl( | ||
| q_f32, | ||
| k_f32, | ||
| v_f32, | ||
| g=None, | ||
| g_gamma=g_gamma.reshape(1, 1, -1, 1), | ||
| g_cumsum=None, | ||
| scale=scale, | ||
| initial_state=None, | ||
| h=None, | ||
| A=None, | ||
| do=do_f32, | ||
| dht=None, | ||
| chunk_size=chunk_size, | ||
| ) | ||
| return ref_dq, ref_dk, ref_dv | ||
|
|
||
| def test_fused_bwd_basic(self): | ||
| """Basic fused backward: B=2, T=256, H=4, K=128, V=128, C=64.""" | ||
| B, T, H, K, V, C = 2, 256, 4, 128, 128, 64 | ||
| q, k, v, g_gamma, do = _make_test_data(B, T, H, K, V, seed=42) | ||
| scale = K**-0.5 | ||
|
|
||
| # Get h from fused forward | ||
| h, _ = chunk_fwd_fused_g_gamma(q, k, v, g_gamma, scale, C) | ||
|
|
||
| # Fused backward kernel | ||
| dq_fused, dk_fused, dv_fused = chunk_bwd_fused_g_gamma( | ||
| q, k, v, h, do, g_gamma, scale, C | ||
| ) | ||
|
|
||
| # Non-fused reference | ||
| ref_dq, ref_dk, ref_dv = self._run_reference(q, k, v, g_gamma, do, scale, C) | ||
|
|
||
| assert compare_tensor( | ||
| "dq", ref_dq, dq_fused, atol=1e-2, rtol=1e-2, dtype=np.float32 | ||
| ) | ||
| assert compare_tensor( | ||
| "dk", ref_dk, dk_fused, atol=1e-2, rtol=1e-2, dtype=np.float32 | ||
| ) | ||
| assert compare_tensor( | ||
| "dv", ref_dv, dv_fused, atol=1e-2, rtol=1e-2, dtype=np.float32 | ||
| ) | ||
|
|
||
| def test_fused_bwd_al_dims(self): | ||
| """AL model dimensions: B=2, T=4096, H=16, K=128, V=128, C=64.""" | ||
| B, T, H, K, V, C = 2, 4096, 16, 128, 128, 64 | ||
| q, k, v, g_gamma, do = _make_test_data(B, T, H, K, V, seed=123) | ||
| scale = K**-0.5 | ||
|
|
||
| # Get h from fused forward | ||
| h, _ = chunk_fwd_fused_g_gamma(q, k, v, g_gamma, scale, C) | ||
|
|
||
| # Fused backward kernel | ||
| dq_fused, dk_fused, dv_fused = chunk_bwd_fused_g_gamma( | ||
| q, k, v, h, do, g_gamma, scale, C | ||
| ) | ||
|
|
||
| # Non-fused reference | ||
| ref_dq, ref_dk, ref_dv = self._run_reference(q, k, v, g_gamma, do, scale, C) | ||
|
|
||
| assert compare_tensor( | ||
| "dq", ref_dq, dq_fused, atol=1e-2, rtol=1e-2, dtype=np.float32 | ||
| ) | ||
| assert compare_tensor( | ||
| "dk", ref_dk, dk_fused, atol=1e-2, rtol=1e-2, dtype=np.float32 | ||
| ) | ||
| assert compare_tensor( | ||
| "dv", ref_dv, dv_fused, atol=1e-2, rtol=1e-2, dtype=np.float32 | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| pytest.main([__file__, "-v"]) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,91 @@ | ||
| """Tests for the fused chunked GLA forward kernel (g_gamma mode). | ||
|
|
||
| Compares the output of the fused single-pallas_call forward against the | ||
| non-fused ``chunk_gla_fwd`` reference that uses three separate kernels. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import sys | ||
| from pathlib import Path | ||
|
|
||
| sys.path.insert(0, str(Path(__file__).resolve().parents[3])) | ||
|
|
||
| import jax | ||
| import jax.numpy as jnp | ||
| import numpy as np | ||
| import pytest | ||
|
|
||
| from tests.utils import compare_tensor | ||
| from tops.ops.gla.chunk import chunk_gla_fwd | ||
| from tops.ops.gla.chunk_fused_kernels import chunk_fwd_fused_g_gamma | ||
|
|
||
|
|
||
| def _make_test_data(B, T, H, K, V, seed=42): | ||
| """Create deterministic (q, k, v, g_gamma) for a GLA test case.""" | ||
| key = jax.random.PRNGKey(seed) | ||
| k1, k2, k3, k4 = jax.random.split(key, 4) | ||
| q = jax.random.normal(k1, (B, T, H, K), dtype=jnp.bfloat16) | ||
| k_arr = jax.random.normal(k2, (B, T, H, K), dtype=jnp.bfloat16) | ||
| v = jax.random.normal(k3, (B, T, H, V), dtype=jnp.bfloat16) | ||
| g_gamma = -jnp.abs(jax.random.normal(k4, (H,), dtype=jnp.float32)) * 0.1 | ||
| return q, k_arr, v, g_gamma | ||
|
|
||
|
|
||
| @pytest.mark.tpu_only | ||
| class TestChunkFwdFused: | ||
| """Tests for chunk_fwd_fused_g_gamma against the non-fused reference.""" | ||
|
|
||
| def _run_reference(self, q, k, v, g_gamma, scale, chunk_size): | ||
| """Run the non-fused chunk_gla_fwd to get reference (h, o).""" | ||
| q_f32 = q.astype(jnp.float32) | ||
| k_f32 = k.astype(jnp.float32) | ||
| v_f32 = v.astype(jnp.float32) | ||
| _, _, h_ref, _, o_ref = chunk_gla_fwd( | ||
| q_f32, | ||
| k_f32, | ||
| v_f32, | ||
| g=None, | ||
| g_gamma=g_gamma.reshape(1, 1, -1, 1), | ||
| g_cumsum=None, | ||
| scale=scale, | ||
| initial_state=None, | ||
| output_final_state=False, | ||
| cu_seqlens=None, | ||
| chunk_size=chunk_size, | ||
| ) | ||
| return h_ref, o_ref | ||
|
|
||
| def test_fused_fwd_basic(self): | ||
| """Basic fused forward: B=2, T=256, H=4, K=128, V=128, C=64.""" | ||
| B, T, H, K, V, C = 2, 256, 4, 128, 128, 64 | ||
| q, k, v, g_gamma = _make_test_data(B, T, H, K, V, seed=42) | ||
| scale = K**-0.5 | ||
|
|
||
| # Fused kernel | ||
| h_fused, o_fused = chunk_fwd_fused_g_gamma(q, k, v, g_gamma, scale, C) | ||
|
|
||
| # Non-fused reference | ||
| h_ref, o_ref = self._run_reference(q, k, v, g_gamma, scale, C) | ||
|
|
||
| assert compare_tensor("o", o_ref, o_fused, atol=1e-2, rtol=1e-2, dtype=np.float32) | ||
| assert compare_tensor("h", h_ref, h_fused, atol=1e-2, rtol=1e-2, dtype=np.float32) | ||
|
|
||
| def test_fused_fwd_al_dims(self): | ||
| """AL model dimensions: B=2, T=4096, H=16, K=128, V=128, C=64.""" | ||
| B, T, H, K, V, C = 2, 4096, 16, 128, 128, 64 | ||
| q, k, v, g_gamma = _make_test_data(B, T, H, K, V, seed=123) | ||
| scale = K**-0.5 | ||
|
|
||
| # Fused kernel | ||
| h_fused, o_fused = chunk_fwd_fused_g_gamma(q, k, v, g_gamma, scale, C) | ||
|
|
||
| # Non-fused reference | ||
| h_ref, o_ref = self._run_reference(q, k, v, g_gamma, scale, C) | ||
|
|
||
| assert compare_tensor("o", o_ref, o_fused, atol=1e-2, rtol=1e-2, dtype=np.float32) | ||
| assert compare_tensor("h", h_ref, h_fused, atol=1e-2, rtol=1e-2, dtype=np.float32) | ||
coderabbitai[bot] marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| pytest.main([__file__, "-v"]) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,106 @@ | ||
| """End-to-end tests for fused chunk-GLA dispatch path (g_gamma mode). | ||
|
|
||
| Verifies that ``chunk_gla_fwd`` and ``chunk_gla_bwd_with_pl`` correctly | ||
| dispatch to the fused kernels when running in g_gamma mode on TPU, and | ||
| that the results match the non-fused reference path. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import sys | ||
| from pathlib import Path | ||
|
|
||
| sys.path.insert(0, str(Path(__file__).resolve().parents[3])) | ||
|
|
||
| import jax | ||
| import jax.numpy as jnp | ||
| import numpy as np | ||
| import pytest | ||
|
|
||
| from tests.utils import compare_tensor | ||
| from tops.ops.gla.chunk import chunk_gla_fwd, chunk_gla_bwd_with_pl | ||
|
|
||
|
|
||
| def _make_test_data(B, T, H, K, V, seed=42): | ||
| """Create deterministic (q, k, v, g_gamma, do) for a GLA test case.""" | ||
| key = jax.random.PRNGKey(seed) | ||
| k1, k2, k3, k4, k5 = jax.random.split(key, 5) | ||
| q = jax.random.normal(k1, (B, T, H, K), dtype=jnp.bfloat16) | ||
| k_arr = jax.random.normal(k2, (B, T, H, K), dtype=jnp.bfloat16) | ||
| v = jax.random.normal(k3, (B, T, H, V), dtype=jnp.bfloat16) | ||
| g_gamma = -jnp.abs(jax.random.normal(k4, (H,), dtype=jnp.float32)) * 0.1 | ||
| do = jax.random.normal(k5, (B, T, H, V), dtype=jnp.bfloat16) | ||
| return q, k_arr, v, g_gamma, do | ||
|
|
||
|
|
||
| @pytest.mark.tpu_only | ||
| class TestChunkGlaFusedE2E: | ||
| """End-to-end fused dispatch vs non-fused reference.""" | ||
|
|
||
| def test_fwd_dispatch_output_shape(self): | ||
| """Forward dispatch produces correct output shapes.""" | ||
| B, T, H, K, V, C = 2, 256, 4, 128, 128, 64 | ||
| q, k, v, g_gamma, _ = _make_test_data(B, T, H, K, V) | ||
| scale = K**-0.5 | ||
|
|
||
| _, _, h, ht, o = chunk_gla_fwd( | ||
| q, k, v, g=None, g_gamma=g_gamma, g_cumsum=None, | ||
| scale=scale, initial_state=None, output_final_state=False, | ||
| chunk_size=C, | ||
| ) | ||
| assert o.shape == (B, T, H, V) | ||
| assert h.shape[0] == B | ||
| assert ht is None | ||
|
|
||
| def test_fwd_dispatch_matches_reference(self): | ||
| """Forward fused dispatch matches non-fused reference on output.""" | ||
| B, T, H, K, V, C = 2, 256, 4, 128, 128, 64 | ||
| q, k, v, g_gamma, _ = _make_test_data(B, T, H, K, V) | ||
| scale = K**-0.5 | ||
|
|
||
| # Run via the dispatch path (g_gamma mode, should use fused on TPU) | ||
| _, _, h_disp, _, o_disp = chunk_gla_fwd( | ||
| q, k, v, g=None, g_gamma=g_gamma, g_cumsum=None, | ||
| scale=scale, initial_state=None, output_final_state=False, | ||
| chunk_size=C, | ||
| ) | ||
|
|
||
| # Run via the non-fused path (pass g explicitly to bypass fused dispatch) | ||
| g_full = jnp.broadcast_to(g_gamma.reshape(1, 1, -1, 1), q.shape) | ||
| _, _, h_ref, _, o_ref = chunk_gla_fwd( | ||
| q, k, v, g=g_full, g_gamma=g_gamma, g_cumsum=None, | ||
| scale=scale, initial_state=None, output_final_state=False, | ||
| chunk_size=C, | ||
| ) | ||
|
|
||
| assert compare_tensor("o", o_ref, o_disp, atol=1e-2, rtol=1e-2, dtype=np.float32) | ||
| assert compare_tensor("h", h_ref, h_disp, atol=1e-2, rtol=1e-2, dtype=np.float32) | ||
|
|
||
| def test_bwd_dispatch_matches_reference(self): | ||
| """Backward fused dispatch matches non-fused reference gradients.""" | ||
| B, T, H, K, V, C = 2, 256, 4, 128, 128, 64 | ||
| q, k, v, g_gamma, do = _make_test_data(B, T, H, K, V) | ||
| scale = K**-0.5 | ||
|
|
||
| # Backward via dispatch path (g_gamma mode, should use fused on TPU) | ||
| dq_disp, dk_disp, dv_disp, _, _ = chunk_gla_bwd_with_pl( | ||
| q, k, v, g=None, g_gamma=g_gamma, g_cumsum=None, | ||
| scale=scale, initial_state=None, h=None, A=None, | ||
| do=do, dht=None, chunk_size=C, | ||
| ) | ||
|
|
||
| # Backward via non-fused path (pass g explicitly) | ||
| g_full = jnp.broadcast_to(g_gamma.reshape(1, 1, -1, 1), q.shape) | ||
| dq_ref, dk_ref, dv_ref, _, _ = chunk_gla_bwd_with_pl( | ||
| q, k, v, g=g_full, g_gamma=g_gamma, g_cumsum=None, | ||
| scale=scale, initial_state=None, h=None, A=None, | ||
| do=do, dht=None, chunk_size=C, | ||
| ) | ||
|
|
||
| assert compare_tensor("dq", dq_ref, dq_disp, atol=1e-2, rtol=1e-2, dtype=np.float32) | ||
| assert compare_tensor("dk", dk_ref, dk_disp, atol=1e-2, rtol=1e-2, dtype=np.float32) | ||
| assert compare_tensor("dv", dv_ref, dv_disp, atol=1e-2, rtol=1e-2, dtype=np.float32) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| pytest.main([__file__, "-v"]) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.