Skip to content

Conversation

@dlwh
Copy link
Member

@dlwh dlwh commented Dec 5, 2025

This PR introduces Grugformer: a “grug-simple” JAX LM implementation that leans into explicit sharding and top-level functions rather than heavy abstractions. It adds a minimal core (levanter.grug) plus a small adapter (levanter.models.grug_wrapper) so it can run through the existing Levanter trainer pipeline, and it includes speedrun entrypoints + tests that lock down the intended “grug core surface”.

What’s Included

New Grug core (minimal, notebook-like)

  • New package: lib/levanter/src/levanter/grug/
    • attention.py: Grug-local AttentionMask spec + attention implementation (TPU Splash when on TPU; reference fallback otherwise).
    • model.py: parameter dataclasses + init/forward/activations/loss functions.
    • loss.py: blockwise “large vocab friendly” CE path (avoid full logits materialization; see note below on tradeoffs).
    • data.py, main.py: minimal training/data wiring to run in-repo.
  • Exported surface is intentionally small (functions + dataclasses; minimal mutation).

Levanter adapter

  • lib/levanter/src/levanter/models/grug_wrapper.py: wraps grug core behind Levanter’s LmConfig/trainer expectations while keeping the core itself free of NamedArray-heavy abstractions.

Speedruns / templates

  • experiments/speedrun/grugformer_starter/grugformer_speedrun.py: a grug speedrun template for quick iteration.
  • experiments/speedrun/grugformer_attnsink/grugformer_attn_sink.py: “hackable” grug attention-sink variant (copy/paste edit surface).
  • experiments/speedrun/grugformer_vs_hackable_125m/grugformer_vs_hackable_125m.py: head-to-head comparison (Hackable Transformer vs Grugformer, no sinks). Hackable path runs without explicit mesh axes for now.

Tests (lock the “grug core surface”)

  • All Grug tests live under lib/levanter/tests/grug/:
    • test_grugformer_core.py: core API + mesh/sharding sanity.
    • test_grugformer_model_loss.py: loss correctness vs full logits on small shapes; wrapper plumbing.
    • test_grugformer_fused_loss.py: loss-related regression coverage.
    • test_grugformer_compilation.py: lowers/jit-traces model+loss under AbstractMesh (no concrete devices required).
    • test_grugformer.py: higher-level smoke coverage (tiny synthetic step).

Documentation

  • .agents/projects/grugformer.md: principles, intended edit surface, and follow-ups.
  • docs/recipes/change_grug.md: workflow for proposing changes (speedrun edit surface → adopt into canonical grug → archive old experiments).
  • docs/reports/grug-archive.md: lightweight “experiment archive log” placeholder so we have somewhere to record removals/renames as grug evolves.

Notable Design Choices / Current Constraints

  • Attention: TPU path uses Splash attention directly; GPU path uses the reference fallback for now.
  • Loss: large-vocab CE is more painful than we’d like under explicit-sharding; we currently use a blockwise “flash-attention style” transform. The block-size knob is intentionally exposed; we’ve observed meaningful perf sensitivity and will likely revisit this with a better kernel later.

How To Try

  • Run the h2h speedrun:
    • python -m experiments.speedrun.grugformer_vs_hackable_125m.grugformer_vs_hackable_125m
    • Set SR_USE_TPU=1 to use TPU preset.
  • Run tests:
    • uv run pytest lib/levanter/tests/grug -q

Follow-ups

  • Implement a faster large-vocab CE path that’s robust under explicit sharding (avoids the current speed/memory tradeoff).
  • Expand the speedrun “gauntlet” checks and add more minimal “edit points” for experiments.

@github-actions
Copy link
Contributor

This pull request has been inactive for 23 days and is marked as stale.
If there is no further activity within 7 days, it will be automatically closed.
If you believe this PR should remain open, please add a comment or update the PR.

@github-actions github-actions bot added the stale label Dec 29, 2025
@dlwh
Copy link
Member Author

dlwh commented Dec 29, 2025

bump

@github-actions github-actions bot removed the stale label Dec 30, 2025
@dlwh dlwh marked this pull request as ready for review January 10, 2026 07:12
Copilot AI review requested due to automatic review settings January 10, 2026 07:12
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces Grugformer: a "grug-simple" JAX LM implementation emphasizing explicit sharding and top-level functions over heavy abstractions. The implementation provides a minimal core in levanter.grug, an adapter for integration with Levanter's trainer (levanter.models.grug_wrapper), speedrun entrypoints for experimentation, and comprehensive tests.

Changes:

  • Adds new levanter.grug package with model, attention, loss, data, and config modules
  • Provides GrugWrapper adapter to integrate with Levanter's LmHeadModel interface
  • Includes three speedrun experiments: starter template, attention sink variant, and head-to-head comparison with Hackable Transformer
  • Adds comprehensive test suite locking down the core API surface
  • Documents design principles, change workflow, and experiment archiving strategy

Reviewed changes

Copilot reviewed 24 out of 26 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
uv.lock Adds einops dependency and updates various package versions
lib/levanter/pyproject.toml Adds einops to dependencies and TPU test marker
lib/levanter/src/levanter/grug/*.py Core grug implementation: model, attention, loss, data, config
lib/levanter/src/levanter/models/grug_wrapper.py Adapter bridging grug core to Levanter's LmHeadModel API
lib/levanter/src/levanter/models/loss.py Updates to use auto_sharded and zeros_like for better sharding
lib/levanter/tests/grug/*.py Comprehensive test suite for grug core functionality
experiments/speedrun/grugformer_/.py Three speedrun experiments showcasing grug usage
docs/*.md Documentation for principles, workflow, and experiment archiving


import equinox as eqx
import jax
import haliax as hax
Copy link

Copilot AI Jan 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Module 'haliax' is imported with both 'import' and 'import from'.

Copilot uses AI. Check for mistakes.

# nodryrun

import dataclasses
Copy link

Copilot AI Jan 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Module 'dataclasses' is imported with both 'import' and 'import from'.

Copilot uses AI. Check for mistakes.
Comment on lines +4 to +5
import dataclasses

Copy link

Copilot AI Jan 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Module 'dataclasses' is imported with both 'import' and 'import from'.

Suggested change
import dataclasses

Copilot uses AI. Check for mistakes.
@pc0618
Copy link
Contributor

pc0618 commented Jan 11, 2026

Pushed fix for TPU Splash attention crashing during init on tracers (fallback to when is unavailable) + removed unsupported arg from the Grugformer speedrun wrapper. Commit: 0ee618f.

@pc0618
Copy link
Contributor

pc0618 commented Jan 11, 2026

Follow-up (previous comment had shell quoting issues): fix uses x.sharding when available and falls back to x.aval.sharding for tracers during staging; also stops passing tie_embeddings into GrugModelConfig (it is kept only for param counting). Commit: 0ee618f.

@pc0618
Copy link
Contributor

pc0618 commented Jan 12, 2026

Added an inline note + refactor in levanter/grug/model.py:init_parameters to use hierarchical key splitting instead of (3 + 7 * num_layers) “magic number” math (more robust to future parameter additions). Commit: b9756b3. Also left a TODO in-code to add a brief explanation in the PR discussion later.

Copy link
Contributor

@ravwojdyla ravwojdyla left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome! I may be a little aggressive with the comments to delete "unused" logic/options or reduce number of files - this is mostly in spirit of karpathy-ish code 🙇 There's a couple of logic questions in here. Some nits as well, like the __all__, which I dislike 1.

Footnotes

  1. I prefer to protected-ish _, but if marin has a policy on __all__ I'm happy to adjust.


### CLI Entrypoint

- `src/marin/grugpt/train.py` implements `def main(argv=None): ...` using `argparse` (no click). Steps:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Outdated path? Probably lib/levanter/src/levanter/grug/main.py?

global_batch_size: int = 8


__all__ = ["GrugModelConfig", "GrugTrainingConfig"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the __all__?

Levanter integration adapters live under `levanter.models`.
"""

from .attention import apply_rotary_embedding, attention
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the sake of simplicity - do we need explicit __all__ export? Could this be just a plain init file?



@dataclass(frozen=True)
class GrugTrainingConfig:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To reduce number of files - could this live in main.py (the only place apart from tests where this is used right now).

if global_batch_size is not None and global_batch_size > 0:
while data_size > 1 and global_batch_size % data_size != 0:
data_size -= 1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we log/print warning here if final data_size is less than len(devices)?

testpaths = ["tests", "experiments"]

# Make sure we timeout before CI kills us, and don't run TPU or slow tests by default
addopts = "--session-timeout=480 -m 'not tpu_ci and not slow'"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intentional?

runtime_dict = {
"working_dir": current_dir,
"config": {"setup_timeout_seconds": 1800},
"excludes": [".git", "tests/", "docs/", "**/*.pack", "lib/levanter/docs"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ what is the purpose of this change?

token_embed=token_embed,
output_proj=output_proj,
blocks=tuple(blocks),
final_norm=jnp.ones_like(final_norm),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ isn't it already ones?

num_kv_heads: int = 16
head_dim: int | None = None
max_seq_len: int = 4096
dropout_rate: float = 0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ where is this used?

Comment on lines +205 to +206
elif isinstance(mask, AttentionMask) and not mask.is_causal:
mask = dataclasses.replace(mask, is_causal=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ what's the intentions behind these 2 lines?

@ravwojdyla
Copy link
Contributor

FYI when I run the starter speedrun (130M only) in us-central1 on TPU (v5p-8), I get OOM:

Total hbm usage >= 101.99G:
    reserved        263.00M
    program         101.73G
    arguments            0B

I can work around this but I wonder if that was supposed to work?


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run the Grug trainer.")
parser.add_argument("--cache-dir", type=str, default=None, help="Optional TreeCache directory for real data.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make it simpler to run main.py in isolation without depending on TreeCache or synthetic data? I.e. point it at dir (object store comp) dump of some canonical dataset, e.g. OpenWebText, Fineweb or TinyStories even?

@ravwojdyla ravwojdyla mentioned this pull request Jan 16, 2026
Copy link
Contributor

@ravwojdyla ravwojdyla left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some more comments from experiments

num_kv_heads=self.num_kv_heads,
head_dim=self.head_dim,
max_seq_len=self.max_seq_len,
tie_embeddings=self.tie_embeddings,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This flag tie_embeddings, doesn't exist in GrugModelConfig (anymore?). I assume it can be completely removed from this experiment?

Comment on lines +111 to +113
# Grug core currently always has separate token embed + output projection; keep this knob
# for param counting / compatibility with other LmConfig-based scripts.
tie_embeddings: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be more intuitive to not expose this config and instead hard code logic in total_trainable_params? Otherwise it may seem like this flag does something, when it doesn't?

num_kv_heads=self.num_kv_heads,
head_dim=self.head_dim,
max_seq_len=self.max_seq_len,
tie_embeddings=self.tie_embeddings,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tie_embeddings doesn't exist in GrugModelConfig (see other comment)

labels = jnp.concatenate([token_ids[:, 1:], token_ids[:, :1] * 0], axis=1).astype(jnp.int32)
loss_weight = loss_weight.astype(loss_dtype)

# NOTE: `block_size=None` corresponds to a single full-vocab block. On the 125M speedrun,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#2315 (comment) ptal, I can't reproduce this 🙏

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw setting cross_entropy_block_size to say ~32k on v5p-8 OOMs in 125M experiment.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants