-
Notifications
You must be signed in to change notification settings - Fork 71
Grug #2171
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
# Conflicts: # lib/levanter/src/levanter/data/mixture.py # lib/marin/pyproject.toml # pyproject.toml # uv.lock
|
This pull request has been inactive for 23 days and is marked as stale. |
|
bump |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR introduces Grugformer: a "grug-simple" JAX LM implementation emphasizing explicit sharding and top-level functions over heavy abstractions. The implementation provides a minimal core in levanter.grug, an adapter for integration with Levanter's trainer (levanter.models.grug_wrapper), speedrun entrypoints for experimentation, and comprehensive tests.
Changes:
- Adds new
levanter.grugpackage with model, attention, loss, data, and config modules - Provides
GrugWrapperadapter to integrate with Levanter'sLmHeadModelinterface - Includes three speedrun experiments: starter template, attention sink variant, and head-to-head comparison with Hackable Transformer
- Adds comprehensive test suite locking down the core API surface
- Documents design principles, change workflow, and experiment archiving strategy
Reviewed changes
Copilot reviewed 24 out of 26 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| uv.lock | Adds einops dependency and updates various package versions |
| lib/levanter/pyproject.toml | Adds einops to dependencies and TPU test marker |
| lib/levanter/src/levanter/grug/*.py | Core grug implementation: model, attention, loss, data, config |
| lib/levanter/src/levanter/models/grug_wrapper.py | Adapter bridging grug core to Levanter's LmHeadModel API |
| lib/levanter/src/levanter/models/loss.py | Updates to use auto_sharded and zeros_like for better sharding |
| lib/levanter/tests/grug/*.py | Comprehensive test suite for grug core functionality |
| experiments/speedrun/grugformer_/.py | Three speedrun experiments showcasing grug usage |
| docs/*.md | Documentation for principles, workflow, and experiment archiving |
|
|
||
| import equinox as eqx | ||
| import jax | ||
| import haliax as hax |
Copilot
AI
Jan 10, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Module 'haliax' is imported with both 'import' and 'import from'.
|
|
||
| # nodryrun | ||
|
|
||
| import dataclasses |
Copilot
AI
Jan 10, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Module 'dataclasses' is imported with both 'import' and 'import from'.
| import dataclasses | ||
|
|
Copilot
AI
Jan 10, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Module 'dataclasses' is imported with both 'import' and 'import from'.
| import dataclasses |
|
Pushed fix for TPU Splash attention crashing during init on tracers (fallback to when is unavailable) + removed unsupported arg from the Grugformer speedrun wrapper. Commit: 0ee618f. |
|
Follow-up (previous comment had shell quoting issues): fix uses |
|
Added an inline note + refactor in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is awesome! I may be a little aggressive with the comments to delete "unused" logic/options or reduce number of files - this is mostly in spirit of karpathy-ish code 🙇 There's a couple of logic questions in here. Some nits as well, like the __all__, which I dislike 1.
Footnotes
-
I prefer to protected-ish
_, but if marin has a policy on__all__I'm happy to adjust. ↩
|
|
||
| ### CLI Entrypoint | ||
|
|
||
| - `src/marin/grugpt/train.py` implements `def main(argv=None): ...` using `argparse` (no click). Steps: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Outdated path? Probably lib/levanter/src/levanter/grug/main.py?
| global_batch_size: int = 8 | ||
|
|
||
|
|
||
| __all__ = ["GrugModelConfig", "GrugTrainingConfig"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need the __all__?
| Levanter integration adapters live under `levanter.models`. | ||
| """ | ||
|
|
||
| from .attention import apply_rotary_embedding, attention |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the sake of simplicity - do we need explicit __all__ export? Could this be just a plain init file?
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class GrugTrainingConfig: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To reduce number of files - could this live in main.py (the only place apart from tests where this is used right now).
| if global_batch_size is not None and global_batch_size > 0: | ||
| while data_size > 1 and global_batch_size % data_size != 0: | ||
| data_size -= 1 | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we log/print warning here if final data_size is less than len(devices)?
| testpaths = ["tests", "experiments"] | ||
|
|
||
| # Make sure we timeout before CI kills us, and don't run TPU or slow tests by default | ||
| addopts = "--session-timeout=480 -m 'not tpu_ci and not slow'" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this intentional?
| runtime_dict = { | ||
| "working_dir": current_dir, | ||
| "config": {"setup_timeout_seconds": 1800}, | ||
| "excludes": [".git", "tests/", "docs/", "**/*.pack", "lib/levanter/docs"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❓ what is the purpose of this change?
| token_embed=token_embed, | ||
| output_proj=output_proj, | ||
| blocks=tuple(blocks), | ||
| final_norm=jnp.ones_like(final_norm), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❓ isn't it already ones?
| num_kv_heads: int = 16 | ||
| head_dim: int | None = None | ||
| max_seq_len: int = 4096 | ||
| dropout_rate: float = 0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❓ where is this used?
| elif isinstance(mask, AttentionMask) and not mask.is_causal: | ||
| mask = dataclasses.replace(mask, is_causal=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❓ what's the intentions behind these 2 lines?
|
FYI when I run the starter speedrun (130M only) in us-central1 on TPU (v5p-8), I get OOM: I can work around this but I wonder if that was supposed to work? |
|
|
||
| def parse_args() -> argparse.Namespace: | ||
| parser = argparse.ArgumentParser(description="Run the Grug trainer.") | ||
| parser.add_argument("--cache-dir", type=str, default=None, help="Optional TreeCache directory for real data.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we make it simpler to run main.py in isolation without depending on TreeCache or synthetic data? I.e. point it at dir (object store comp) dump of some canonical dataset, e.g. OpenWebText, Fineweb or TinyStories even?
ravwojdyla
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some more comments from experiments
| num_kv_heads=self.num_kv_heads, | ||
| head_dim=self.head_dim, | ||
| max_seq_len=self.max_seq_len, | ||
| tie_embeddings=self.tie_embeddings, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This flag tie_embeddings, doesn't exist in GrugModelConfig (anymore?). I assume it can be completely removed from this experiment?
| # Grug core currently always has separate token embed + output projection; keep this knob | ||
| # for param counting / compatibility with other LmConfig-based scripts. | ||
| tie_embeddings: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't it be more intuitive to not expose this config and instead hard code logic in total_trainable_params? Otherwise it may seem like this flag does something, when it doesn't?
| num_kv_heads=self.num_kv_heads, | ||
| head_dim=self.head_dim, | ||
| max_seq_len=self.max_seq_len, | ||
| tie_embeddings=self.tie_embeddings, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tie_embeddings doesn't exist in GrugModelConfig (see other comment)
| labels = jnp.concatenate([token_ids[:, 1:], token_ids[:, :1] * 0], axis=1).astype(jnp.int32) | ||
| loss_weight = loss_weight.astype(loss_dtype) | ||
|
|
||
| # NOTE: `block_size=None` corresponds to a single full-vocab block. On the 125M speedrun, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❓ #2315 (comment) ptal, I can't reproduce this 🙏
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw setting cross_entropy_block_size to say ~32k on v5p-8 OOMs in 125M experiment.
This PR introduces Grugformer: a “grug-simple” JAX LM implementation that leans into explicit sharding and top-level functions rather than heavy abstractions. It adds a minimal core (
levanter.grug) plus a small adapter (levanter.models.grug_wrapper) so it can run through the existing Levanter trainer pipeline, and it includes speedrun entrypoints + tests that lock down the intended “grug core surface”.What’s Included
New Grug core (minimal, notebook-like)
lib/levanter/src/levanter/grug/attention.py: Grug-localAttentionMaskspec + attention implementation (TPU Splash when on TPU; reference fallback otherwise).model.py: parameter dataclasses + init/forward/activations/loss functions.loss.py: blockwise “large vocab friendly” CE path (avoid full logits materialization; see note below on tradeoffs).data.py,main.py: minimal training/data wiring to run in-repo.Levanter adapter
lib/levanter/src/levanter/models/grug_wrapper.py: wraps grug core behind Levanter’sLmConfig/trainer expectations while keeping the core itself free of NamedArray-heavy abstractions.Speedruns / templates
experiments/speedrun/grugformer_starter/grugformer_speedrun.py: a grug speedrun template for quick iteration.experiments/speedrun/grugformer_attnsink/grugformer_attn_sink.py: “hackable” grug attention-sink variant (copy/paste edit surface).experiments/speedrun/grugformer_vs_hackable_125m/grugformer_vs_hackable_125m.py: head-to-head comparison (Hackable Transformer vs Grugformer, no sinks). Hackable path runs without explicit mesh axes for now.Tests (lock the “grug core surface”)
lib/levanter/tests/grug/:test_grugformer_core.py: core API + mesh/sharding sanity.test_grugformer_model_loss.py: loss correctness vs full logits on small shapes; wrapper plumbing.test_grugformer_fused_loss.py: loss-related regression coverage.test_grugformer_compilation.py: lowers/jit-traces model+loss underAbstractMesh(no concrete devices required).test_grugformer.py: higher-level smoke coverage (tiny synthetic step).Documentation
.agents/projects/grugformer.md: principles, intended edit surface, and follow-ups.docs/recipes/change_grug.md: workflow for proposing changes (speedrun edit surface → adopt into canonical grug → archive old experiments).docs/reports/grug-archive.md: lightweight “experiment archive log” placeholder so we have somewhere to record removals/renames as grug evolves.Notable Design Choices / Current Constraints
How To Try
python -m experiments.speedrun.grugformer_vs_hackable_125m.grugformer_vs_hackable_125mSR_USE_TPU=1to use TPU preset.uv run pytest lib/levanter/tests/grug -qFollow-ups