diff --git a/.agents/projects/grugformer.md b/.agents/projects/grugformer.md new file mode 100644 index 0000000000..1e9ea38f58 --- /dev/null +++ b/.agents/projects/grugformer.md @@ -0,0 +1,222 @@ +# Grug: Explicit-Sharding LM Trainer + +## Background + +Inspired by [grugbrain.dev](https://grugbrain.dev/) and Andrej Karpathy’s [NanoGPT](https://github.com/karpathy/nanoGPT), we want a “grug-simple” causal LM trainer that showcases JAX’s explicit sharding mode while relying only on primitive building blocks: `jax`, `einops`, and JAX dataclasses (via `jax.tree_util.register_dataclass`). Training utilities stay minimal—optax for optimization, Levanter’s data loading/trackers for ingestion + logging, HuggingFace tokenizers (or Levanter’s serializer) for text, and TensorStore serialization for checkpoints. We explicitly do **not** want Haliax abstractions or class-heavy APIs—every computation lives in straightforward top-level functions so the trainer reads like a notebook. + + +## Grug Principles + +### Software Principles + +- **Few methods**: prefer top level functions to lots of methods (i.e. torch-style class hierarchies). Use `@dataclass`es for parameter/state containers, but keep logic in functions. Accessors are okay for small helpers (e.g. `def get_attention_params(model_params): ...`). +- **No array-involved methods in jit**: avoid class methods that operate on `jax.Array` fields inside `@jax.jit`-compiled functions. There are a few corner cases where they behave surprisingly. (Mostly we need to avoid `jit(module.compute_loss)`). +- **Keep dependencies small.** Prefer `einshard`, `optax`, JAX core libs, HF tokenizers, and Levanter’s `data` + `tracker` + TensorStore serialization APIs. Grug doesn't want to reinvent wheels, but we also don't want heavy frameworks obscuring the core logic. +- **Fast kernels are allowed, but keep the surface simple.** On TPU, Grug uses JAX Splash attention (via `jax.experimental.pallas.ops.tpu.splash_attention`). We keep a separate reference implementation for debugging/regressions. +- **Serializable state.** Trainer state must round-trip through `levanter.tensorstore_serialization.tree_{de,}serialize_leaves_tensorstore` with no custom logic. +- **Consumption-ready.** Provide a `uv run python -m marin.grugpt.train` entrypoint plus tests. Tracking hooks log loss/perplexity through Levanter’s tracker interface. +- **Limit config knobs.** In general, it's better to copy-paste experimental code than to bake every possible option into the core. Keep the config surface minimal and stable. As an example, attention sinks should be added by copy-pasting/editing the attention callsite in a speedrun script, not by adding a `cfg.use_attention_sinks` flag. But changing the number of layers etc. should be easy via config. This is a judgment call. + An exception to this principle is when constructing an A/B test in a speedrun file (see docs/recipes/change_grug.md)—in that case, it's okay (even preferred) to add temporary flags to your copy-pasted grug core to toggle between two behaviors for comparison. + +## Working Agreement: How Grug Evolves + +- Canonical “best guess” lives in `lib/levanter/src/levanter/grug/`. +- The evolving speedrun entrypoint is `experiments/speedrun/grugformer_starter/grugformer_speedrun.py`. +- One-off speedruns under `experiments/speedrun/…` are snapshots/edit surfaces; they should not silently become the source of truth. +- When we upstream a successful experiment, we update tests and record/clean up old experiments per `docs/recipes/change_grug.md`. This involves deletion of incompatible experiments but leaving a trail. + +### JAX/ML Principles + +- **Mesh-first mental model.** We always create one mesh axis per logical unit. For now, `['data', 'model']`. Work must still run on a single GPU (mesh axes collapse to size 1) but should seamlessly extend to 4-device TPU v4-8 pods for CI. Mesh creation and validation live in one place. + - We will likely add `"replica"` and `"replica_dcn"` for partial fsdp and multislice. +- **Use good kernels when available.** Grug is happy to call out to other people's fast attention kernels (ejkernel blocksparse today) as long as the surface stays simple and the fallback reference path remains. +- **Explicit sharding everywhere.** Arrays carry `PartitionSpec`s in their types. By using `set_mesh` with `AxisType.Explicit` we'll always know every Array's sharding. Any op with ambiguous shardings must either supply `out_sharding=` or run inside `auto_axes`. Prefer the former. + +### Code Organization Principles + +- **Keep `levanter.grug` core-only.** The `levanter.grug` package should stay “notebook-like”: raw `jax.Array` inputs/outputs, top-level functions, small dataclasses, and minimal plumbing. +- **Adapters live outside grug.** Integration with Levanter model interfaces (`LmHeadModel`, `NamedArray`, etc.) lives in `levanter.models`, currently `levanter.models.grug_wrapper`. +- **Mask spec is simple.** Grug’s attention mask is a small spec (`levanter.grug.attention.AttentionMask`) storing only raw JAX arrays/ints (no NamedArray fields). Dense masks are currently accepted only by the reference attention path (TPU Splash does not support dense masks). +- **Prefer jaxtyping for the grug core.** Grug uses jaxtyping-style shape hints in `levanter.grug` (e.g. `Int[Array, "B S"]`, `Float[Array, "B Q H D"]`) to keep the core readable and to document expected conventions without introducing runtime checks. + +## Misc Style +- Use `einops.rearrange` for `[batch, seq, heads, head_dim]` manipulations where it improves clarity. +- Use `jnp.einsum` and explicit `PartitionSpec` annotations to keep matmuls consistent with the mesh layout. +- Grug core attention is a single `attention(q, k, v, mask)` that calls ejkernel block-sparse attention; `reference_attention(...)` exists for debugging/regressions but is not selected at runtime. +- Attention masks are a small spec dataclass (`levanter.grug.attention.AttentionMask`) rather than a dense mask/bias array. Grug defaults to causal masking. +- Layer norm: implement epsilon-stabilized RMSNorm using `jnp.mean/var` and explicit shardings. +- Residual dropouts remain optional and implemented via `jax.random.bernoulli` (with `out_sharding=P(None)` to keep them replicated). + +### Loss & Training Step + +- `loss_fn(params, batch, *, cfg, mesh)` runs forward pass and computes token-level cross-entropy (optionally ignoring padding tokens). Output sharding uses `(Batch @ data, None)` so reductions stay deterministic. +- Optimizer: `optax.adamw` built from `TrainingConfig`. `TrainingState` carries `opt_state`. We wrap the optimizer in `optax.apply_updates`. +- `@jax.jit` (or `jax.jit(static_argnames=("cfg",))`) compiled `train_step(state, batch, tracker)`: + 1. `loss, grads = jax.value_and_grad(loss_fn)(state.params, batch)` + 2. `grads = reshard(grads, param_spec)` so updates stay model-sharded. + 3. `updates, new_opt_state = optimizer.update(grads, state.opt_state, state.params)` + 4. `new_params = optax.apply_updates(state.params, updates)` + 5. `tracker.log({"train/loss": loss, "train/ppl": jnp.exp(loss)}, step=state.step)` + 6. Return updated `TrainingState(step + 1, rng=next_rng)` +- Provide `eval_step` mirroring `train_step` without optimizer updates for validation. + +### Data & Tokenization + +- `DataConfig` picks between HuggingFace tokenizer (`transformers.AutoTokenizer`) or a Levanter `TreeCache` path. We default to HF and fall back to Levanter’s `deserialize_leaves_tensorstore` when resuming vocab. +- `data.py` exposes `build_dataloader(cfg, mesh)` that: + - Builds/loads tokenizer. + - Constructs a `levanter.data.text.TokenSeqDataset` (or other AsyncDataset) using seq length + sharding info. + - Wraps it in a lightweight async background iterator returning dictionaries `{"input_ids": sharded_array, "labels": sharded_array}` already resharded to `(Batch@data, Seq)`. +- Because Levanter’s loader internally uses Haliax, grugpt never imports haliax directly—all adapters convert outputs to plain `jax.Array`s via `jnp.asarray` before feeding the model. + +### Checkpointing & Tracking + +- `checkpoint_path = cfg.run_dir / f"step_{state.step:07d}"`. +- Saving uses `tree_serialize_leaves_tensorstore(path, (state, cfg))` inside host 0 guard; loading uses matching `tree_deserialize...`. +- Trackers: `levanter.tracker` configs are CLI-selectable (noop vs wandb). `train.py` enters `with current_tracker(tracker):` before training. +- On resume, we log hyperparameters once and emit `train/time_per_step`, `train/tok_per_sec` metrics every `cfg.log_every` steps. + +### CLI Entrypoint + +- Current entrypoint lives in `lib/levanter/src/levanter/grug/main.py` with `def main(argv=None): ...` using `argparse` (no click). Steps: + 1. Parse YAML/JSON or CLI flags into `RunConfig` (we can reuse Levanter’s draccus parser if convenient). + 2. Build mesh and set it globally. + 3. Initialize tokenizer + dataloader + tracker. + 4. Initialize or restore `TrainingState`. + 5. Compile `train_step`/`eval_step` once per process. + 6. Run training loop with periodic evaluation/checkpointing. +- Provide `if __name__ == "__main__": uvicorn.main()` guard and register console script in `pyproject.toml`. + +- `tests/grugpt/test_model.py` creates a toy mesh (1×1×1), initializes params, runs forward pass, and asserts: + - logits shape `[batch, seq, vocab]` + - `jax.typeof(hidden).sharding` matches expected `PartitionSpec` +- `tests/grugpt/test_mesh.py` spins up synthetic mesh metadata (e.g., 4-device v4-8) to ensure axis naming/resizing works when devices >= 4. +- `tests/grugpt/test_train_step.py` runs two train steps on random tokens, verifying loss decreases and tracker receives entries (with splash attention disabled for determinism). +- All tests run under `uv run pytest tests/grugpt -n auto`; TPU-specific splash kernels get smoke-tested behind a flag if CI has TPU runtime, otherwise stub out gracefully. + +### Current Status + +- The current prototype lives under `lib/levanter/src/levanter/grug/` (not yet extracted to `lib/grugpt/`). +- Attention uses ejkernel blocksparse by default; reference attention remains available for debugging. +- Levanter integration lives in `lib/levanter/src/levanter/models/grug_wrapper.py` (kept out of `levanter.grug` intentionally). + +## Implementation Tasks + +1. **Data + tracker integration.** Swap the synthetic batch generator for a real tokenizer/dataset pipeline (leveraging Levanter loaders) and pipe metrics through the tracker API. +2. **Attention + masks.** Extend the structured mask surface (segments/windows) and keep the reference implementation available for debugging/regressions. +3. **Checkpointing.** Hook `tree_{de}serialize_leaves_tensorstore` into the training loop so state/rng/optimizer snapshots can be saved/restored. +4. **Evaluation + CLI polish.** Add validation hooks, flag parsing (YAML/Draccus), and a proper entry point under `uv run` instead of the current hard-coded config. +5. **Testing.** Port the sketched unit tests (`test_model`, `test_mesh`, `test_train_step`) into `tests/grugpt/` and wire them into CI. + +## Next Milestones (Current Plan) + +1) **PR what we have** + - Keep it focused: grug core simplicity + ejkernel blocksparse + adapter placement + doc updates. + +2) **Speedrun: hackable single-file gauntlet** + - Target: `experiments/speedrun/hackable_transformer_starter/hackable_transformer_attn_sink.py` (similar to `experiments/hackable_transformer_starter_template.py`). + - Copy-paste the grug core into the file and define a small “gauntlet API” (`init_fn`, `fwd_fn`, `loss_fn`, `train_step`). + - Include both `reference_attention` and ejkernel block-sparse paths and a tiny correctness check comparing them on small shapes. + +3) **Native datasets for grug** + - Add “grug-native” dataset analogs for `CausalLmDataset` and `MultiturnChatDataset` that yield plain `jax.Array` batches (`tokens`, `labels`, optional `segment_ids`). + - Reuse Levanter ingestion/tokenization internally via adapters until we decide to fully decouple. + +4) **Preemption + resume in the grug trainer** + - Follow the patterns in `lib/levanter/src/levanter/main/train_lm.py`: periodic checkpoints, resume-from-latest, and safe/atomic step dirs. + - Checkpoint `(state, cfg, tokenizer metadata)` with Levanter’s TensorStore serialization. + +5) **Lock the grug core surface with tests** + - Shapes + `jax.jit` compile sanity for `init_parameters`/`forward`. + - Sharding sanity: `PartitionSpec` expectations don’t explode when a mesh is active. + - Mask semantics for causal + sliding window + segment ids. + - (Optional) blocksparse vs reference numerical check on tiny shapes. + +## Speedrun Milestone: Hackable Transformer Gauntlet + +Goal: use Grug as a “copy-pasteable” reference implementation inside a single speedrun script (e.g. +`experiments/speedrun/hackable_transformer_starter/hackable_transformer_attn_sink.py`), so the workflow is: + +1) copy/paste the Grug core into the file +2) modify it (e.g. add attention sinks) +3) run it through a standard gauntlet (correctness + compile + throughput + memory) + +### Constraints + +- Single-file friendly: minimal imports, no class-heavy APIs, top-level functions. +- Keep an accelerated path (ejkernel blocksparse) + a reference fallback path. +- Make the “hack points” obvious: attention sinks, masks, sharding, and config. + +### Plan (concrete) + +1) **Define a “Grug-In-File” template section in the speedrun script** + - Put it near the top of the file under a clear header like `# === GRUG CORE (COPY-PASTE) ===`. + - Include only: + - param dataclasses (`GrugAttentionParams`, `GrugBlockParams`, `GrugModelParameters`) + - `init_parameters(cfg, *, key)` + - `forward(params, tokens, cfg, *, mask=None)` + - `attention(q, k, v, mask)` + - `AttentionMask` (Grug-local, raw arrays only) + +2) **Make sinks a deliberate edit point** + - Do not bake “sinks”/`softmax_aux` into the stable grug API. + - In the hackable speedrun script, add sinks by directly editing the ejkernel block-sparse attention callsite (or by copy-pasting a slightly modified `attention()` helper into the file). + +3) **Keep the mask surface minimal and explicit** + - Use the Grug-local `AttentionMask` spec (causal + sliding window + segment ids). + - Avoid accepting dense boolean masks in accelerated mode (raise loudly). + - Default to causal masking in `forward()` when `mask is None`. + +4) **Define the gauntlet API in the speedrun script** + - Standardize a small set of functions the harness calls: + - `build_cfg()` (or constants) that define model hyperparams. + - `init_fn(key) -> params` + - `fwd_fn(params, batch) -> logits` + - `loss_fn(logits, labels) -> loss` + - `train_step(state, batch) -> (new_state, metrics)` + - Keep all state as plain pytrees so it’s easy to serialize/inspect. + +5) **Gauntlet checks to run every time** + - **Correctness sanity** (small shapes): + - reference vs blocksparse forward (match within a tolerance) + - gradients exist and are finite + - **Compile sanity**: + - time-to-first-step (TTFS) for `train_step` + - **Throughput sanity**: + - tokens/sec and step time for a fixed number of warmup+iters + - **Memory sanity**: + - optional: track max HBM/VMEM if available in the speedrun harness + +6) **Keep Levanter adapters out of the speedrun file** + - Do not rely on `LmHeadModel`/`NamedArray` in the hackable file. + - If you need to run the same model inside Levanter, use the adapter in `levanter.models.grug_wrapper`. + +7) **Document the expected “edit surface” inside the file** + - Put a short list at the top: + - “To add sinks: edit the ejkernel block-sparse attention call in `attention()`.” + - “To change masking: edit `AttentionMask` / `mask` construction in `forward()`.” + - “To change sharding: edit `PartitionSpec`s in init + `out_sharding=` callsites.” + +## Integration Direction + +Grug currently integrates with Levanter via a thin wrapper (`levanter.models.grug_wrapper`). That’s fine as an adapter, but it should not become a long-term architectural dependency. + +Longer term, we want one of: + +1) **A grug-native trainer** (preferred for “hackable” workflows) + - Minimal surface: plain pytrees + explicit mesh + small config. + - Owns data loading, checkpointing, and evaluation in a way that’s easy to copy/paste into speedrun scripts. + +2) **Evolve Levanter/Marin to support grug natively** + - Make Levanter’s trainer accept a “grug-style” model (pure functions + pytrees) without requiring a wrapper that reifies NamedArray/Haliax concepts. + - Goal: the core stays `jax.Array`-first, and the training stack becomes more flexible rather than forcing everything into `LmHeadModel`-style interfaces. + +The wrapper remains a pragmatic bridge, but the intended direction is to shrink/remove it over time. + +## Open Questions & Follow-Ups + +- Do we want gradient accumulation/microbatching in v1, or stick to per-step full batches? +- Should we support tensor parallel axes beyond `('model',)` at launch, or gate behind a flag until we verify the explicit-sharding rules? +- For tokenizer caching, is it acceptable to rely solely on HF local cache, or should we add explicit `deserialize_leaves_tensorstore` support for tokenizers stored in checkpoints? +- What minimum TPU/GPU topology should the CI tests assume (2 devices vs 4)? +- Loss: the current blockwise CE (`cross_entropy_block_size`) is a tradeoff; in the 125M speedrun we saw MFU jump from ~20 -> ~40 when disabling chunking (`block_size=None`). We need a better large-vocab loss kernel eventually. diff --git a/.github/workflows/levanter-tests.yaml b/.github/workflows/levanter-tests.yaml index f2476f6959..7e3a62f8cc 100644 --- a/.github/workflows/levanter-tests.yaml +++ b/.github/workflows/levanter-tests.yaml @@ -191,4 +191,4 @@ jobs: -v /tmp/uv-cache:/tmp/uv-cache:rw \ -w /workspace \ $DOCKER_IMAGE \ - bash -c "cp -a /workspace-src/. /workspace/ && cd /workspace && timeout --kill-after=5 --signal=TERM 890 uv run --package levanter --frozen --with 'jax[tpu]==$JAX_VERSION' pytest lib/levanter/tests -m 'not entry and not ray and not slow' -v --tb=short --log-cli-level=WARNING --durations=20" \ No newline at end of file + bash -c "cp -a /workspace-src/. /workspace/ && cd /workspace && timeout --kill-after=5 --signal=TERM 890 uv run --package levanter --frozen --with 'jax[tpu]==$JAX_VERSION' pytest lib/levanter/tests -m 'not entry and not ray and not slow' -v --tb=short --log-cli-level=WARNING --durations=20" diff --git a/AGENTS.md b/AGENTS.md index 888f2bf654..bd012b3f70 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -14,6 +14,7 @@ - Begin with the agent-friendly recipes in `docs/recipes/`. - The first step for dataset addition is schema inspection. See the [add_dataset.md](docs/recipes/add_dataset.md) recipe for details. - You can help organize experiments using the [organize_experiments.md](docs/recipes/organize_experiments.md) recipe. +- When making significant changes to Grug/Grugformer, follow [change_grug.md](docs/recipes/change_grug.md). - Follow the rules and examples in each recipe to ensure compatibility and automation-friendliness. ## Shared Coding Practices diff --git a/docs/recipes/change_grug.md b/docs/recipes/change_grug.md new file mode 100644 index 0000000000..62a73fe936 --- /dev/null +++ b/docs/recipes/change_grug.md @@ -0,0 +1,112 @@ +# Recipe: Changing Grug (Experiment → Canonical) + +Grug is meant to be “grug-simple” and easy to hack, but we still want a single, trustworthy “best guess” implementation in `levanter.grug`. + +This recipe describes the workflow for: + +1) trying changes safely in a speedrun experiment, and +2) upstreaming successful ideas into the canonical core (and cleaning up old experiments). + +## Source Of Truth vs Experiments + +- **Source of truth:** `lib/levanter/src/levanter/grug/` + - This is the “best guess” model. It should stay small, readable, and testable. +- **Evolving experiment:** `experiments/speedrun/grugformer_starter/grugformer_speedrun.py` + - This is the *living* entrypoint that is expected to evolve as we learn. +- **One-off experiments:** under `experiments/speedrun/…` + - These are snapshots / specialized edit surfaces (e.g. attention sinks). + +We try not to let one-off scripts become the canonical implementation. + +## When You Want To Try Something + +### 1) Decide what you’re changing + +Most changes fall into one bucket: + +- **Attention** (masking semantics, kernels, sinks/aux, layout/sharding) +- **Block** (residual wiring, normalization order, pre/post-norm) +- **MLP** (activation, GLU variants, gating, dimension choices) +- **Loss** (large-vocab CE, z-loss, label smoothing, logit soft-cap) +- **Optimizer** (Adam, Muon, etc.) + +Try to change **one bucket at a time**. Optimizer isn't really (currently) addressed by Grug, but we'll get there. + +### 2) Create an experiment entrypoint + +Start from: + +- `experiments/speedrun/grugformer_starter/grugformer_speedrun.py` + +Recommended workflow: + +1. Copy the file to a new experiment (or branch the starter if the change is small): + - Example: `experiments/speedrun/grugformer_/grugformer_.py` +2. Keep the edit surface explicit: + - If you’re changing attention, keep the change in one local `attention()` or `attn_fn` block. + - If you’re changing the MLP, keep it local to an `mlp()` helper. +3. Avoid introducing new abstractions (this is a speedrun file; copy/paste is fine). + +### 3) Register the experiment in the archive + +Add an entry to: + +- `docs/reports/grug-archive.md` + +Record: +- the experiment path, +- the commit SHA (once known), +- what you changed and why, +- the intended “status” (`active`, `superseded`, `deleted`). + +## When You Want To Adopt Something As Canonical + +### 1) Port to `levanter.grug` + +Move the change into one of the core files: + +- `lib/levanter/src/levanter/grug/attention.py` +- `lib/levanter/src/levanter/grug/model.py` +- `lib/levanter/src/levanter/grug/loss.py` + +Keep the “grug” style: +- top-level functions, +- small dataclasses only for parameter/state containers, +- explicit sharding when needed (and loud failures otherwise). + +### 2) Update/extend tests + +Add or adjust tests to lock the intended surface: + +- `lib/levanter/tests/test_grugformer_core.py` +- `lib/levanter/tests/test_grugformer_model_loss.py` +- `lib/levanter/tests/test_grugformer_fused_loss.py` + +The goal is: +- shapes don’t regress, +- `jit` still works, +- sharding doesn’t explode, +- mask semantics remain correct. + +### 3) Clean up old experiments + +After merging a canonical improvement: + +- If an experiment is now redundant and not referenced, **delete it** and mark it `deleted` in `docs/reports/grug-archive.md`. +- If an experiment represents a meaningful historical run, keep it but mark it `superseded`, and point to the canonical change (or the new experiment). + Do this only if it's not going to be a maintenance burden. + +Prefer “archive entry + deletion” over keeping lots of old code in-tree. + +### 4) Run repo checks + +Before sending the PR: + +```sh +uv run python infra/pre-commit.py --all-files +``` + +## Notes / Inspiration + +This workflow is inspired by projects like `modded-nanogpt`: keep a small, readable core, iterate quickly via “hackable” entrypoints, and regularly upstream what works. + diff --git a/docs/reports/grug-archive.md b/docs/reports/grug-archive.md new file mode 100644 index 0000000000..4f1750e26a --- /dev/null +++ b/docs/reports/grug-archive.md @@ -0,0 +1,61 @@ +# Grug Archive: Experiments and Snapshots + +This file is a lightweight “paper trail” for Grug-related experiments, inspired by the idea of keeping a runnable history without letting a pile of one-off scripts become the de facto source of truth. + +## Principles + +- **`levanter.grug` is the source of truth.** Speedrun files are snapshots/entrypoints, not the canonical implementation. +- **Every experiment should be attributable to a commit.** If an experiment is removed or superseded, it should be clear what replaced it and why. +- **Prefer deletion over permanent snapshots.** If a script is dead, delete it and record the last known-good commit here. +- **Keep diffs small.** When an experiment is kept “alive”, update it to track the current core rather than forking the entire model. + +## When Grug Core Changes + +When a change in `levanter.grug` is likely to affect results, performance, or semantics: + +1. Update the experiment(s) that should track “best guess”. +2. For experiments that no longer make sense: + - delete them, or + - mark them superseded and point to the replacement. +3. Update the corresponding entry in this archive (and any linked issue). + +## Entry Template + +Copy/paste this block for new experiments: + +```text +### +- Path: `` +- Introduced: +- Last known-good: +- Status: active | superseded | deleted +- Purpose: +- Notes: +- Superseded by: +- Issue: +``` + +## Experiments + +### grugformer-attnsink +- Path: `experiments/speedrun/grugformer_attnsink/grugformer_attn_sink.py` +- Introduced: TBD +- Last known-good: TBD +- Status: active +- Purpose: “Hackable” Grug attention-sink variant; intended edit surface for sinks/aux. +- Notes: Keep this file short; copy/paste local modifications rather than growing new abstractions. + +### grugformer-starter-speedrun +- Path: `experiments/speedrun/grugformer_starter/grugformer_speedrun.py` +- Introduced: TBD +- Last known-good: TBD +- Status: active +- Purpose: Minimal starter speedrun for Grug; convenient baseline for quick iteration. + +### grugformer-vs-hackable-125m +- Path: `experiments/speedrun/grugformer_vs_hackable_125m/grugformer_vs_hackable_125m.py` +- Introduced: TBD +- Last known-good: TBD +- Status: active +- Purpose: Head-to-head comparison between Hackable Transformer and Grugformer (no sinks). + diff --git a/experiments/simple_train_config.py b/experiments/simple_train_config.py index 0d49cefc5f..05d7b1e4f9 100644 --- a/experiments/simple_train_config.py +++ b/experiments/simple_train_config.py @@ -92,7 +92,7 @@ class SimpleTrainConfig: """Whether to run the JAX profiler during training.""" profiler_start_step: int = 5 """Which step to start profiling.""" - profiler_num_steps: int = 100 + profiler_num_steps: int = 25 """How many steps to profile for once started.""" explicit_mesh_axes: bool = False diff --git a/experiments/speedrun/grugformer_attnsink/grugformer_attn_sink.py b/experiments/speedrun/grugformer_attnsink/grugformer_attn_sink.py new file mode 100644 index 0000000000..62d3d24c6c --- /dev/null +++ b/experiments/speedrun/grugformer_attnsink/grugformer_attn_sink.py @@ -0,0 +1,504 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Grugformer speedrun with attention sinks (TPU splash attention `sinks`). + +Analogue of: + experiments/speedrun/hackable_transformer_starter/hackable_transformer_attn_sink.py + +How to run: + python marin/run/ray_run.py -- \ + python -m experiments.speedrun.grugformer_attnsink.grugformer_attn_sink +""" + +# nodryrun + +import dataclasses +import functools +import logging +import math +import os + +import jax +import jax.numpy as jnp +from einops import rearrange +from fray.cluster import ResourceConfig +from haliax import Axis +from jax.sharding import PartitionSpec +from jaxtyping import PRNGKeyArray + +from levanter.grug.attention import AttentionMask, apply_rotary_embedding +from levanter.grug.model import GrugModelConfig +from levanter.grug.model import rms_norm, mlp +from levanter.models.grug_wrapper import GrugWrapper +from levanter.models.lm_model import LmConfig +from levanter.utils.flop_utils import lm_flops_per_token +from marin.execution.executor import executor_main +from marin.speedrun.speedrun import Author, SpeedrunConfig, default_speedrun + +from experiments.llama import llama3_tokenizer_vocab_size +from experiments.simple_train_config import SimpleTrainConfig + +logger = logging.getLogger("ray") + + +def _get_num_train_steps(param_count: int, batch_size: int, max_seq_len: int, tpp: int = 20) -> int: + total_tokens = param_count * tpp + return max(1, total_tokens // (batch_size * max_seq_len)) + + +def _resource_presets(use_tpu: bool = False): + if use_tpu: + return { + "130m": ResourceConfig.with_tpu("v5p-8"), + "300m": ResourceConfig.with_tpu("v5p-8"), + "520m": ResourceConfig.with_tpu("v5p-8"), + "1_2b": ResourceConfig.with_tpu("v5p-8"), + } + return { + "130m": ResourceConfig.with_gpu("A100-80G", count=1), + "300m": ResourceConfig.with_gpu("A100-80G", count=1), + "520m": ResourceConfig.with_gpu("A100-80G", count=2), + "1_2b": ResourceConfig.with_gpu("A100-80G", count=4), + } + + +def _batch_sizes() -> dict[str, int]: + return {"130m": 128, "300m": 128, "520m": 128, "1_2b": 256} + + +def _size_presets() -> dict[str, "GrugformerAttnSinkLmConfig"]: + base = dict(max_seq_len=2048, head_dim=None) + return { + "130m": GrugformerAttnSinkLmConfig( + hidden_dim=512, + intermediate_dim=1792, + num_layers=6, + num_heads=8, + num_kv_heads=8, + **base, + ), + "300m": GrugformerAttnSinkLmConfig( + hidden_dim=768, + intermediate_dim=2688, + num_layers=12, + num_heads=12, + num_kv_heads=12, + **base, + ), + "520m": GrugformerAttnSinkLmConfig( + hidden_dim=1024, + intermediate_dim=3584, + num_layers=24, + num_heads=16, + num_kv_heads=16, + **base, + ), + "1_2b": GrugformerAttnSinkLmConfig( + hidden_dim=2048, + intermediate_dim=7168, + num_layers=16, + num_heads=16, + num_kv_heads=16, + **base, + ), + } + + +@dataclasses.dataclass(frozen=True) +class GrugformerAttnSinkConfig: + """Extra knobs for the sink speedrun (kept out of grug core config).""" + + num_sinks: int = 1 + init_logit: float = 0.0 + + +@dataclasses.dataclass(frozen=True) +class GrugformerAttnSinkModelConfig: + """Config object carried by GrugWrapper for this speedrun. + + This is separate from `GrugModelConfig` to keep the sink knobs out of grug core. + """ + + core: GrugModelConfig + sink: GrugformerAttnSinkConfig + + @property + def vocab_size(self) -> int: + return self.core.vocab_size + + @property + def max_seq_len(self) -> int: + return self.core.max_seq_len + + @property + def hidden_dim(self) -> int: + return self.core.hidden_dim + + @property + def cross_entropy_block_size(self): + return 32000 + + +def _init_grugformer_with_sinks(cfg: GrugformerAttnSinkModelConfig, *, key: PRNGKeyArray) -> dict: + core = _init_core(cfg.core, key) + + # Initialize sink logits per layer. + # Use a small constant init by default; users can set init_logit to control initial sink mass. + # + # NOTE: splash attention takes one sink logit per query head. + if cfg.sink.num_sinks != 1: + raise NotImplementedError("This speedrun currently supports only num_sinks=1.") + sink_init = jnp.full((cfg.core.num_heads,), cfg.sink.init_logit, dtype=jnp.float32) + sink_logits = tuple(sink_init for _ in range(cfg.core.num_layers)) + + return {"core": core, "sink_logits": sink_logits} + + +def _init_core(core_cfg: GrugModelConfig, key: PRNGKeyArray): + from levanter.grug.model import init_parameters as grug_init + + return grug_init(core_cfg, key=key) + + +def _splash_attention_with_sink( + q: jax.Array, + k: jax.Array, + v: jax.Array, + mask: AttentionMask | jax.Array | None, + *, + softmax_aux: jax.Array | None, +) -> jax.Array: + if not isinstance(mask, AttentionMask) and mask is not None: + raise NotImplementedError("Only Grug AttentionMask is supported in this speedrun.") + + if mask is None: + mask = AttentionMask.causal() + + if softmax_aux is None: + raise ValueError("softmax_aux (sink logits) must be provided for the attn sink speedrun.") + + if jax.default_backend() != "tpu": + raise NotImplementedError("This speedrun currently uses TPU splash attention.") + + from jax.experimental.pallas.ops.tpu.splash_attention import SegmentIds as SplashSegmentIds + from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel, splash_attention_mask + + mesh = jax.sharding.get_abstract_mesh() + if mesh is None: + raise RuntimeError("Splash attention requires a JAX mesh to be set.") + + _batch, q_len, num_q_heads, head_dim = q.shape + kv_len = k.shape[1] + num_kv_heads = k.shape[2] + + if q_len != kv_len: + raise NotImplementedError("Splash attention speedrun currently assumes Sq == Sk.") + + if num_kv_heads != num_q_heads: + raise NotImplementedError("This speedrun currently requires num_kv_heads == num_heads for splash attention.") + + if kv_len % 128 != 0: + raise NotImplementedError("Splash attention requires the KV sequence length to be a multiple of 128.") + + scaling_factor = 1.0 / math.sqrt(head_dim) + q = q * scaling_factor + + q_bhsd = q.transpose(0, 2, 1, 3) + k_bhsd = k.transpose(0, 2, 1, 3) + v_bhsd = v.transpose(0, 2, 1, 3) + + if mask is None: + base_mask = splash_attention_mask.FullMask(_shape=(q_len, kv_len)) + else: + base_mask = splash_attention_mask.FullMask(_shape=(q_len, kv_len)) + if mask.is_causal: + base_mask = splash_attention_mask.CausalMask((q_len, kv_len), offset=0, shard_count=1) + if mask.sliding_window is not None: + local_mask = splash_attention_mask.LocalMask( + shape=(q_len, kv_len), + window_size=(mask.sliding_window - 1, None), + offset=0, + shard_count=1, + ) + base_mask = splash_attention_mask.LogicalAnd(base_mask, local_mask) + + kernel_mask = splash_attention_mask.MultiHeadMask(masks=[base_mask for _ in range(num_q_heads)]) + kernel = splash_attention_kernel.make_splash_mha( + mask=kernel_mask, + head_shards=1, + q_seq_shards=1, + block_sizes=splash_attention_kernel.BlockSizes.get_default(), + attn_logits_soft_cap=None, + ) + + sinks = softmax_aux.astype(jnp.float32) + + segment_ids = None + if mask is not None and mask.segment_ids is not None: + q_seg, kv_seg = mask.segment_ids + segment_ids = SplashSegmentIds(q_seg.astype(jnp.int32), kv_seg.astype(jnp.int32)) + + q_spec = PartitionSpec(("data",), None, None, None) + k_spec = PartitionSpec(("data",), None, None, None) + v_spec = PartitionSpec(("data",), None, None, None) + sinks_spec = PartitionSpec(None) + + if segment_ids is None: + + @functools.partial( + jax.shard_map, + mesh=mesh, + in_specs=(q_spec, k_spec, v_spec, sinks_spec), + out_specs=q_spec, + check_vma=False, + ) + def _call_splash_attention(q_, k_, v_, sinks_): + return jax.vmap(lambda q_b, k_b, v_b: kernel(q_b, k_b, v_b, sinks=sinks_), in_axes=(0, 0, 0))(q_, k_, v_) + + out_bhsd = _call_splash_attention(q_bhsd, k_bhsd, v_bhsd, sinks) + else: + segment_ids_spec = SplashSegmentIds( + PartitionSpec(("data",), None), + PartitionSpec(("data",), None), + ) + + @functools.partial( + jax.shard_map, + mesh=mesh, + in_specs=(q_spec, k_spec, v_spec, segment_ids_spec, sinks_spec), + out_specs=q_spec, + check_vma=False, + ) + def _call_splash_attention(q_, k_, v_, seg_ids, sinks_): + return jax.vmap( + lambda q_b, k_b, v_b, si: kernel(q_b, k_b, v_b, segment_ids=si, sinks=sinks_), + in_axes=(0, 0, 0, 0), + )(q_, k_, v_, seg_ids) + + out_bhsd = _call_splash_attention(q_bhsd, k_bhsd, v_bhsd, segment_ids, sinks) + return out_bhsd.transpose(0, 2, 1, 3) + + +def _grug_activations_with_sinks( + params: dict, + token_ids: jax.Array, + cfg: GrugformerAttnSinkModelConfig, + *, + mask: AttentionMask | jax.Array | None = None, +) -> jax.Array: + cfg = cfg.core + head_dim = cfg.inferred_head_dim + seq_len = token_ids.shape[1] + + if mask is None: + mask = AttentionMask.causal() + elif isinstance(mask, AttentionMask) and not mask.is_causal: + mask = dataclasses.replace(mask, is_causal=True) + + core = params["core"] + hidden = core.token_embed.at[token_ids].get(out_sharding=_PBATCH) + + for block, sink_logits in zip(core.blocks, params["sink_logits"], strict=True): + attn_in = rms_norm(hidden, block.rms_attn, cfg.layer_norm_eps) + + q = rearrange(jnp.einsum("bsh,hd->bsd", attn_in, block.attn.w_q), "... (n d) -> ... n d", d=head_dim) + k = rearrange(jnp.einsum("bsh,hd->bsd", attn_in, block.attn.w_k), "... (m d) -> ... m d", d=head_dim) + v = rearrange(jnp.einsum("bsh,hd->bsd", attn_in, block.attn.w_v), "... (m d) -> ... m d", d=head_dim) + q, k = apply_rotary_embedding(q, k, seq_len=seq_len, head_dim=head_dim, rope=cfg.rope) + + attn_out = _splash_attention_with_sink(q, k, v, mask, softmax_aux=sink_logits) + attn_out = rearrange(attn_out, "... n d -> ... (n d)") + attn_out = jnp.einsum("bsh,hd->bsd", attn_out, block.attn.w_o, out_sharding=_PBATCH) + hidden = hidden + attn_out + + mlp_in = rms_norm(hidden, block.rms_mlp, cfg.layer_norm_eps) + mlp_out = mlp(block, mlp_in) + hidden = hidden + mlp_out + + # final rms norm + hidden = rms_norm(hidden, core.final_norm, cfg.layer_norm_eps) + return hidden + + +def _lm_head_from_sink_params(params: dict) -> jax.Array: + return params["core"].output_proj + + +_PBATCH = jax.sharding.PartitionSpec(("data",), None) + + +@LmConfig.register_subclass("grugformer_attn_sink") +@dataclasses.dataclass(frozen=True) +class GrugformerAttnSinkLmConfig(LmConfig[GrugWrapper]): + max_seq_len: int = 2048 + + hidden_dim: int = 1024 + intermediate_dim: int = 2752 + num_layers: int = 12 + num_heads: int = 16 + num_kv_heads: int = 16 + head_dim: int | None = None + rope_theta: float = 10000.0 + + sink: GrugformerAttnSinkConfig = dataclasses.field(default_factory=GrugformerAttnSinkConfig) + + @property + def model_type(self) -> type[GrugWrapper]: + return GrugWrapper + + @property + def Embed(self) -> Axis: + return Axis("embed", self.hidden_dim) + + def to_grug_model_config(self) -> GrugModelConfig: + return GrugModelConfig( + vocab_size=llama3_tokenizer_vocab_size, + hidden_dim=self.hidden_dim, + intermediate_dim=self.intermediate_dim, + num_layers=self.num_layers, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + max_seq_len=self.max_seq_len, + ) + + def build(self, Vocab: Axis, *, key: PRNGKeyArray) -> GrugWrapper: + core_cfg = self.to_grug_model_config() + cfg = GrugformerAttnSinkModelConfig(core=core_cfg, sink=self.sink) + params = _init_grugformer_with_sinks(cfg, key=key) + return GrugWrapper( + params=params, + grug_config=cfg, + init_fn=_init_grugformer_with_sinks, + forward_fn=_grug_activations_with_sinks, + lm_head_fn=_lm_head_from_sink_params, + ) + + def flops_per_token(self, vocab_size: int, context_length: int) -> float | None: + return lm_flops_per_token( + hidden_dim=self.hidden_dim, + intermediate_dim=self.intermediate_dim, + num_layers=self.num_layers, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + seq_len=context_length, + vocab_size=vocab_size, + glu=True, + ) + + def total_trainable_params(self, vocab_size: int) -> int: + head_dim = self.head_dim or (self.hidden_dim // self.num_heads) + token_embedding = vocab_size * self.hidden_dim + attn = ( + self.hidden_dim * head_dim * self.num_heads + + 2 * self.hidden_dim * head_dim * self.num_kv_heads + + head_dim * self.num_heads * self.hidden_dim + ) + mlp = 3 * self.hidden_dim * self.intermediate_dim + transformer = self.num_layers * (attn + mlp + 2 * self.hidden_dim) + self.hidden_dim + sinks = self.num_layers * self.num_heads + return int(transformer + 2 * token_embedding + sinks) + + +speedrun_config = SpeedrunConfig( + author=Author( + name="David Hall", + affiliation="OpenAthena", + url="TODO", + ), + description="Grugformer with attention sinks via TPU splash attention.", + model_config=GrugformerAttnSinkLmConfig( + max_seq_len=2048, + hidden_dim=1024, + intermediate_dim=2752, + num_layers=12, + num_heads=16, + num_kv_heads=16, + sink=GrugformerAttnSinkConfig(num_sinks=1, init_logit=0.0), + ), + train_config=SimpleTrainConfig( + ResourceConfig.with_tpu("v5p-8"), + train_batch_size=32, + num_train_steps=100, + learning_rate=3e-3, + weight_decay=0.1, + steps_per_eval=500, + steps_per_hf_export=-1, + explicit_mesh_axes=True, + ), +) + + +def build_run(size: str, *, use_tpu: bool = False) -> tuple[str, SpeedrunConfig]: + sizes = _size_presets() + if size not in sizes: + raise ValueError(f"Unknown size: {size}") + model_cfg = sizes[size] + + if not use_tpu: + raise ValueError("grugformer_attn_sink requires SR_USE_TPU=1 (it uses TPU splash attention sinks).") + + batch = _batch_sizes()[size] + max_seq_len = model_cfg.max_seq_len + params = int(model_cfg.total_trainable_params(llama3_tokenizer_vocab_size)) + steps = _get_num_train_steps(params, batch, max_seq_len, tpp=20) + resources = _resource_presets(use_tpu=use_tpu)[size] + + train = SimpleTrainConfig( + resources, + train_seq_len=max_seq_len, + train_batch_size=batch, + num_train_steps=steps, + learning_rate=3e-3, + weight_decay=0.1, + steps_per_eval=500, + steps_per_hf_export=-1, + explicit_mesh_axes=True, + ) + + run_name = f"grugformer_attn_sink_{size}" + desc = f"Grugformer with attention sinks via TPU splash attention ({size})." + cfg = SpeedrunConfig( + author=speedrun_config.author, + description=desc, + model_config=model_cfg, + train_config=train, + ) + return run_name, cfg + + +def main() -> None: + sizes = [ + "130m", + "300m", + "520m", + "1_2b", + ] + # sizes = ["130m", "300m", "520m", "1_2b"] + use_tpu = bool(int(os.environ.get("SR_USE_TPU", "0"))) + + steps = [] + for s in sizes: + name, cfg = build_run(s, use_tpu=use_tpu) + if cfg.vocab_size != llama3_tokenizer_vocab_size: + raise AssertionError("Speedrun vocab_size mismatch; expected llama3_tokenizer_vocab_size") + cfg.print_run_info() + steps.extend(default_speedrun(name, cfg)) + executor_main(steps=steps, description="Grugformer with attention sinks via TPU splash attention.") + + +if __name__ == "__main__": + main() diff --git a/experiments/speedrun/grugformer_starter/grugformer_speedrun.py b/experiments/speedrun/grugformer_starter/grugformer_speedrun.py new file mode 100644 index 0000000000..d2f9d9a609 --- /dev/null +++ b/experiments/speedrun/grugformer_starter/grugformer_speedrun.py @@ -0,0 +1,214 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Grugformer starter speedrun. + +This uses the grug core implementation (`levanter.grug`) and wires it into the existing +Marin speedrun harness via the `levanter.models.grug_wrapper.GrugWrapper` adapter. + +On TPU, grug uses JAX's Splash Attention; on other backends it falls back to a reference +attention implementation. + +How to run: + python marin/run/ray_run.py -- \ + python -m experiments.speedrun.grugformer_starter.grugformer_speedrun +""" + +# nodryrun + +import logging +import os +from dataclasses import dataclass + +from fray.cluster import ResourceConfig +from haliax import Axis +from jaxtyping import PRNGKeyArray + +from levanter.grug.model import GrugModelConfig +from levanter.models.grug_wrapper import GrugWrapper +from levanter.models.lm_model import LmConfig +from levanter.utils.flop_utils import lm_flops_per_token +from marin.execution.executor import executor_main +from marin.speedrun.speedrun import Author, SpeedrunConfig, default_speedrun + +from experiments.llama import llama3_tokenizer_vocab_size +from experiments.simple_train_config import SimpleTrainConfig + +logger = logging.getLogger("ray") + + +def _get_num_train_steps(param_count: int, batch_size: int, max_seq_len: int, tpp: int = 20) -> int: + total_tokens = param_count * tpp + return max(1, total_tokens // (batch_size * max_seq_len)) + + +def _size_presets() -> dict[str, "GrugformerConfig"]: + base = dict(max_seq_len=2048, head_dim=None) + return { + "130m": GrugformerConfig( + hidden_dim=512, intermediate_dim=1792, num_layers=6, num_heads=8, num_kv_heads=8, **base + ), + "300m": GrugformerConfig( + hidden_dim=768, intermediate_dim=2688, num_layers=12, num_heads=12, num_kv_heads=12, **base + ), + "520m": GrugformerConfig( + hidden_dim=1024, intermediate_dim=3584, num_layers=24, num_heads=16, num_kv_heads=16, **base + ), + "1_2b": GrugformerConfig( + hidden_dim=2048, intermediate_dim=7168, num_layers=16, num_heads=16, num_kv_heads=16, **base + ), + } + + +def _resource_presets(use_tpu: bool = False): + if use_tpu: + return { + "130m": ResourceConfig.with_tpu("v5p-8"), + "300m": ResourceConfig.with_tpu("v5p-8"), + "520m": ResourceConfig.with_tpu("v5p-8"), + "1_2b": ResourceConfig.with_tpu("v5p-8"), + } + return { + "130m": ResourceConfig.with_gpu("A100-80G", count=1), + "300m": ResourceConfig.with_gpu("A100-80G", count=1), + "520m": ResourceConfig.with_gpu("A100-80G", count=2), + "1_2b": ResourceConfig.with_gpu("A100-80G", count=4), + } + + +def _batch_sizes() -> dict[str, int]: + return {"130m": 128, "300m": 128, "520m": 128, "1_2b": 256} + + +@LmConfig.register_subclass("grugformer") +@dataclass(frozen=True) +class GrugformerConfig(LmConfig[GrugWrapper]): + """LmConfig wrapper around grug core hyperparameters.""" + + # LmConfig field + max_seq_len: int = 2048 + + # Grug core hyperparams + hidden_dim: int = 1024 + intermediate_dim: int = 2752 + num_layers: int = 12 + num_heads: int = 16 + num_kv_heads: int = 16 + head_dim: int | None = None + + # ---- LmConfig API ---- + @property + def model_type(self) -> type[GrugWrapper]: + return GrugWrapper + + @property + def Embed(self) -> Axis: + # Not used by GrugWrapper (it returns logits directly), but LmConfig requires it. + return Axis("embed", self.hidden_dim) + + def build(self, Vocab: Axis, *, key: PRNGKeyArray) -> GrugWrapper: + grug_cfg = GrugModelConfig( + vocab_size=Vocab.size, + hidden_dim=self.hidden_dim, + intermediate_dim=self.intermediate_dim, + num_layers=self.num_layers, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + max_seq_len=self.max_seq_len, + ) + return GrugWrapper.init(Vocab, grug_cfg, key=key) + + def flops_per_token(self, vocab_size: int, context_length: int) -> float | None: + return lm_flops_per_token( + hidden_dim=self.hidden_dim, + intermediate_dim=self.intermediate_dim, + num_layers=self.num_layers, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + seq_len=context_length, + vocab_size=vocab_size, + glu=True, + ) + + def total_trainable_params(self, vocab_size: int) -> int: + head_dim = self.head_dim or (self.hidden_dim // self.num_heads) + token_embedding = vocab_size * self.hidden_dim + attn = ( + self.hidden_dim * head_dim * self.num_heads + + 2 * self.hidden_dim * head_dim * self.num_kv_heads + + head_dim * self.num_heads * self.hidden_dim + ) + mlp = 3 * self.hidden_dim * self.intermediate_dim + transformer = self.num_layers * (attn + mlp + 2 * self.hidden_dim) + self.hidden_dim + return int(transformer + 2 * token_embedding) + + +def build_run(size: str, *, use_tpu: bool = False) -> tuple[str, SpeedrunConfig]: + sizes = _size_presets() + if size not in sizes: + raise ValueError(f"Unknown size: {size}") + model_cfg = sizes[size] + + batch = _batch_sizes()[size] + max_seq_len = model_cfg.max_seq_len + params = int(model_cfg.total_trainable_params(llama3_tokenizer_vocab_size)) + steps = _get_num_train_steps(params, batch, max_seq_len, tpp=20) + resources = _resource_presets(use_tpu=use_tpu)[size] + + train = SimpleTrainConfig( + resources, + train_seq_len=max_seq_len, + train_batch_size=batch, + num_train_steps=steps, + learning_rate=3e-3, + weight_decay=0.1, + steps_per_eval=500, + steps_per_hf_export=-1, + explicit_mesh_axes=True, + ) + + run_name = f"grugformer_starter_{size}" + desc = f"Grugformer starter (ejkernel blocksparse attention) ({size})." + cfg = SpeedrunConfig( + author=Author( + name="__YOUR_NAME__", + affiliation="__YOUR_AFFILIATION__", + url="__YOUR_URL__", + ), + description=desc, + model_config=model_cfg, + train_config=train, + ) + return run_name, cfg + + +def main() -> None: + sizes = ["130m", "300m", "520m", "1_2b"] + use_tpu = bool(int(os.environ.get("SR_USE_TPU", "0"))) + + steps = [] + for s in sizes: + name, cfg = build_run(s, use_tpu=use_tpu) + if cfg.vocab_size != llama3_tokenizer_vocab_size: + raise AssertionError("Speedrun vocab_size mismatch; expected llama3_tokenizer_vocab_size") + cfg.print_run_info() + steps.extend(default_speedrun(name, cfg)) + + executor_main(steps=steps, description="Grugformer starter (ejkernel blocksparse attention).") + + +if __name__ == "__main__": + main() diff --git a/experiments/speedrun/grugformer_vs_hackable_125m/grugformer_vs_hackable_125m.py b/experiments/speedrun/grugformer_vs_hackable_125m/grugformer_vs_hackable_125m.py new file mode 100644 index 0000000000..1e7232887c --- /dev/null +++ b/experiments/speedrun/grugformer_vs_hackable_125m/grugformer_vs_hackable_125m.py @@ -0,0 +1,231 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Head-to-head speedrun: Hackable Transformer vs Grugformer (no sinks), ~125M params. + +How to run: + python marin/run/ray_run.py -- \ + python -m experiments.speedrun.grugformer_vs_hackable_125m.grugformer_vs_hackable_125m + +By default this uses GPU resource presets. Set SR_USE_TPU=1 for TPU. +""" + +# nodryrun + +import os +from dataclasses import dataclass + +from fray.cluster import ResourceConfig +from haliax import Axis +from jaxtyping import PRNGKeyArray + +from levanter.grug.model import GrugModelConfig +from levanter.models.grug_wrapper import GrugWrapper +from levanter.models.lm_model import LmConfig +from levanter.utils.flop_utils import lm_flops_per_token +from marin.execution.executor import executor_main +from marin.speedrun.speedrun import Author, SpeedrunConfig, default_speedrun + +from experiments.llama import llama3_tokenizer_vocab_size +from experiments.simple_train_config import SimpleTrainConfig +from experiments.speedrun.hackable_transformer_starter.hackable_transformer_attn_sink import HackableTransformerConfig + +AUTHOR = Author( + name="David Hall", + affiliation="Stanford University", + url="https://github.com/dlwh", +) + + +def _resource_preset(*, use_tpu: bool) -> ResourceConfig: + if use_tpu: + return ResourceConfig.with_tpu("v5p-8") + return ResourceConfig.with_gpu("A100-80G", count=1) + + +def _num_train_steps(*, param_count: int, batch_size: int, max_seq_len: int, tpp: int = 20) -> int: + total_tokens = param_count * tpp + return max(1, total_tokens // (batch_size * max_seq_len)) + + +@LmConfig.register_subclass("grugformer_h2h_125m") +@dataclass(frozen=True) +class GrugformerH2HConfig(LmConfig[GrugWrapper]): + """LmConfig wrapper around grug core hyperparameters (for head-to-head comparisons).""" + + max_seq_len: int = 2048 + + hidden_dim: int = 512 + intermediate_dim: int = 1792 + num_layers: int = 6 + num_heads: int = 8 + num_kv_heads: int = 8 + head_dim: int | None = None + + # NOTE: `None` means "single full-vocab block" in grug's blockwise CE. For the 125M head-to-head + # speedrun we observed MFU jump from ~20 -> ~40 by disabling chunking; keep it simple for now. + cross_entropy_block_size: int | None = None + + @property + def model_type(self) -> type[GrugWrapper]: + return GrugWrapper + + @property + def Embed(self) -> Axis: + # Not used by GrugWrapper (it returns logits directly), but LmConfig requires it. + return Axis("embed", self.hidden_dim) + + def build(self, Vocab: Axis, *, key: PRNGKeyArray) -> GrugWrapper: + grug_cfg = GrugModelConfig( + vocab_size=Vocab.size, + hidden_dim=self.hidden_dim, + intermediate_dim=self.intermediate_dim, + num_layers=self.num_layers, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + max_seq_len=self.max_seq_len, + cross_entropy_block_size=self.cross_entropy_block_size, + ) + return GrugWrapper.init(Vocab, grug_cfg, key=key) + + def flops_per_token(self, vocab_size: int, context_length: int) -> float | None: + return lm_flops_per_token( + hidden_dim=self.hidden_dim, + intermediate_dim=self.intermediate_dim, + num_layers=self.num_layers, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + seq_len=context_length, + vocab_size=vocab_size, + glu=True, + ) + + def total_trainable_params(self, vocab_size: int) -> int: + head_dim = self.head_dim or (self.hidden_dim // self.num_heads) + token_embedding = vocab_size * self.hidden_dim + attn = ( + self.hidden_dim * head_dim * self.num_heads + + 2 * self.hidden_dim * head_dim * self.num_kv_heads + + head_dim * self.num_heads * self.hidden_dim + ) + mlp = 3 * self.hidden_dim * self.intermediate_dim + transformer = self.num_layers * (attn + mlp + 2 * self.hidden_dim) + self.hidden_dim + return int(transformer + token_embedding + token_embedding) + + +def _hackable_125m_config() -> HackableTransformerConfig: + # Match the 130m preset dims from hackable transformer starter, but use 2048 context for parity with grug defaults. + return HackableTransformerConfig( + max_seq_len=2048, + hidden_dim=512, + intermediate_dim=1792, + num_layers=6, + num_heads=8, + num_kv_heads=8, + head_dim=None, + use_attention_sink=False, + ) + + +def _grug_125m_config() -> GrugformerH2HConfig: + return GrugformerH2HConfig( + max_seq_len=2048, + hidden_dim=512, + intermediate_dim=1792, + num_layers=6, + num_heads=8, + num_kv_heads=8, + head_dim=None, + ) + + +def _train_config( + *, + use_tpu: bool, + batch_size: int, + max_seq_len: int, + num_train_steps: int, + explicit_mesh_axes: bool, +) -> SimpleTrainConfig: + return SimpleTrainConfig( + _resource_preset(use_tpu=use_tpu), + train_seq_len=max_seq_len, + train_batch_size=batch_size, + num_train_steps=num_train_steps, + learning_rate=3e-3, + weight_decay=0.1, + steps_per_hf_export=-1, + explicit_mesh_axes=explicit_mesh_axes, + profiler=True, + ) + + +def main() -> None: + use_tpu = bool(int(os.environ.get("SR_USE_TPU", "0"))) + + batch_size = 128 + max_seq_len = 2048 + + hack_cfg = _hackable_125m_config() + grug_cfg = _grug_125m_config() + + hack_params = int(hack_cfg.total_trainable_params(llama3_tokenizer_vocab_size)) + grug_params = int(grug_cfg.total_trainable_params(llama3_tokenizer_vocab_size)) + + hack_steps = _num_train_steps(param_count=hack_params, batch_size=batch_size, max_seq_len=max_seq_len) + grug_steps = _num_train_steps(param_count=grug_params, batch_size=batch_size, max_seq_len=max_seq_len) + + hack_train = _train_config( + use_tpu=use_tpu, + batch_size=batch_size, + max_seq_len=max_seq_len, + num_train_steps=hack_steps, + explicit_mesh_axes=False, + ) + grug_train = _train_config( + use_tpu=use_tpu, + batch_size=batch_size, + max_seq_len=max_seq_len, + num_train_steps=grug_steps, + explicit_mesh_axes=use_tpu, + ) + + hack_speedrun = SpeedrunConfig( + author=AUTHOR, + description="Hackable Transformer (~125M) - standard attention (no sinks).", + model_config=hack_cfg, + train_config=hack_train, + ) + grug_speedrun = SpeedrunConfig( + author=AUTHOR, + description="Grugformer (~125M) - TPU Splash Attention / reference fallback (no sinks).", + model_config=grug_cfg, + train_config=grug_train, + ) + + if hack_speedrun.vocab_size != llama3_tokenizer_vocab_size: + raise AssertionError("Hackable speedrun vocab_size mismatch; expected llama3_tokenizer_vocab_size") + if grug_speedrun.vocab_size != llama3_tokenizer_vocab_size: + raise AssertionError("Grug speedrun vocab_size mismatch; expected llama3_tokenizer_vocab_size") + + steps = [] + steps.extend(default_speedrun(f"hackable_compare_125m_{max_seq_len}-profile2", hack_speedrun)) + steps.extend(default_speedrun(f"grug_compare_125m_{max_seq_len}-profile2", grug_speedrun)) + executor_main(steps=steps, description="Head-to-head: hackable transformer vs grugformer (~125M, no sinks)") + + +if __name__ == "__main__": + main() diff --git a/lib/levanter/pyproject.toml b/lib/levanter/pyproject.toml index fa718e2b82..915d4643e8 100644 --- a/lib/levanter/pyproject.toml +++ b/lib/levanter/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ "jax>=0.8.0", "fray", "zephyr", + "einops", "jaxtyping>=0.2.34", "tokenizers>=0.15.2", "transformers>=4.57.1,<5.0", @@ -175,5 +176,6 @@ markers = [ "entry: marks tests as entry point tests (deselect with '-m \"not entry\"')", "ray: marks tests that require Ray (deselect with '-m \"not ray\"')", "torch: mark tests that use Torch (deselect with '-m \"not torch\"')", + "tpu: mark tests that require TPU (deselect with '-m \"not tpu\"')", ] asyncio_default_fixture_loop_scope = "function" diff --git a/lib/levanter/src/levanter/__init__.py b/lib/levanter/src/levanter/__init__.py index 5159d3a52a..bd3d4ce759 100644 --- a/lib/levanter/src/levanter/__init__.py +++ b/lib/levanter/src/levanter/__init__.py @@ -15,6 +15,7 @@ "tracker", "trainer", "visualization", + "grug", "current_tracker", "initialize", ] @@ -32,6 +33,7 @@ import levanter.tracker as tracker import levanter.trainer as trainer import levanter.visualization as visualization +import levanter.grug as grug from levanter.tracker import current_tracker from levanter.trainer import initialize diff --git a/lib/levanter/src/levanter/grug/__init__.py b/lib/levanter/src/levanter/grug/__init__.py new file mode 100644 index 0000000000..8a1829572b --- /dev/null +++ b/lib/levanter/src/levanter/grug/__init__.py @@ -0,0 +1,8 @@ +# Copyright 2025 The Levanter Authors +# SPDX-License-Identifier: Apache-2.0 + +"""Grug: a grug-simple explicit-sharding transformer trainer. + +This package intentionally exposes only the "grug core" (raw-array model + kernels). +Levanter integration adapters live under `levanter.models`. +""" diff --git a/lib/levanter/src/levanter/grug/attention.py b/lib/levanter/src/levanter/grug/attention.py new file mode 100644 index 0000000000..6e654ca800 --- /dev/null +++ b/lib/levanter/src/levanter/grug/attention.py @@ -0,0 +1,398 @@ +# Copyright 2025 The Levanter Authors +# SPDX-License-Identifier: Apache-2.0 +import functools +import math +from dataclasses import dataclass + +import jax +from jax import numpy as jnp +from jax.experimental.shard_map import shard_map +from jax.sharding import NamedSharding, PartitionSpec as P +from jax.tree_util import register_dataclass +from jaxtyping import Array, Bool, Float, Int + +from haliax.jax_utils import named_call +from haliax.partitioning import _get_mesh + + +@dataclass(frozen=True) +class RotaryConfig: + """Lightweight rotary embedding configuration.""" + + theta: float = 10000.0 + scaling_factor: float | None = None + + +@functools.partial(register_dataclass, data_fields=["segment_ids"], meta_fields=["is_causal", "sliding_window"]) +@dataclass(frozen=True) +class AttentionMask: + """Grug attention mask spec. + + This is deliberately simpler than `levanter.layers.attention.AttentionMask`: + - Stores raw JAX arrays (no NamedArray fields). + - Supports causal masking, sliding windows, and segment IDs. + """ + + is_causal: bool = False + segment_ids: tuple[jax.Array, jax.Array] | None = None + sliding_window: int | None = None + + @classmethod + def causal(cls, *, sliding_window: int | None = None) -> "AttentionMask": + return cls(is_causal=True, sliding_window=sliding_window) + + def with_segment_ids( + self, q_segment_ids: Int[Array, "..."], kv_segment_ids: Int[Array, "..."] | None = None + ) -> "AttentionMask": + kv_ids = q_segment_ids if kv_segment_ids is None else kv_segment_ids + return AttentionMask( + is_causal=self.is_causal, + segment_ids=(q_segment_ids, kv_ids), + sliding_window=self.sliding_window, + ) + + def with_sliding_window(self, sliding_window: int | None) -> "AttentionMask": + return AttentionMask( + is_causal=self.is_causal, + segment_ids=self.segment_ids, + sliding_window=sliding_window, + ) + + def materialize_mask(self, q_len: int, k_len: int) -> Bool[Array, "..."] | None: + """Return a boolean mask (True = allowed) or None. + + Shapes: + - If `segment_ids` is unset, returns `(q_len, k_len)` (broadcastable across batch). + - If `segment_ids` is set with per-batch IDs, returns `(batch, q_len, k_len)`. + """ + mask = None + + if self.is_causal: + q_idx = jnp.arange(q_len)[:, None] + k_idx = jnp.arange(k_len)[None, :] + allowed = k_idx <= q_idx + mask = allowed + + if self.sliding_window is not None: + if self.sliding_window <= 0: + raise ValueError(f"sliding_window must be positive, got {self.sliding_window}") + q_idx = jnp.arange(q_len)[:, None] + k_idx = jnp.arange(k_len)[None, :] + # Standard sliding-window semantics: `sliding_window=W` means "keep the last W tokens, + # including self". Without causality, this is "don't look too far back": + # k >= q - (W - 1) + allowed = k_idx >= q_idx - (self.sliding_window - 1) + mask = allowed if mask is None else jnp.logical_and(mask, allowed) + + if self.segment_ids is not None: + q_seg, k_seg = self.segment_ids + if q_seg.ndim != k_seg.ndim: + raise ValueError(f"segment_ids ndim mismatch: q={q_seg.ndim}, k={k_seg.ndim}") + if q_seg.ndim == 1: + allowed = q_seg[:, None] == k_seg[None, :] + elif q_seg.ndim == 2: + if q_seg.shape[0] != k_seg.shape[0]: + raise ValueError(f"segment_ids batch mismatch: q={q_seg.shape[0]}, k={k_seg.shape[0]}") + allowed = q_seg[:, :, None] == k_seg[:, None, :] + else: + raise ValueError(f"segment_ids must be 1D or 2D, got ndim={q_seg.ndim}") + mask = allowed if mask is None else jnp.logical_and(mask, allowed) + + return mask + + +def _rotary_cache(seq_len: int, head_dim: int, rope: RotaryConfig) -> tuple[Float[Array, "S D"], Float[Array, "S D"]]: + half_dim = head_dim // 2 + inv_freq = 1.0 / (rope.theta ** (jnp.arange(0, half_dim, dtype=jnp.float32) / half_dim)) + positions = jnp.arange(seq_len, dtype=jnp.float32) + angles = positions[:, None] * inv_freq[None, :] + cos = jnp.cos(angles) + sin = jnp.sin(angles) + return cos, sin + + +@named_call +def apply_rotary_embedding( + q: Float[Array, "B S H D"], + k: Float[Array, "B S H D"], + *, + seq_len: int, + head_dim: int, + rope: RotaryConfig, +) -> tuple[Float[Array, "B S H D"], Float[Array, "B S H D"]]: + cos, sin = _rotary_cache(seq_len, head_dim, rope) + cos = cos[None, :, None, :] + sin = sin[None, :, None, :] + + def _apply(x: Float[Array, "B S H D"]) -> Float[Array, "B S H D"]: + x1, x2 = jnp.split(x, 2, axis=-1) + return jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1) + + return _apply(q), _apply(k) + + +def reference_attention( + q: Float[Array, "B Q Hq D"], + k: Float[Array, "B K Hkv D"], + v: Float[Array, "B K Hkv D"], + mask: AttentionMask | Bool[Array, "B Q K"] | Float[Array, "B Q K"] | None, + *, + logits_dtype: jnp.dtype | None, +) -> Float[Array, "B Q Hq D"]: + head_dim = q.shape[-1] + num_q_heads = q.shape[2] + num_kv_heads = k.shape[2] + + if num_q_heads != num_kv_heads: + if num_q_heads % num_kv_heads != 0: + raise ValueError(f"num_heads ({num_q_heads}) must be divisible by num_kv_heads ({num_kv_heads})") + repeat = num_q_heads // num_kv_heads + k = jnp.repeat(k, repeat, axis=2) + v = jnp.repeat(v, repeat, axis=2) + + scale = 1.0 / math.sqrt(head_dim) + scores = jnp.einsum("bqhd,bkhd->bhqk", q * scale, k) + + explicit = None + if mask is None: + explicit = None + elif isinstance(mask, AttentionMask): + explicit = mask.materialize_mask(scores.shape[-2], scores.shape[-1]) + else: + explicit = mask + + if explicit is not None: + # Standardize dense masks to [B, Q, K] (bool = allowed, float = additive bias). + if explicit.ndim == 2: + explicit = explicit[None, :, :] + if explicit.ndim != 3: + raise ValueError(f"explicit mask must have shape [batch, q, k], got shape={explicit.shape}") + if explicit.shape[0] not in (1, q.shape[0]): + raise ValueError(f"explicit mask batch dim must be 1 or {q.shape[0]}, got {explicit.shape[0]}") + if explicit.shape[1] != scores.shape[-2] or explicit.shape[2] != scores.shape[-1]: + raise ValueError( + "explicit mask must match attention shapes: " + f"got mask={explicit.shape}, expected [batch,{scores.shape[-2]},{scores.shape[-1]}]" + ) + + explicit = explicit[:, None, :, :] # -> [B, 1, Q, K], broadcast across heads. + if explicit.dtype == jnp.bool_: + scores = jnp.where(explicit, scores, jnp.array(-1e9, dtype=scores.dtype)) + else: + scores = scores + explicit + if logits_dtype is not None: + scores = scores.astype(logits_dtype) + weights = jax.nn.softmax(scores, axis=-1).astype(v.dtype) + ctx = jnp.einsum("bhqk,bkhd->bqhd", weights, v) + return ctx.astype(v.dtype) + + +def _spec_shard_factor(entry: str | tuple[str, ...] | None, mesh) -> int: + """ + Compute the size of the mesh axes associated with a PartitionSpec entry. + + Splash attention can handle various kinds of sequence parallelism but needs this to function + """ + if entry is None or mesh is None: + return 1 + if isinstance(entry, str): + return mesh.shape[entry] + if isinstance(entry, tuple): + factor = 1 + for name in entry: + factor *= mesh.shape[name] + return factor + raise TypeError(f"Unsupported PartitionSpec entry: {entry!r}") + + +def _tpu_splash_attention( + q: Float[Array, "B Q Hq D"], + k: Float[Array, "B K Hkv D"], + v: Float[Array, "B K Hkv D"], + mask: AttentionMask | jax.Array | None, +) -> Float[Array, "B Q Hq D"]: + from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel, splash_attention_mask + from jax.experimental.pallas.ops.tpu.splash_attention import SegmentIds as SplashSegmentIds + + # Splash attention expects BHSD. + q_ = jnp.transpose(q, (0, 2, 1, 3)) + k_ = jnp.transpose(k, (0, 2, 1, 3)) + v_ = jnp.transpose(v, (0, 2, 1, 3)) + + B, Hq, Sq, D = q_.shape + _, _, Sk, _ = k_.shape + + if Sk % 128 != 0: + raise NotImplementedError("Splash attention requires key/value sequence length to be a multiple of 128.") + + q_ = q_ * (1.0 / math.sqrt(D)) + + mesh = _get_mesh() + if mesh is None or getattr(mesh, "empty", False): + raise RuntimeError("TPU splash attention requires a JAX mesh context.") + + def _named_sharding_of(x: jax.Array, *, label: str) -> NamedSharding: + """Extract NamedSharding from a JAX value or tracer. + + In JAX, `.sharding` is not available on tracers during staging; however the sharding + is still available on the underlying abstract value in many cases. + """ + sharding = None + try: + sharding = x.sharding # type: ignore[attr-defined] + except Exception: + sharding = None + if sharding is None: + aval = getattr(x, "aval", None) + sharding = getattr(aval, "sharding", None) if aval is not None else None + if not isinstance(sharding, NamedSharding): + raise TypeError( + f"TPU splash attention expects NamedSharding on {label} under an explicit mesh; got {sharding!r}." + ) + return sharding + + q_sharding = _named_sharding_of(q_, label="q") + k_sharding = _named_sharding_of(k_, label="k") + v_sharding = _named_sharding_of(v_, label="v") + + q_pspec = q_sharding.spec + k_pspec = k_sharding.spec + v_pspec = v_sharding.spec + + # KV sequence must not be sharded for splash attention. + if k_pspec[2] is not None: + raise NotImplementedError( + "Splash attention does not support sharding the KV sequence dimension. " + f"Got KV sequence spec: {k_pspec[2]}" + ) + + head_shards = _spec_shard_factor(q_pspec[1], mesh) + q_seq_shards = _spec_shard_factor(q_pspec[2], mesh) + kv_seq_shards = _spec_shard_factor(k_pspec[2], mesh) + + # MaxText uses a block size of 512. Pick per-shard blocks that evenly divide each shard length, + # preferring multiples of 128 when possible. + block_size = 512 + + shard_Sq = max(1, Sq // max(1, q_seq_shards)) + shard_Sk = max(1, Sk // max(1, kv_seq_shards)) + + def _compatible_block(shard_len: int, max_block: int) -> int: + """Pick largest block <= max_block that divides shard_len; prefer multiples of 128.""" + if shard_len <= 0: + return max_block + cap = min(max_block, shard_len) + for step in (128, 1): + candidate = cap - (cap % step) + while candidate >= step: + if shard_len % candidate == 0: + return candidate + candidate -= step + return 1 + + block_q = _compatible_block(shard_Sq, block_size) + block_kv = _compatible_block(shard_Sk, block_size) + + block_sizes = splash_attention_kernel.BlockSizes( + block_q=block_q, + block_kv_compute=block_kv, + block_kv=block_kv, + block_q_dkv=block_q, + block_kv_dkv=block_kv, + block_kv_dkv_compute=block_q, + block_q_dq=block_q, + block_kv_dq=block_kv, + ) + + if mask is None: + base_mask = splash_attention_mask.FullMask(_shape=(Sq, Sk)) + segment_ids = None + segment_ids_axes = None + segment_batch_axis = None + elif isinstance(mask, AttentionMask): + if mask.is_causal: + base_mask = splash_attention_mask.CausalMask((Sq, Sk), offset=0, shard_count=q_seq_shards) + else: + base_mask = splash_attention_mask.FullMask(_shape=(Sq, Sk)) + + if mask.sliding_window is not None: + if mask.sliding_window <= 0: + raise ValueError(f"sliding_window must be positive, got {mask.sliding_window}") + local_mask = splash_attention_mask.LocalMask( + shape=(Sq, Sk), + # Grug's `sliding_window` matches the "lookback" semantics used in the + # reference mask materialization: allow attending to keys with + # k >= q - (sliding_window - 1) + # (and optionally combine with causal). + window_size=(mask.sliding_window - 1, None), + offset=0, + shard_count=q_seq_shards, + ) + base_mask = splash_attention_mask.LogicalAnd(base_mask, local_mask) + + if mask.segment_ids is not None: + q_segment_ids, kv_segment_ids = mask.segment_ids + q_seg_sharding = _named_sharding_of(q_segment_ids, label="segment_ids.q") + kv_seg_sharding = _named_sharding_of(kv_segment_ids, label="segment_ids.kv") + segment_ids = SplashSegmentIds(q_segment_ids, kv_segment_ids) + segment_ids_axes = SplashSegmentIds( + q_seg_sharding.spec, + kv_seg_sharding.spec, + ) + segment_batch_axis = SplashSegmentIds( + 0 if q_segment_ids.ndim == 2 else None, + 0 if kv_segment_ids.ndim == 2 else None, + ) + else: + segment_ids = None + segment_ids_axes = None + segment_batch_axis = None + else: + raise NotImplementedError("Dense masks are not supported for splash attention.") + + kernel_mask = splash_attention_mask.MultiHeadMask(masks=[base_mask for _ in range(Hq)]) + + splash_kernel = splash_attention_kernel.make_splash_mha( + mask=kernel_mask, + block_sizes=block_sizes, + head_shards=head_shards, + q_seq_shards=q_seq_shards, + ) + + kernel_sharding = NamedSharding(mesh, P(q_pspec[1], q_pspec[2])) + kernel_specs = splash_kernel.manual_sharding_spec(kernel_sharding) + + @functools.partial( + shard_map, + mesh=mesh, + in_specs=(q_pspec, k_pspec, v_pspec, segment_ids_axes, kernel_specs), + out_specs=q_pspec, + check_rep=False, + ) + def wrap(q_bhsd, k_bhsd, v_bhsd, seg_ids, kernel): + return jax.vmap(kernel, in_axes=(0, 0, 0, segment_batch_axis))(q_bhsd, k_bhsd, v_bhsd, seg_ids) + + out = wrap(q_, k_, v_, segment_ids, splash_kernel) + return jnp.transpose(out, (0, 2, 1, 3)).astype(v.dtype) + + +def attention( + q: Float[Array, "B Q Hq D"], + k: Float[Array, "B K Hkv D"], + v: Float[Array, "B K Hkv D"], + mask: AttentionMask | Bool[Array, "B Q K"] | Float[Array, "B Q K"] | None, +) -> Float[Array, "B Q Hq D"]: + if jax.default_backend() == "tpu": + if isinstance(mask, jax.Array): + return reference_attention(q, k, v, mask, logits_dtype=jnp.float32) + return _tpu_splash_attention(q, k, v, mask) + return reference_attention(q, k, v, mask, logits_dtype=jnp.float32) + + +__all__ = [ + "AttentionMask", + "RotaryConfig", + "apply_rotary_embedding", + "attention", + "reference_attention", +] diff --git a/lib/levanter/src/levanter/grug/data.py b/lib/levanter/src/levanter/grug/data.py new file mode 100644 index 0000000000..2911bd5858 --- /dev/null +++ b/lib/levanter/src/levanter/grug/data.py @@ -0,0 +1,71 @@ +# Copyright 2025 The Levanter Authors +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Mapping +from typing import Any + +from jax.sharding import Mesh + +from levanter.data.loader import DataLoader +from levanter.data.text import TokenSeqDataset +from levanter.store.cache import TreeCache + +# Levanter's DataLoader expects an axis name for the batch dimension. We map it to +# the data axis so each data shard loads only its local share. +DEFAULT_AXIS_MAPPING = {"batch": ("data",)} + + +def make_token_dataset(cache: TreeCache[dict], *, seq_len: int) -> TokenSeqDataset: + """Thin wrapper so callers don't touch Levanter internals directly.""" + + return TokenSeqDataset(cache, seq_len) + + +def make_dataloader( + dataset: TokenSeqDataset, + *, + batch_size: int, + mesh: Mesh, + axis_mapping: Mapping[str, tuple[str, ...]] | None = None, + max_buffered_batches: int = 64, + prefetch_size: int = 32, + pad_final_batch: bool = True, + allow_nondivisible_batch_size: bool = False, +) -> DataLoader: + """Wraps a TokenSeqDataset with Levanter's sharding-aware DataLoader.""" + + axis_resources = axis_mapping or DEFAULT_AXIS_MAPPING + return DataLoader( + dataset, + batch_size=batch_size, + mesh=mesh, + axis_resources=axis_resources, + batch_axis_name="batch", + max_buffered_batches=max_buffered_batches, + prefetch_size=prefetch_size, + pad_final_batch=pad_final_batch, + allow_nondivisible_batch_size=allow_nondivisible_batch_size, + ) + + +def build_token_loader( + *, + cache: TreeCache[dict], + seq_len: int, + batch_size: int, + mesh: Mesh, + axis_mapping: Mapping[str, tuple[str, ...]] | None = None, + loader_kwargs: Mapping[str, Any] | None = None, +) -> DataLoader: + """Convenience helper: cache -> TokenSeqDataset -> DataLoader.""" + + dataset = make_token_dataset(cache, seq_len=seq_len) + kwargs = dict(loader_kwargs or {}) + return make_dataloader(dataset, batch_size=batch_size, mesh=mesh, axis_mapping=axis_mapping, **kwargs) + + +__all__ = [ + "make_token_dataset", + "make_dataloader", + "build_token_loader", +] diff --git a/lib/levanter/src/levanter/grug/loss.py b/lib/levanter/src/levanter/grug/loss.py new file mode 100644 index 0000000000..16834c6227 --- /dev/null +++ b/lib/levanter/src/levanter/grug/loss.py @@ -0,0 +1,108 @@ +# Copyright 2025 The Levanter Authors +# SPDX-License-Identifier: Apache-2.0 + +"""Blockwise linear softmax cross-entropy for grug. + +This is the "large vocab friendly" alternative to materializing full logits +`hidden @ lm_head` with shape (batch, seq, vocab). + +Design notes: + - Works on plain `jax.Array` inputs (grug core doesn't use NamedArray). + - Computes `logsumexp` over vocab in blocks to reduce peak memory. + - Computes the correct-class logit via gather+dot (O(N*H)), avoiding a full + logits materialization. +""" + +from __future__ import annotations + +import jax +import jax.numpy as jnp +from jax.sharding import PartitionSpec as P + +from haliax.types import PrecisionLike +from levanter.grug.sharding import Pbatch + + +def linear_softmax_cross_entropy_loss_and_logz( + hidden: jax.Array, + lm_head: jax.Array, + labels: jax.Array, + *, + p_batch: P = Pbatch, + block_size: int | None = None, + dtype: jnp.dtype = jnp.float32, + precision: PrecisionLike = None, +) -> tuple[jax.Array, jax.Array]: + """Compute per-position cross entropy without materializing full logits. + + Args: + hidden: Array with shape (..., hidden_dim). + lm_head: Array with shape (hidden_dim, vocab_size). + labels: Integer array with shape (...,). + block_size: Vocab block size for logsumexp. + dtype: Accumulator dtype for logsumexp. + p_batch: Sharding for the gathered correct-class weights. Defaults to Pbatch. + + Returns: + (loss, logz) each with shape labels.shape. + """ + if block_size is None: + block_size = lm_head.shape[1] + + hidden_dim = hidden.shape[-1] + if lm_head.ndim != 2: + raise ValueError(f"lm_head must be 2D (hidden_dim, vocab), got shape={lm_head.shape}") + if lm_head.shape[0] != hidden_dim: + raise ValueError(f"hidden_dim mismatch: hidden={hidden_dim}, lm_head={lm_head.shape[0]}") + + vocab_size = lm_head.shape[1] + nblocks = (vocab_size + block_size - 1) // block_size + vocab_padded = nblocks * block_size + if vocab_padded != vocab_size: + # We pad the vocab dimension so we can `dynamic_slice_in_dim(..., slice_size=block_size)` + # for every block. Use `lax.pad` (rather than concatenating a freshly-created zeros array) + # so sharding stays consistent under explicit meshes. + lm_head = jax.lax.pad( + lm_head, + jnp.array(0, dtype=lm_head.dtype), + padding_config=((0, 0, 0), (0, vocab_padded - vocab_size, 0)), + ) + flat_hidden = hidden.reshape((-1, hidden_dim)).astype(dtype) + flat_labels = labels.reshape((-1,)).astype(jnp.int32) + + w_y = lm_head.T.at[flat_labels].get(out_sharding=p_batch).astype(dtype) + logit_y = jnp.sum(flat_hidden * w_y, axis=-1) + neg_inf = jnp.array(-jnp.inf, dtype=dtype) + # Match the sharding of the computed per-example logits so the loop carry types are stable. + m0 = jnp.full_like(logit_y, neg_inf) + s0 = jnp.zeros_like(logit_y, dtype=dtype) + + neg_inf_logits = jnp.array(-jnp.inf, dtype=dtype) + + def _body(i: int, state: tuple[jax.Array, jax.Array]) -> tuple[jax.Array, jax.Array]: + m, s = state + start = i * block_size + w_block = jax.lax.dynamic_slice_in_dim(lm_head, start_index=start, slice_size=block_size, axis=1).astype(dtype) + logits = jax.lax.dot_general( + flat_hidden, + w_block, + dimension_numbers=(((1,), (0,)), ((), ())), + precision=precision, + ) + valid = jnp.arange(block_size) < (vocab_size - start) + logits = jnp.where(valid[None, :], logits, neg_inf_logits) + block_max = jnp.max(logits, axis=-1) + new_m = jnp.maximum(m, block_max) + s = s * jnp.exp(m - new_m) + jnp.sum(jnp.exp(logits - new_m[:, None]), axis=-1) + return new_m, s + + m, s = jax.lax.fori_loop(0, nblocks, _body, (m0, s0)) + logz = m + jnp.log(s) + loss = logz - logit_y + + return loss.reshape(labels.shape), logz.reshape(labels.shape) + + +__all__ = [ + "linear_softmax_cross_entropy_loss_and_logz", +] diff --git a/lib/levanter/src/levanter/grug/main.py b/lib/levanter/src/levanter/grug/main.py new file mode 100644 index 0000000000..4bc8cdcf07 --- /dev/null +++ b/lib/levanter/src/levanter/grug/main.py @@ -0,0 +1,212 @@ +# Copyright 2025 The Levanter Authors +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import logging +from dataclasses import dataclass, replace +from typing import Iterator + +import jax +import numpy as np +import optax +from jax import numpy as jnp +from jax.tree_util import register_dataclass +from jax.sharding import AxisType + +from levanter.store.cache import TreeCache + +from levanter.grug.data import DEFAULT_AXIS_MAPPING, build_token_loader +from levanter.grug.model import GrugModelConfig, GrugModelParameters, forward, init_parameters + + +def synthetic_batch_iterator( + *, + rng: jax.Array, + batch_size: int, + seq_len: int, + vocab_size: int, +) -> Iterator[dict[str, jax.Array]]: + """Infinite generator of random token/label pairs.""" + + def _step(key: jax.Array) -> dict[str, jax.Array]: + tokens = jax.random.randint(key, (batch_size, seq_len), 0, vocab_size) + return {"tokens": tokens[:, :-1], "labels": tokens[:, 1:]} + + while True: + rng, key = jax.random.split(rng) + yield _step(key) + + +def dataloader_iterator( + loader, + *, + seq_len: int, +) -> Iterator[dict[str, jax.Array]]: + while True: + batch = next(loader) + tokens = batch[:, :seq_len] + yield {"tokens": tokens[:, :-1], "labels": tokens[:, 1:]} + + +@dataclass(frozen=True) +class GrugTrainingConfig: + """Full training recipe, nested around a model.""" + + model: GrugModelConfig + learning_rate: float = 1e-3 + weight_decay: float = 0.01 + seed: int = 0 + steps: int = 10 + global_batch_size: int = 8 + + +def create_mesh(*, global_batch_size: int | None = None) -> jax.sharding.Mesh: + devices = jax.devices() + if not devices: + raise RuntimeError("No JAX devices available") + + # Grug uses explicit sharding. For data parallelism, sharding the batch requires the batch + # to be divisible by the data axis size. For the minimal trainer, prefer using all devices, + # but fall back to a smaller data axis if needed (e.g. tiny synthetic tests). + data_size = len(devices) + if global_batch_size is not None and global_batch_size > 0: + while data_size > 1 and global_batch_size % data_size != 0: + data_size -= 1 + if data_size < len(devices): + logging.getLogger(__name__).warning( + "global_batch_size=%s is not divisible by device count=%s; using %s data devices", + global_batch_size, + len(devices), + data_size, + ) + + mesh_devices = np.array(devices[:data_size]).reshape(data_size, 1) + mesh = jax.sharding.Mesh( + mesh_devices, + axis_names=("data", "model"), + axis_types=(AxisType.Explicit, AxisType.Explicit), + ) + return mesh + + +def make_train_step( + model_cfg: GrugModelConfig, + optimizer: optax.GradientTransformation, +): + def loss_and_metrics(params: GrugModelParameters, batch: dict[str, jax.Array]): + logits = forward(params, batch["tokens"], model_cfg, mask=None) + loss = cross_entropy_loss(logits, batch["labels"]) + metrics = {"loss": loss, "ppl": jnp.exp(loss)} + return loss, metrics + + def step(state: TrainingState, batch: dict[str, jax.Array]): + (_loss, metrics), grads = jax.value_and_grad(loss_and_metrics, has_aux=True)(state.params, batch) + updates, new_opt_state = optimizer.update(grads, state.opt_state, state.params) + new_params = optax.apply_updates(state.params, updates) + new_state = replace(state, step=state.step + 1, params=new_params, opt_state=new_opt_state) + return new_state, metrics + + return jax.jit(step) + + +def cross_entropy_loss(logits: jax.Array, labels: jax.Array) -> jax.Array: + log_probs = jax.nn.log_softmax(logits, axis=-1) + gathered = jnp.take_along_axis(log_probs, labels[..., None], axis=-1) + return -jnp.mean(gathered) + + +def load_cache(cache_dir: str, *, seq_len: int) -> TreeCache[dict]: + exemplar = {"input_ids": np.zeros(seq_len, dtype=np.int32)} + return TreeCache.load(cache_dir, exemplar) + + +def run_training(train_cfg: GrugTrainingConfig, *, cache_dir: str | None = None) -> None: + mesh = create_mesh(global_batch_size=train_cfg.global_batch_size) + + with jax.set_mesh(mesh): + rng = jax.random.key(train_cfg.seed) + rng, init_rng = jax.random.split(rng) + params = init_parameters(train_cfg.model, key=init_rng) + optimizer = optax.adamw(learning_rate=train_cfg.learning_rate, weight_decay=train_cfg.weight_decay) + opt_state = optimizer.init(params) + state = TrainingState(step=0, params=params, opt_state=opt_state) + + seq_len = train_cfg.model.max_seq_len + train_step = make_train_step(train_cfg.model, optimizer) + + if cache_dir: + cache = load_cache(cache_dir, seq_len=seq_len) + loader = build_token_loader( + cache=cache, + seq_len=seq_len, + batch_size=train_cfg.global_batch_size, + mesh=mesh, + axis_mapping=DEFAULT_AXIS_MAPPING, + ) + batch_iter = dataloader_iterator(iter(loader), seq_len=seq_len) + else: + batch_iter = synthetic_batch_iterator( + rng=rng, + batch_size=train_cfg.global_batch_size, + seq_len=seq_len, + vocab_size=train_cfg.model.vocab_size, + ) + + for _ in range(train_cfg.steps): + batch = next(batch_iter) + state, metrics = train_step(state, batch) + print(f"step={state.step:03d} loss={float(metrics['loss']):.4f} ppl={float(metrics['ppl']):.2f}") + + +@dataclass(frozen=True) +class TrainingState: + step: int + params: GrugModelParameters + opt_state: optax.OptState + + +register_dataclass(TrainingState) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run the Grug trainer.") + parser.add_argument("--cache-dir", type=str, default=None, help="Optional TreeCache directory for real data.") + parser.add_argument("--steps", type=int, default=10) + parser.add_argument("--batch-size", type=int, default=4) + parser.add_argument("--seq-len", type=int, default=32) + parser.add_argument("--hidden-dim", type=int, default=128) + parser.add_argument("--layers", type=int, default=2) + parser.add_argument("--heads", type=int, default=4) + parser.add_argument("--vocab-size", type=int, default=50257) + return parser.parse_args() + + +def build_training_config(args: argparse.Namespace) -> GrugTrainingConfig: + model_cfg = GrugModelConfig( + vocab_size=args.vocab_size, + hidden_dim=args.hidden_dim, + intermediate_dim=4 * args.hidden_dim, + num_layers=args.layers, + num_heads=args.heads, + num_kv_heads=args.heads, + max_seq_len=args.seq_len, + ) + train_cfg = GrugTrainingConfig( + model=model_cfg, + learning_rate=1e-3, + weight_decay=0.01, + steps=args.steps, + global_batch_size=args.batch_size, + seed=0, + ) + return train_cfg + + +def main() -> None: + args = parse_args() + cfg = build_training_config(args) + run_training(cfg, cache_dir=args.cache_dir) + + +if __name__ == "__main__": + main() diff --git a/lib/levanter/src/levanter/grug/model.py b/lib/levanter/src/levanter/grug/model.py new file mode 100644 index 0000000000..9b72b1eb3e --- /dev/null +++ b/lib/levanter/src/levanter/grug/model.py @@ -0,0 +1,303 @@ +# Copyright 2025 The Levanter Authors +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses + +from dataclasses import dataclass +from functools import partial + +import jax +import jax.numpy as jnp +from einops import rearrange +from jax import random +from jax.sharding import PartitionSpec as P, reshard +from jax.tree_util import register_dataclass +from jaxtyping import Array, Float, Int, PRNGKeyArray + +from .attention import AttentionMask, RotaryConfig, apply_rotary_embedding, attention +from .loss import linear_softmax_cross_entropy_loss_and_logz +from .sharding import Pbatch, Pvocab, unshard + + +#### Conventions + +# Mesh meanings: +# - "data": data parallel sharding axis. We also shard parameters over this axis. +# - "model": model parallel sharding axis. TP + +# Dim names: +# - B = batch +# - D = embedding / hidden dim +# - S = sequence length +# - N = num heads +# - M = num kv heads +# - H = head dim +# - I = intermediate dim + + +@dataclass(frozen=True) +class GrugModelConfig: + """Hyperparameters for the Grug Llama-style transformer.""" + + vocab_size: int + hidden_dim: int = 2048 + intermediate_dim: int = 5632 + num_layers: int = 24 + num_heads: int = 16 + num_kv_heads: int = 16 + head_dim: int | None = None + max_seq_len: int = 4096 + layer_norm_eps: float = 1e-5 + initializer_std: float = 0.02 + rope: RotaryConfig = dataclasses.field(default_factory=RotaryConfig) + + # Controls how we compute logsumexp over the vocab in `levanter.grug.loss_fn`. + # + # - `None` means "single full-vocab block" (often faster for small-ish models/vocabs). + # - Smaller values reduce peak memory, but can be significantly slower in practice. + # + # TODO(grug): Replace with a faster large-vocab CE kernel so we don't have to pick between + # speed and memory. + cross_entropy_block_size: int | None = 32768 + + def __post_init__(self) -> None: + _ = self.inferred_head_dim + if self.num_heads % self.num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads for grouped-query attention") + if self.vocab_size <= 0: + raise ValueError("vocab_size must be positive") + if self.max_seq_len <= 0: + raise ValueError("max_seq_len must be positive") + + @property + def inferred_head_dim(self) -> int: + if self.head_dim is not None: + return self.head_dim + if self.hidden_dim % self.num_heads != 0: + raise ValueError( + f"hidden_dim={self.hidden_dim} is not divisible by num_heads={self.num_heads}; set head_dim explicitly" + ) + return self.hidden_dim // self.num_heads + + +@register_dataclass +@dataclass(frozen=True) +class GrugAttentionParams: + w_q: jax.Array + w_k: jax.Array + w_v: jax.Array + w_o: jax.Array + + +@register_dataclass +@dataclass(frozen=True) +class GrugBlockParams: + attn: GrugAttentionParams + rms_attn: jax.Array + rms_mlp: jax.Array + mlp_gate: jax.Array + mlp_up: jax.Array + mlp_down: jax.Array + + +@register_dataclass +@dataclass(frozen=True) +class GrugModelParameters: + token_embed: jax.Array + output_proj: jax.Array + blocks: tuple[GrugBlockParams, ...] + final_norm: jax.Array + + +def _init_weight(key: PRNGKeyArray, shape: tuple[int, ...], std: float) -> Float[Array, "..."]: + return std * random.truncated_normal(key, -3, 3, shape) + + +@partial(jax.jit, static_argnames=("cfg",)) +def init_parameters(cfg: GrugModelConfig, *, key: PRNGKeyArray) -> GrugModelParameters: + head_dim = cfg.inferred_head_dim + key, embed_key, out_key = random.split(key, 3) + layer_keys = random.split(key, cfg.num_layers) + + token_embed = reshard(_init_weight(embed_key, (cfg.vocab_size, cfg.hidden_dim), cfg.initializer_std), Pvocab) + output_proj = reshard(_init_weight(out_key, (cfg.hidden_dim, cfg.vocab_size), cfg.initializer_std), Pvocab) + final_norm = reshard(jnp.ones((cfg.hidden_dim,), dtype=jnp.float32), P(None)) + + blocks: list[GrugBlockParams] = [] + # extract shape sizes for brevity and consistency + D, N, M, H, I = cfg.hidden_dim, cfg.num_heads, cfg.num_kv_heads, head_dim, cfg.intermediate_dim + for i in range(cfg.num_layers): + k_q, k_k, k_v, k_o, k_gate, k_up, k_down = random.split(layer_keys[i], 7) + + attn = GrugAttentionParams( + w_q=reshard(_init_weight(k_q, (D, N * H), cfg.initializer_std), P("data", "model")), + w_k=reshard(_init_weight(k_k, (D, M * H), cfg.initializer_std), P("data", "model")), + w_v=reshard(_init_weight(k_v, (D, M * H), cfg.initializer_std), P("data", "model")), + w_o=reshard(_init_weight(k_o, (N * H, D), cfg.initializer_std), P("model", "data")), + ) + mlp_gate = reshard(_init_weight(k_gate, (D, I), cfg.initializer_std), P("data", "model")) + mlp_up = reshard(_init_weight(k_up, (D, I), cfg.initializer_std), P("data", "model")) + mlp_down = reshard(_init_weight(k_down, (I, D), cfg.initializer_std), P("model", "data")) + # keep rms replicated + rms_attn = jnp.ones((D,), dtype=jnp.float32) + rms_mlp = jnp.ones((D,), dtype=jnp.float32) + + blocks.append( + GrugBlockParams( + attn=attn, + rms_attn=rms_attn, + rms_mlp=rms_mlp, + mlp_gate=mlp_gate, + mlp_up=mlp_up, + mlp_down=mlp_down, + ) + ) + + return GrugModelParameters( + token_embed=token_embed, + output_proj=output_proj, + blocks=tuple(blocks), + final_norm=final_norm, + ) + + +def rms_norm(x: Float[Array, "... D"], weight: Float[Array, "D"], eps: float) -> Float[Array, "... D"]: + weight = unshard(weight) + variance = jnp.mean(jnp.square(x), axis=-1, keepdims=True) + normed = x * jax.lax.rsqrt(variance + eps) + return normed * weight + + +def mlp(block: GrugBlockParams, x: Float[Array, "B S D"]) -> Float[Array, "B S D"]: + gate = jnp.einsum("bsh,hm->bsm", x, block.mlp_gate) + up = jnp.einsum("bsh,hm->bsm", x, block.mlp_up) + activated = jax.nn.silu(gate) * up + return jnp.einsum("bsm,mh->bsh", activated, block.mlp_down, out_sharding=Pbatch) + + +def _transformer_hidden( + params: GrugModelParameters, + token_ids: Int[Array, "B S"], + cfg: GrugModelConfig, + *, + mask: AttentionMask | jax.Array | None, +) -> Float[Array, "B S D"]: + head_dim = cfg.inferred_head_dim + seq_len = token_ids.shape[1] + + if mask is None: + mask = AttentionMask.causal() + + hidden = params.token_embed.at[token_ids].get(out_sharding=Pbatch) + + for block in params.blocks: + attn_in = rms_norm(hidden, block.rms_attn, cfg.layer_norm_eps) + q = rearrange(jnp.einsum("bsh,hd->bsd", attn_in, block.attn.w_q), "... (n d) -> ... n d", d=head_dim) + k = rearrange(jnp.einsum("bsh,hd->bsd", attn_in, block.attn.w_k), "... (m d) -> ... m d", d=head_dim) + v = rearrange(jnp.einsum("bsh,hd->bsd", attn_in, block.attn.w_v), "... (m d) -> ... m d", d=head_dim) + q, k = apply_rotary_embedding(q, k, seq_len=seq_len, head_dim=head_dim, rope=cfg.rope) + attn_out = attention(q, k, v, mask) + attn_out = rearrange(attn_out, "... n d -> ... (n d)") + attn_out = jnp.einsum("bsh,hd->bsd", attn_out, block.attn.w_o, out_sharding=Pbatch) + + hidden = hidden + attn_out + mlp_in = rms_norm(hidden, block.rms_mlp, cfg.layer_norm_eps) + mlp_out = mlp(block, mlp_in) + hidden = hidden + mlp_out + + hidden = rms_norm(hidden, params.final_norm, cfg.layer_norm_eps) + return hidden + + +def forward( + params: GrugModelParameters, + token_ids: Int[Array, "B S"], + cfg: GrugModelConfig, + *, + mask: AttentionMask | jax.Array | None = None, +) -> Float[Array, "B S V"]: + hidden = _transformer_hidden(params, token_ids, cfg, mask=mask) + logits = jnp.einsum("bsh,hd->bsd", hidden, params.output_proj, out_sharding=Pbatch) + return logits + + +def activations( + params: GrugModelParameters, + token_ids: Int[Array, "B S"], + cfg: GrugModelConfig, + *, + mask: AttentionMask | jax.Array | None = None, +) -> Float[Array, "B S D"]: + """Return final hidden states with shape (batch, seq, hidden_dim).""" + return _transformer_hidden(params, token_ids, cfg, mask=mask) + + +def loss_fn( + params: GrugModelParameters, + token_ids: Int[Array, "B S"], + loss_weight: Float[Array, "B S"], + cfg: GrugModelConfig, + *, + mask: AttentionMask | jax.Array | None = None, + reduction: str = "mean", + logsumexp_weight: float | None = None, + loss_dtype: jnp.dtype = jnp.float32, +) -> jax.Array: + """Compute next-token cross-entropy loss for a batch. + + This is the "activations vs lm_head" friendly path: it avoids materializing full logits. + + Args: + params: Model parameters. + token_ids: Integer array with shape (batch, seq). + loss_weight: Float array with shape (batch, seq), typically 1 except last position (0). + cfg: Model config (uses `cfg.cross_entropy_block_size`). + mask: Optional attention mask spec. + reduction: One of {"mean", "sum", "none"}. + logsumexp_weight: Optional z-loss weight (logsumexp^2 term). + loss_dtype: Accumulator dtype for logsumexp / loss. + + Returns: + If reduction=="none": array with shape (batch, seq). + Else: scalar array. + """ + hidden = _transformer_hidden(params, token_ids, cfg, mask=mask) + labels = jnp.concatenate([token_ids[:, 1:], token_ids[:, :1] * 0], axis=1).astype(jnp.int32) + loss_weight = loss_weight.astype(loss_dtype) + + # NOTE: `block_size=None` corresponds to a single full-vocab block. On the 125M speedrun, + # disabling blockwise chunking doubled observed MFU (~20 -> ~40). We'll likely need a better + # large-vocab loss kernel eventually (esp. for sharded vocab / padding weights), but this is + # good enough for now. + block_size = cfg.cross_entropy_block_size + + per_pos_loss, logz = linear_softmax_cross_entropy_loss_and_logz( + hidden, + params.output_proj, + labels, + block_size=block_size, + dtype=loss_dtype, + ) + per_pos_loss = per_pos_loss.astype(loss_dtype) * loss_weight + if logsumexp_weight is not None and logsumexp_weight != 0.0: + per_pos_loss = per_pos_loss + logsumexp_weight * (logz.astype(loss_dtype) ** 2) * loss_weight + + if reduction == "none": + return per_pos_loss + if reduction == "sum": + return jnp.sum(per_pos_loss) + if reduction == "mean": + denom = jnp.sum(loss_weight) + return jnp.sum(per_pos_loss) / jnp.maximum(denom, jnp.array(1.0, dtype=loss_dtype)) + raise ValueError(f"Unknown reduction: {reduction}") + + +__all__ = [ + "GrugAttentionParams", + "GrugBlockParams", + "GrugModelParameters", + "init_parameters", + "activations", + "forward", + "loss_fn", +] diff --git a/lib/levanter/src/levanter/grug/sharding.py b/lib/levanter/src/levanter/grug/sharding.py new file mode 100644 index 0000000000..8d57df5bb8 --- /dev/null +++ b/lib/levanter/src/levanter/grug/sharding.py @@ -0,0 +1,15 @@ +# Copyright 2025 The Levanter Authors +# SPDX-License-Identifier: Apache-2.0 + +import jax +from jax import P +from jax.sharding import reshard + +# convenience shorthand for batch sharding. +# if this were Haliax, we'd say {"batch": ("data",)} +Pbatch = P(("data",)) +Pvocab = P(None, None) + + +def unshard(x: jax.Array) -> jax.Array: + return reshard(x, P((None,))) diff --git a/lib/levanter/src/levanter/models/grug_wrapper.py b/lib/levanter/src/levanter/models/grug_wrapper.py new file mode 100644 index 0000000000..f7c0e9143d --- /dev/null +++ b/lib/levanter/src/levanter/models/grug_wrapper.py @@ -0,0 +1,231 @@ +# Copyright 2025 The Levanter Authors +# SPDX-License-Identifier: Apache-2.0 + +# Adapter to wire the grug model into the LmHeadModel API. + +from typing import Any, Protocol, cast + +import equinox as eqx +import jax +import haliax as hax +import jax.numpy as jnp +from haliax import Axis, NamedArray +from jaxtyping import PRNGKeyArray, PyTree + +from levanter.grug.attention import AttentionMask +from levanter.grug.model import activations as grug_activations +from levanter.grug.model import init_parameters +from levanter.grug.model import loss_fn as grug_loss_fn +from levanter.layers.attention import AttentionMask as LevanterAttentionMask +from levanter.models.lm_model import LmExample, LmHeadModel + + +class GrugConfigLike(Protocol): + vocab_size: int + max_seq_len: int + hidden_dim: int + cross_entropy_block_size: int | None = None + + +class GrugForwardFn(Protocol): + def __call__( + self, + params: PyTree, + tokens: jax.Array, + cfg: GrugConfigLike, + *, + mask: AttentionMask | jax.Array | None = None, + ) -> jax.Array: ... + + +class GrugInitFn(Protocol): + def __call__(self, cfg: GrugConfigLike, *, key: PRNGKeyArray) -> PyTree: ... + + +class GrugLmHeadFn(Protocol): + def __call__(self, params: PyTree) -> jax.Array: ... + + +def _default_lm_head_fn(params: PyTree) -> jax.Array: + return params.output_proj # type: ignore[attr-defined] + + +def _mask_from_levanter(attn_mask: LevanterAttentionMask | NamedArray | None) -> AttentionMask | jax.Array | None: + mask: AttentionMask | jax.Array | None = None + if isinstance(attn_mask, LevanterAttentionMask): + if attn_mask.explicit_mask is not None: + raise NotImplementedError("Grug does not support explicit masks yet.") + if attn_mask.causal_offset is not None: + raise NotImplementedError("Grug does not support causal offsets yet.") + segment_ids = None + if attn_mask.segment_ids is not None: + q_seg, kv_seg = attn_mask.segment_ids + segment_ids = (q_seg.array, kv_seg.array) + mask = AttentionMask( + is_causal=attn_mask.is_causal, + segment_ids=segment_ids, + sliding_window=attn_mask.sliding_window, + ) + elif isinstance(attn_mask, NamedArray): + raise NotImplementedError( + "NamedArray attention masks are not supported by Grug (pass a Levanter AttentionMask)." + ) + return mask + + +class GrugWrapper(LmHeadModel[Any]): + """Minimal LmHeadModel wrapper around the standalone Grug transformer.""" + + params: PyTree + grug_config: GrugConfigLike + init_fn: GrugInitFn = eqx.field(static=True) + forward_fn: GrugForwardFn = eqx.field(static=True) + lm_head_fn: GrugLmHeadFn = eqx.field(static=True, default=_default_lm_head_fn) + + @property + def config(self) -> GrugConfigLike: + return self.grug_config + + @property + def Pos(self) -> Axis: + return Axis("position", self.grug_config.max_seq_len) + + @property + def KeyPos(self) -> Axis: + return self.Pos.alias("key_position") + + @property + def Vocab(self) -> Axis: + return Axis("vocab", self.grug_config.vocab_size) + + @property + def Embed(self) -> Axis: + return Axis("embed", self.grug_config.hidden_dim) + + @classmethod + def init( + cls, + Vocab: Axis, + config: GrugConfigLike, + *, + key: PRNGKeyArray, + init_fn: GrugInitFn | None = None, + forward_fn: GrugForwardFn | None = None, + lm_head_fn: GrugLmHeadFn | None = None, + ) -> "GrugWrapper": + cfg = config + chosen_init = init_fn or init_parameters + params = chosen_init(cfg, key=key) + return cls( + params=params, + grug_config=cfg, + init_fn=chosen_init, + forward_fn=forward_fn or grug_activations, + lm_head_fn=lm_head_fn or _default_lm_head_fn, + ) + + def activations( + self, + input_ids: NamedArray, + attn_mask: LevanterAttentionMask | NamedArray | None = None, + *, + key=None, + pos_ids: NamedArray | None = None, + ) -> NamedArray: + del key, pos_ids # unused in this lightweight wrapper + mask = _mask_from_levanter(attn_mask) + + hidden = self.forward_fn( + self.params, + input_ids.array, + self.grug_config, + mask=mask, + ) + + # Map raw hidden states to a NamedArray with the existing axes plus Embed. + axes = (*input_ids.axes, self.Embed) + return hax.named(hidden, axes) + + def compute_next_token_loss( + self, + example: LmExample, + *, + key=None, + reduction: hax.ReductionFunction | None = cast(hax.ReductionFunction | None, hax.mean), + reduction_axis: hax.AxisSelection | None = None, + logsumexp_weight: float | None = None, + loss_dtype: jnp.dtype | None = jnp.float32, + logit_soft_cap: float | None = None, + ) -> jnp.ndarray | NamedArray: + """Override to use grug's blockwise loss (avoids materializing full logits).""" + # NOTE: this wrapper is intentionally minimal; grug core currently doesn't use PRNGs. + del key + + # LmExample-ish protocol: expects `.tokens`, `.loss_weight`, `.attn_mask`. + tokens = example.tokens + loss_weight = example.loss_weight + attn_mask = example.attn_mask + + mask = _mask_from_levanter(attn_mask) + dtype = jnp.float32 if loss_dtype is None else loss_dtype + + if reduction is None: + per_pos = grug_loss_fn( + self.params, + tokens.array, + loss_weight.array, + self.grug_config, + mask=mask, + reduction="none", + logsumexp_weight=logsumexp_weight, + loss_dtype=dtype, + logit_soft_cap=logit_soft_cap, + ) + return hax.named(per_pos, tokens.axes) + + # Fast path: scalar mean/sum reduction over all axes. + if reduction_axis is None and reduction is hax.mean: + return grug_loss_fn( + self.params, + tokens.array, + loss_weight.array, + self.grug_config, + mask=mask, + reduction="mean", + logsumexp_weight=logsumexp_weight, + loss_dtype=dtype, + logit_soft_cap=logit_soft_cap, + ) + if reduction_axis is None and reduction is hax.sum: + return grug_loss_fn( + self.params, + tokens.array, + loss_weight.array, + self.grug_config, + mask=mask, + reduction="sum", + logsumexp_weight=logsumexp_weight, + loss_dtype=dtype, + logit_soft_cap=logit_soft_cap, + ) + + per_pos = grug_loss_fn( + self.params, + tokens.array, + loss_weight.array, + self.grug_config, + mask=mask, + reduction="none", + logsumexp_weight=logsumexp_weight, + loss_dtype=dtype, + logit_soft_cap=logit_soft_cap, + ) + loss = hax.named(per_pos, tokens.axes) + + return reduction(loss, axis=reduction_axis) + + def get_lm_head(self) -> NamedArray: + return hax.named(self.lm_head_fn(self.params), (self.Embed, self.Vocab)) + + def resize_vocab(self, new_size: int, key: PRNGKeyArray | None = None) -> "GrugWrapper": + raise NotImplementedError("GrugWrapper does not yet support resizing the vocabulary.") diff --git a/lib/levanter/src/levanter/models/loss.py b/lib/levanter/src/levanter/models/loss.py index 3dc54a4249..72efae1e24 100644 --- a/lib/levanter/src/levanter/models/loss.py +++ b/lib/levanter/src/levanter/models/loss.py @@ -302,9 +302,9 @@ def _block_cross_entropy_forward( num_blocks = vocab_size // block_size # Initialize accumulators: loss, logsumexp, max_logits - initial_O = hax.zeros(labels_y.axes) - initial_logsumexp = hax.full(labels_y.axes, -jnp.inf) - initial_max = hax.full(labels_y.axes, -jnp.inf) + initial_O = hax.auto_sharded(hax.zeros(labels_y.axes)) + initial_logsumexp = hax.auto_sharded(hax.full(labels_y.axes, -jnp.inf)) + initial_max = hax.auto_sharded(hax.full(labels_y.axes, -jnp.inf)) # We don't need this b/c we're using one-hot targets # initial_sumV = hax.full(labels_y.axes, 0.0) @@ -422,8 +422,8 @@ def _block_cross_entropy_backward( num_blocks = vocab_size // block_size - grad_embeddings = hax.zeros(pred_embeddings.axes, dtype=pred_embeddings.dtype) - grad_lm_head = hax.zeros(pred_lm_head.axes, dtype=pred_lm_head.dtype) + grad_embeddings = hax.zeros_like(pred_embeddings, dtype=pred_embeddings.dtype) + grad_lm_head = hax.zeros_like(pred_lm_head, dtype=pred_lm_head.dtype) def process_block(block_idx, acc, current_block_size): """ diff --git a/lib/levanter/tests/grug/test_grugformer.py b/lib/levanter/tests/grug/test_grugformer.py new file mode 100644 index 0000000000..961c4de2c4 --- /dev/null +++ b/lib/levanter/tests/grug/test_grugformer.py @@ -0,0 +1,46 @@ +# Copyright 2025 The Levanter Authors +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax + +from levanter.grug.main import GrugTrainingConfig +from levanter.grug.model import GrugModelConfig +from levanter.grug.main import run_training + + +def test_synthetic_training_step_runs(): + # On TPU, Grug uses Splash attention which requires KV sequence length to be a multiple of 128. + # `run_training` trains on tokens[:, :-1], so pick max_seq_len=129 -> token length 128. + max_seq_len = 129 if jax.default_backend() == "tpu" else 16 + cfg = GrugTrainingConfig( + model=GrugModelConfig( + vocab_size=257, + hidden_dim=64, + intermediate_dim=256, + num_layers=1, + num_heads=4, + num_kv_heads=4, + max_seq_len=max_seq_len, + ), + learning_rate=1e-3, + weight_decay=0.01, + steps=1, + global_batch_size=2, + seed=0, + ) + + run_training(cfg, cache_dir=None) diff --git a/lib/levanter/tests/grug/test_grugformer_compilation.py b/lib/levanter/tests/grug/test_grugformer_compilation.py new file mode 100644 index 0000000000..e354d35f87 --- /dev/null +++ b/lib/levanter/tests/grug/test_grugformer_compilation.py @@ -0,0 +1,76 @@ +# Copyright 2025 The Levanter Authors +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +import jax +import jax.numpy as jnp +from jax.sharding import AbstractMesh, AxisType, NamedSharding, PartitionSpec as P, use_abstract_mesh + +from jax._src import config as jax_config + +from levanter.grug.attention import AttentionMask +from levanter.grug.model import GrugModelConfig +from levanter.grug.sharding import Pbatch +from levanter.grug.model import init_parameters, loss_fn + + +def _make_abstract_mesh(*, data: int, model: int) -> AbstractMesh: + return AbstractMesh( + axis_sizes=(data, model), + axis_names=("data", "model"), + axis_types=(AxisType.Explicit, AxisType.Explicit), + ) + + +class _reset_abstract_mesh: + def __enter__(self): + self._prev = jax_config.abstract_mesh_context_manager.swap_local(jax_config.config_ext.unset) + return self + + def __exit__(self, exc_type, exc, tb): + jax_config.abstract_mesh_context_manager.set_local(self._prev) + return False + + +@pytest.mark.parametrize( + ("data", "model"), + [ + (4, 1), + (2, 2), + ], +) +def test_grug_loss_can_lower_on_abstract_4_device_mesh(data: int, model: int): + cfg = GrugModelConfig( + vocab_size=256, + hidden_dim=128, + intermediate_dim=256, + num_layers=1, + num_heads=8, + num_kv_heads=8, + max_seq_len=256, # splash attention requires KV seq multiple of 128 + cross_entropy_block_size=None, # single full-vocab block for faster compilation + ) + + mesh = _make_abstract_mesh(data=data, model=model) + # Some test setups establish a size-1 abstract mesh globally; JAX forbids changing the mesh + # size under `use_abstract_mesh`. Reset to "unset" so we can test other sizes here. + with _reset_abstract_mesh(), use_abstract_mesh(mesh): + # Build shaped params via eval_shape so we exercise init sharding rules under AbstractMesh. + key = jax.ShapeDtypeStruct(shape=(2,), dtype=jnp.uint32, sharding=NamedSharding(mesh, P())) + params = jax.eval_shape(lambda k: init_parameters(cfg, key=k), key) + + batch = 8 + seq = 256 + + def f(p): + token_ids = jnp.zeros((batch, seq), dtype=jnp.int32) + token_ids = jax.sharding.reshard(token_ids, Pbatch) + loss_weight = jnp.ones((batch, seq), dtype=jnp.float32) + loss_weight = jax.sharding.reshard(loss_weight, Pbatch) + return loss_fn(p, token_ids, loss_weight, cfg, mask=AttentionMask.causal(), reduction="mean") + + platform = jax.devices()[0].platform if jax.devices() else jax.default_backend() + lowered = jax.jit(f).trace(params).lower(lowering_platforms=(platform,)) + # Lowering is the point of this test; don't force full compilation. + assert lowered is not None diff --git a/lib/levanter/tests/grug/test_grugformer_core.py b/lib/levanter/tests/grug/test_grugformer_core.py new file mode 100644 index 0000000000..acd4eb5d5e --- /dev/null +++ b/lib/levanter/tests/grug/test_grugformer_core.py @@ -0,0 +1,244 @@ +# Copyright 2025 The Levanter Authors +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np + +import jax +import jax.numpy as jnp +import pytest +from jax.sharding import AxisType, Mesh, NamedSharding, PartitionSpec as P + +from levanter.grug.attention import AttentionMask, attention, reference_attention +from levanter.grug.model import GrugModelConfig +from levanter.grug.sharding import Pbatch +from levanter.grug.model import forward, init_parameters + + +def _make_grug_mesh() -> Mesh: + devices = jax.devices() + if not devices: + raise RuntimeError("No JAX devices available") + mesh_devices = np.array(devices).reshape(len(devices), 1) + return Mesh( + mesh_devices, + axis_names=("data", "model"), + axis_types=(AxisType.Explicit, AxisType.Explicit), + ) + + +def test_forward_shapes_and_jit_compile(): + # On TPU, Grug uses Splash attention which requires KV sequence length to be a multiple of 128. + seq = 128 if jax.default_backend() == "tpu" else 8 + cfg = GrugModelConfig( + vocab_size=101, + hidden_dim=32, + intermediate_dim=64, + num_layers=1, + num_heads=4, + num_kv_heads=4, + max_seq_len=seq, + ) + + mesh = _make_grug_mesh() + with jax.set_mesh(mesh): + params = init_parameters(cfg, key=jax.random.key(0)) + tokens = jax.random.randint(jax.random.key(1), (2, seq), 0, cfg.vocab_size) + + logits = forward(params, tokens, cfg, mask=AttentionMask.causal()) + assert logits.shape == (2, seq, cfg.vocab_size) + + jit_forward = jax.jit(forward, static_argnames=("cfg",)) + logits_jit = jit_forward(params, tokens, cfg, mask=AttentionMask.causal()) + assert logits_jit.shape == (2, seq, cfg.vocab_size) + + +def test_parameter_sharding_specs_are_named(): + cfg = GrugModelConfig( + vocab_size=101, + hidden_dim=32, + intermediate_dim=64, + num_layers=1, + num_heads=4, + num_kv_heads=4, + max_seq_len=16, + ) + + mesh = _make_grug_mesh() + with jax.set_mesh(mesh): + params = init_parameters(cfg, key=jax.random.key(0)) + + expected_vocab = P(None, None) + assert params.token_embed.sharding.spec == expected_vocab + assert params.output_proj.sharding.spec == expected_vocab + assert getattr(params.blocks[0].attn.w_q.sharding, "spec", None) == P("data", "model") + + +def test_full_like_preserves_sharding_under_mesh(): + mesh = _make_grug_mesh() + with jax.set_mesh(mesh): + sharding = NamedSharding(mesh, P(("data",), None)) + segment_ids = jax.device_put(jnp.array([[0, 0, 1, 1], [5, 5, 5, -1]], dtype=jnp.int32), sharding) + + batch_slice = segment_ids[:, 0] + init_last = jnp.full_like(batch_slice, jnp.int32(-2)) + init_pos = jnp.full_like(batch_slice, jnp.int32(-1)) + + assert getattr(batch_slice.sharding, "spec", None) == getattr(init_last.sharding, "spec", None) + assert getattr(batch_slice.sharding, "spec", None) == getattr(init_pos.sharding, "spec", None) + + +def test_attentionmask_materialize_causal(): + mask = AttentionMask.causal() + allowed = mask.materialize_mask(4, 4) + expected = jnp.array( + [ + [True, False, False, False], + [True, True, False, False], + [True, True, True, False], + [True, True, True, True], + ], + dtype=bool, + ) + assert allowed is not None + assert allowed.shape == (4, 4) + assert jnp.array_equal(allowed, expected) + + +def test_attentionmask_materialize_sliding_window_only(): + mask = AttentionMask(is_causal=False, sliding_window=1) + allowed = mask.materialize_mask(4, 4) + expected = jnp.array( + [ + [True, True, True, True], + [False, True, True, True], + [False, False, True, True], + [False, False, False, True], + ], + dtype=bool, + ) + assert allowed is not None + assert allowed.shape == (4, 4) + assert jnp.array_equal(allowed, expected) + + +def test_attentionmask_materialize_segment_ids_per_batch(): + q_seg = jnp.array([[0, 0, 1], [3, 3, 3]], dtype=jnp.int32) + k_seg = jnp.array([[0, 1, 1, 1], [3, 4, 3, 4]], dtype=jnp.int32) + mask = AttentionMask(is_causal=False, segment_ids=(q_seg, k_seg)) + allowed = mask.materialize_mask(3, 4) + expected = jnp.array( + [ + [ + [True, False, False, False], + [True, False, False, False], + [False, True, True, True], + ], + [ + [True, False, True, False], + [True, False, True, False], + [True, False, True, False], + ], + ], + dtype=bool, + ) + assert allowed is not None + assert allowed.shape == (2, 3, 4) + assert jnp.array_equal(allowed, expected) + + +@pytest.mark.parametrize("mode", ["causal", "causal_window", "causal_window_segments"]) +def test_blocksparse_attention_matches_reference_on_tiny_shapes(mode: str): + bs = len(jax.devices()) + seq = 128 if jax.default_backend() == "tpu" else 8 + + batch, heads, head_dim = bs, 2, 4 + # Keep logits in a reasonable range so this test checks semantics rather than softmax saturation + # differences between Splash and the reference path. + scale = 0.02 + q = jax.random.normal(jax.random.key(0), (batch, seq, heads, head_dim), dtype=jnp.float32) * scale + k = jax.random.normal(jax.random.key(1), (batch, seq, heads, head_dim), dtype=jnp.float32) * scale + v = jax.random.normal(jax.random.key(2), (batch, seq, heads, head_dim), dtype=jnp.float32) * scale + + mesh = _make_grug_mesh() + with jax.set_mesh(mesh): + if mode == "causal": + mask = AttentionMask.causal() + elif mode == "causal_window": + mask = AttentionMask.causal(sliding_window=3) + elif mode == "causal_window_segments": + # Segment every 16 tokens to test reset behavior while keeping Splash-compatible shapes on TPU. + segment_ids = (jnp.arange(seq, dtype=jnp.int32) // 16)[None, :] + segment_ids = jnp.repeat(segment_ids, repeats=bs, axis=0) + segment_ids = jax.sharding.reshard(segment_ids, Pbatch) + mask = AttentionMask.causal(sliding_window=3).with_segment_ids(segment_ids, segment_ids) + else: + raise AssertionError(f"unknown mode: {mode}") + + q, k, v = jax.sharding.reshard((q, k, v), Pbatch) + out_blocksparse = attention(q, k, v, mask) + out_ref = reference_attention(q, k, v, mask, logits_dtype=None) + + assert out_blocksparse.shape == out_ref.shape + if jax.default_backend() == "tpu": + # Splash attention has small numeric differences vs the reference path on TPU. + assert jnp.allclose(out_blocksparse, out_ref, rtol=1e-3, atol=1e-3) + else: + assert jnp.allclose(out_blocksparse, out_ref, rtol=1e-4, atol=1e-4) + + +def test_tpu_splash_attention_respects_causal_mask(): + if jax.default_backend() != "tpu": + pytest.skip("TPU only (Splash attention)") + + bs = len(jax.devices()) + seq = 128 + heads, head_dim = 2, 4 + + # Construct an adversarial setup: + # - make the last key have overwhelmingly high similarity to every query + # - set v_last to 1s, and all other v to 0s + # Then: + # - for q positions < last, causal mask forbids attending to the last key => output ~0 + # - for q position == last, attending to last is allowed => output ~1 + u = jnp.ones((head_dim,), dtype=jnp.float32) + q = jnp.broadcast_to(u, (bs, seq, heads, head_dim)) + k = jnp.zeros((bs, seq, heads, head_dim), dtype=jnp.float32).at[:, -1, :, :].set(u * 50.0) + v = jnp.zeros((bs, seq, heads, head_dim), dtype=jnp.float32).at[:, -1, :, :].set(1.0) + + mesh = _make_grug_mesh() + with jax.set_mesh(mesh): + q, k, v = jax.sharding.reshard((q, k, v), Pbatch) + out = attention(q, k, v, AttentionMask.causal()) + + assert out.shape == (bs, seq, heads, head_dim) + # Early positions can't see the last key, so output should be ~0. + assert jnp.max(jnp.abs(out[:, :-1, :, :])) < 1e-3 + # Last position can see the last key; output should be very close to 1. + assert jnp.min(out[:, -1, :, :]) > 0.999 + + +def test_tpu_splash_attention_respects_sliding_window(): + if jax.default_backend() != "tpu": + pytest.skip("TPU only (Splash attention)") + + bs = len(jax.devices()) + seq = 128 + heads, head_dim = 2, 4 + + # Standard sliding window semantics: W tokens including self. + # For q position W, keys < 1 are outside the window and must be masked. + W = 3 + u = jnp.ones((head_dim,), dtype=jnp.float32) + q = jnp.broadcast_to(u, (bs, seq, heads, head_dim)) + k = jnp.zeros((bs, seq, heads, head_dim), dtype=jnp.float32).at[:, 0, :, :].set(u * 50.0) + v = jnp.zeros((bs, seq, heads, head_dim), dtype=jnp.float32).at[:, 0, :, :].set(1.0) + + mesh = _make_grug_mesh() + with jax.set_mesh(mesh): + q, k, v = jax.sharding.reshard((q, k, v), Pbatch) + out = attention(q, k, v, AttentionMask.causal(sliding_window=W)) + + # q at positions < W still include k=0 in their window, so output ~1 (and causal allows it). + assert jnp.min(out[:, :W, :, :]) > 0.999 + # q at position == W cannot see k=0 (outside window), so output should drop to ~0. + assert jnp.max(jnp.abs(out[:, W:, :, :])) < 1e-3 diff --git a/lib/levanter/tests/grug/test_grugformer_fused_loss.py b/lib/levanter/tests/grug/test_grugformer_fused_loss.py new file mode 100644 index 0000000000..30bf95e156 --- /dev/null +++ b/lib/levanter/tests/grug/test_grugformer_fused_loss.py @@ -0,0 +1,115 @@ +# Copyright 2025 The Levanter Authors +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np + +import jax +import jax.numpy as jnp +from jax.sharding import AxisType, Mesh + +from levanter.grug.loss import linear_softmax_cross_entropy_loss_and_logz + + +def _make_grug_mesh() -> Mesh: + devices = jax.devices() + if not devices: + raise RuntimeError("No JAX devices available") + # We only require a mesh context here so the loss can provide `out_sharding=...`. + mesh_devices = np.array(devices).reshape(len(devices), 1) + return Mesh( + mesh_devices, + axis_names=("data", "model"), + axis_types=(AxisType.Explicit, AxisType.Explicit), + ) + + +def _full_loss_and_logz( + hidden: jax.Array, lm_head: jax.Array, labels: jax.Array, *, precision: jax.lax.PrecisionLike = None +) -> tuple[jax.Array, jax.Array]: + logits = jax.lax.dot_general( + hidden, + lm_head, + dimension_numbers=(((hidden.ndim - 1,), (0,)), ((), ())), + precision=precision, + ) + log_probs = jax.nn.log_softmax(logits, axis=-1) + logz = jax.scipy.special.logsumexp(logits, axis=-1) + nll = -jnp.take_along_axis(log_probs, labels[..., None], axis=-1)[..., 0] + return nll, logz + + +def test_linear_softmax_cross_entropy_matches_full_logits(): + key = jax.random.key(0) + b, s, h, v = 2, 5, 8, 17 + hidden = jax.random.normal(key, (b, s, h), dtype=jnp.float32) + lm_head = jax.random.normal(jax.random.key(1), (h, v), dtype=jnp.float32) + labels = jax.random.randint(jax.random.key(2), (b, s), 0, v, dtype=jnp.int32) + + mesh = _make_grug_mesh() + with jax.set_mesh(mesh): + loss_full, logz_full = jax.jit(_full_loss_and_logz, static_argnames=("precision",))( + hidden, lm_head, labels, precision=jax.lax.Precision.HIGHEST + ) + loss_blk, logz_blk = jax.jit( + lambda x, w, y: linear_softmax_cross_entropy_loss_and_logz( + x, w, y, block_size=6, precision=jax.lax.Precision.HIGHEST + ) + )(hidden, lm_head, labels) + + assert loss_blk.shape == loss_full.shape + assert logz_blk.shape == logz_full.shape + # On TPU, the streaming logsumexp (blockwise) can differ from the full logsumexp due to + # different associativity/rounding behavior. We use HIGHEST matmul precision above to + # keep this fairly tight. + if jax.default_backend() == "tpu": + assert jnp.allclose(loss_blk, loss_full, atol=5e-3, rtol=5e-3) + assert jnp.allclose(logz_blk, logz_full, atol=5e-3, rtol=5e-3) + else: + assert jnp.allclose(loss_blk, loss_full, atol=1e-4, rtol=1e-4) + assert jnp.allclose(logz_blk, logz_full, atol=1e-4, rtol=1e-4) + + +def test_linear_softmax_cross_entropy_jittable(): + key = jax.random.key(0) + b, s, h, v = 2, 3, 8, 11 + hidden = jax.random.normal(key, (b, s, h), dtype=jnp.float32) + lm_head = jax.random.normal(jax.random.key(1), (h, v), dtype=jnp.float32) + labels = jax.random.randint(jax.random.key(2), (b, s), 0, v, dtype=jnp.int32) + + mesh = _make_grug_mesh() + with jax.set_mesh(mesh): + fn = jax.jit( + lambda x, w, y: linear_softmax_cross_entropy_loss_and_logz( + x, w, y, block_size=4, precision=jax.lax.Precision.HIGHEST + ) + ) + loss, logz = fn(hidden, lm_head, labels) + assert loss.shape == (b, s) + assert logz.shape == (b, s) + + +def test_linear_softmax_cross_entropy_grad_matches_full(): + key = jax.random.key(0) + b, s, h, v = 2, 3, 8, 13 + hidden = jax.random.normal(key, (b, s, h), dtype=jnp.float32) + lm_head = jax.random.normal(jax.random.key(1), (h, v), dtype=jnp.float32) + labels = jax.random.randint(jax.random.key(2), (b, s), 0, v, dtype=jnp.int32) + + def loss_full_fn(x): + loss, _ = _full_loss_and_logz(x, lm_head, labels, precision=jax.lax.Precision.HIGHEST) + return jnp.mean(loss) + + def loss_blk_fn(x): + loss, _ = linear_softmax_cross_entropy_loss_and_logz( + x, lm_head, labels, block_size=5, precision=jax.lax.Precision.HIGHEST + ) + return jnp.mean(loss) + + mesh = _make_grug_mesh() + with jax.set_mesh(mesh): + g_full = jax.grad(loss_full_fn)(hidden) + g_blk = jax.grad(loss_blk_fn)(hidden) + if jax.default_backend() == "tpu": + assert jnp.allclose(g_blk, g_full, atol=5e-4, rtol=5e-3) + else: + assert jnp.allclose(g_blk, g_full, atol=1e-4, rtol=1e-4) diff --git a/lib/levanter/tests/grug/test_grugformer_model_loss.py b/lib/levanter/tests/grug/test_grugformer_model_loss.py new file mode 100644 index 0000000000..6d21a02926 --- /dev/null +++ b/lib/levanter/tests/grug/test_grugformer_model_loss.py @@ -0,0 +1,101 @@ +# Copyright 2025 The Levanter Authors +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np + +import haliax as hax +import jax +import jax.numpy as jnp +from jax.sharding import AxisType, Mesh + +from levanter.grug.attention import AttentionMask +from levanter.grug.model import GrugModelConfig +from levanter.grug.model import activations, init_parameters, loss_fn +from levanter.layers.attention import AttentionMask as LevanterAttentionMask +from levanter.models.grug_wrapper import GrugWrapper +from levanter.models.lm_model import LmExample + + +def _make_grug_mesh() -> Mesh: + devices = jax.devices() + if not devices: + raise RuntimeError("No JAX devices available") + mesh_devices = np.array(devices).reshape(1, 1, 1, len(devices)) + return Mesh( + mesh_devices, + axis_names=("replica_dcn", "replica", "data", "model"), + axis_types=(AxisType.Explicit, AxisType.Explicit, AxisType.Explicit, AxisType.Explicit), + ) + + +def _full_next_token_loss(logits: jax.Array, token_ids: jax.Array, loss_weight: jax.Array) -> jax.Array: + labels = jnp.concatenate([token_ids[:, 1:], token_ids[:, :1] * 0], axis=1).astype(jnp.int32) + log_probs = jax.nn.log_softmax(logits, axis=-1) + nll = -jnp.take_along_axis(log_probs, labels[..., None], axis=-1)[..., 0] + nll = nll * loss_weight + denom = jnp.sum(loss_weight) + return jnp.sum(nll) / jnp.maximum(denom, jnp.array(1.0, dtype=nll.dtype)) + + +def test_grug_model_loss_fn_matches_full_logits(): + # On TPU, Grug uses Splash attention which requires KV sequence length to be a multiple of 128. + seq = 128 if jax.default_backend() == "tpu" else 9 + cfg = GrugModelConfig( + vocab_size=23, + hidden_dim=16, + intermediate_dim=32, + num_layers=1, + num_heads=4, + num_kv_heads=4, + max_seq_len=seq, + cross_entropy_block_size=8, + ) + + mesh = _make_grug_mesh() + with jax.set_mesh(mesh): + params = init_parameters(cfg, key=jax.random.key(0)) + token_ids = jax.random.randint(jax.random.key(1), (2, seq), 0, cfg.vocab_size, dtype=jnp.int32) + loss_weight = jnp.ones((2, seq), dtype=jnp.float32).at[:, -1].set(0.0) + + hidden = activations(params, token_ids, cfg, mask=AttentionMask.causal()) + logits = hidden @ params.output_proj + + ref = _full_next_token_loss(logits, token_ids, loss_weight) + got = loss_fn(params, token_ids, loss_weight, cfg, mask=AttentionMask.causal(), reduction="mean") + + assert jnp.allclose(got, ref, atol=1e-4, rtol=1e-4) + + +def test_grug_wrapper_compute_next_token_loss_uses_grug_loss_fn(): + # On TPU, Grug uses Splash attention which requires KV sequence length to be a multiple of 128. + seq = 128 if jax.default_backend() == "tpu" else 9 + cfg = GrugModelConfig( + vocab_size=29, + hidden_dim=16, + intermediate_dim=32, + num_layers=1, + num_heads=4, + num_kv_heads=4, + max_seq_len=seq, + cross_entropy_block_size=8, + ) + + mesh = _make_grug_mesh() + with jax.set_mesh(mesh): + Vocab = hax.Axis("vocab", cfg.vocab_size) + model = GrugWrapper.init(Vocab, cfg, key=jax.random.key(0)) + + Batch = hax.Axis("batch", 2) + Pos = hax.Axis("position", seq) + token_ids = hax.random.randint(jax.random.key(1), (Batch, Pos), 0, cfg.vocab_size, dtype=jnp.int32) + loss_weight = hax.ones((Batch, Pos), dtype=jnp.float32).at[Pos, Pos.size - 1].set(0.0) + example = LmExample(tokens=token_ids, loss_weight=loss_weight, attn_mask=LevanterAttentionMask.causal()) + + per_pos = model.compute_next_token_loss(example, reduction=None, reduction_axis=()) + assert isinstance(per_pos, hax.NamedArray) + assert per_pos.axes == token_ids.axes + + expected = loss_fn( + model.params, token_ids.array, loss_weight.array, cfg, mask=AttentionMask.causal(), reduction="none" + ) + assert jnp.allclose(per_pos.array, expected, atol=1e-4, rtol=1e-4) diff --git a/lib/marin/src/marin/run/ray_run.py b/lib/marin/src/marin/run/ray_run.py index 10562c20c2..97e28ff23c 100644 --- a/lib/marin/src/marin/run/ray_run.py +++ b/lib/marin/src/marin/run/ray_run.py @@ -146,7 +146,7 @@ async def submit_and_track_job( runtime_dict = { "working_dir": current_dir, "config": {"setup_timeout_seconds": 1800}, - "excludes": [".git", "tests/", "docs/", "**/*.pack", "lib/levanter/docs"], + "excludes": [".git", "docs/", "**/*.pack", "lib/levanter/docs"], } # add the TPU dependency for cluster jobs. diff --git a/pyproject.toml b/pyproject.toml index 9f09b8c963..c988a0495f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -184,4 +184,4 @@ markers = [ testpaths = ["tests", "experiments"] # Make sure we timeout before CI kills us, and don't run TPU or slow tests by default -addopts = "--session-timeout=480 -m 'not tpu_ci and not slow'" +#addopts = "--session-timeout=480 -m 'not tpu_ci and not slow'" diff --git a/uv.lock b/uv.lock index 1b534e0f76..8c9ddb594d 100644 --- a/uv.lock +++ b/uv.lock @@ -3641,6 +3641,7 @@ dependencies = [ { name = "datasets" }, { name = "deepdiff" }, { name = "draccus" }, + { name = "einops" }, { name = "equinox" }, { name = "filelock" }, { name = "fray" }, @@ -3745,6 +3746,7 @@ requires-dist = [ { name = "datasets", specifier = ">=3.1.0,<5.0" }, { name = "deepdiff" }, { name = "draccus", specifier = ">=0.11.5" }, + { name = "einops" }, { name = "equinox", specifier = ">=0.11.7,!=0.12.0" }, { name = "fastapi", marker = "extra == 'serve'", specifier = ">=0.100.0" }, { name = "filelock", specifier = "~=3.13" },