Skip to content
Open

Grug #2171

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
85 commits
Select commit Hold shift + click to select a range
0f23fb2
grugpt!
dlwh Nov 10, 2025
0c530fd
attention in its own file
dlwh Nov 10, 2025
c397bce
project description
dlwh Nov 10, 2025
72d3a8a
tweak
dlwh Nov 10, 2025
629217a
basic data loading works
dlwh Nov 10, 2025
2222843
wip
dlwh Nov 11, 2025
6963ba4
Merge remote-tracking branch 'origin/main' into grug
dlwh Dec 11, 2025
797e42e
move grug into levanter so we can use it there
dlwh Dec 11, 2025
c3b0950
grug
dlwh Dec 11, 2025
7947ecf
grug wrapper
dlwh Dec 11, 2025
478dfba
structured attentionmask
dlwh Dec 12, 2025
796771a
comment
dlwh Dec 12, 2025
078d1b8
update principles
dlwh Dec 12, 2025
6551180
wip attention
dlwh Dec 17, 2025
d7889f8
wip
dlwh Dec 29, 2025
2557d42
Merge remote-tracking branch 'origin/main' into grug
dlwh Dec 30, 2025
8ec3b14
isolate grug, use ejkernel
dlwh Dec 30, 2025
291344c
more grug-y
dlwh Dec 30, 2025
7870cf4
make AttentionBackend simpler
dlwh Dec 30, 2025
6ac56b5
new grug_wrapper
dlwh Dec 30, 2025
b712cee
update status
dlwh Dec 30, 2025
1c1a387
simpler, updat eplan
dlwh Dec 31, 2025
7b4af70
fix tests
dlwh Dec 31, 2025
f824fee
use our axis conventions
dlwh Jan 2, 2026
abcd2d8
grugformer
dlwh Jan 2, 2026
9bc2778
grugformer tests, updated plan
dlwh Jan 3, 2026
a05b32e
fix mesh axis name
dlwh Jan 3, 2026
b7d1dfb
expose activations rather than logits
dlwh Jan 3, 2026
221b279
update for "tensor"
dlwh Jan 3, 2026
470df8e
update for "tensor"
dlwh Jan 3, 2026
9bc4c42
fused next token loss (grug-style)
dlwh Jan 3, 2026
8df219b
grugformer speedrun template
dlwh Jan 3, 2026
4c5bf31
need lm_head
dlwh Jan 3, 2026
11e7d37
Merge remote-tracking branch 'origin/main' into grugformer-speedrun
dlwh Jan 3, 2026
783e8e1
allow levanter to make explicit meshes (needed for grug)
dlwh Jan 5, 2026
36c83bb
update haliax slicing to work with explicit axes (needed for grug)
dlwh Jan 5, 2026
2c12105
fix deserialization of grug
dlwh Jan 6, 2026
091cf6c
allow levanter to make explicit meshes (needed for grug)
dlwh Jan 5, 2026
ba54b61
update haliax slicing to work with explicit axes (needed for grug)
dlwh Jan 5, 2026
be21d34
fix deserialization of grug
dlwh Jan 6, 2026
ba86bcc
haliax.take: only set out_sharding for explicit meshes
dlwh Jan 6, 2026
cae9a01
haliax.partitioning: add get_pspec_for_manual_mesh
dlwh Jan 6, 2026
ca5e168
tensorstore: use haliax mesh helper for concretizing shardings
dlwh Jan 6, 2026
8bf1ba9
Merge branch 'pr/explicit-mesh-and-take-outsharding' into grugformer-…
dlwh Jan 6, 2026
b914a47
Merge branch 'pr/tensorstore-concrete-mesh' into grugformer-speedrun
dlwh Jan 6, 2026
4639f58
grug: replicate vocab weights on TPU to avoid sflag OOM
dlwh Jan 6, 2026
04c9fc6
Merge remote-tracking branch 'origin/main' into grugformer-speedrun
dlwh Jan 7, 2026
a2826ba
Merge remote-tracking branch 'origin/main' into grug
dlwh Jan 7, 2026
89d659f
Merge branch 'grug' into grugformer-speedrun
dlwh Jan 7, 2026
ec9d055
wip
dlwh Jan 7, 2026
2470da6
wip
dlwh Jan 8, 2026
c002aaf
push compute_next_token_loss into LmHeadModel
dlwh Jan 8, 2026
c75cecb
loss: move compute_next_token_loss into LmHeadModel
dlwh Jan 8, 2026
a33eef6
examples: remove alpaca and gsm8k lora examples
dlwh Jan 8, 2026
ce5930b
wip
dlwh Jan 8, 2026
0a55d8d
minor
dlwh Jan 8, 2026
0725a59
Merge branch 'loss_in_model' into grugformer-speedrun
dlwh Jan 8, 2026
8358f3c
grug: add blockwise loss_fn and use it in wrapper
dlwh Jan 8, 2026
628bc8e
grug loss: keep sharding stable when padding vocab blocks
dlwh Jan 8, 2026
fdd6cb5
grug loss: specify gather out_sharding for label columns
dlwh Jan 8, 2026
c0b099d
grug: stop forcing output_proj sharding; keep vocab pspec
dlwh Jan 8, 2026
c1b08ac
grug loss: pass PartitionSpec to gather out_sharding
dlwh Jan 8, 2026
1c1f1f9
grug loss: default gather out_sharding from mesh when missing
dlwh Jan 8, 2026
27d5a7e
grug loss: clamp TPU vocab block size to avoid HBM OOM
dlwh Jan 8, 2026
49f48a4
grug: use tokamax linear softmax CE on TPU; add dependency
dlwh Jan 8, 2026
aed9877
grug: parse absl flags before calling tokamax
dlwh Jan 8, 2026
5f0658d
grug: note TODO for vendored CE kernel
dlwh Jan 8, 2026
77504e0
grug: handle tokamax mosaic batch-size multiple by splitting
dlwh Jan 9, 2026
4bc1c4a
hackable tweaks
dlwh Jan 9, 2026
1dd2918
Merge remote-tracking branch 'origin/main' into grugformer-speedrun
dlwh Jan 9, 2026
ac9b095
revert change
dlwh Jan 9, 2026
2280d0a
almost ready!
dlwh Jan 10, 2026
8df8892
Merge remote-tracking branch 'origin/main' into grugformer-speedrun
dlwh Jan 10, 2026
36b0fef
move tests, cleanup
dlwh Jan 10, 2026
6d16e75
a bit of cleanup
dlwh Jan 10, 2026
5b2f594
jaxtyping
dlwh Jan 10, 2026
4ac926f
cleanup
dlwh Jan 10, 2026
b321a33
Merge branch 'main' into grug
pc0618 Jan 10, 2026
0ee618f
Fix grug TPU splash sharding for tracers
pc0618 Jan 11, 2026
b9756b3
Refactor Grug init key splitting
pc0618 Jan 12, 2026
80a509d
Merge remote-tracking branch 'origin/main' into grug
dlwh Jan 12, 2026
df2464a
Merge remote-tracking branch 'origin/grug' into grug
dlwh Jan 12, 2026
ffaa4c6
fix tests?
dlwh Jan 12, 2026
f459801
fix attention perf
dlwh Jan 20, 2026
e85f680
woo
dlwh Jan 21, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 221 additions & 0 deletions .agents/projects/grugformer.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion .github/workflows/levanter-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -191,4 +191,4 @@ jobs:
-v /tmp/uv-cache:/tmp/uv-cache:rw \
-w /workspace \
$DOCKER_IMAGE \
bash -c "cp -a /workspace-src/. /workspace/ && cd /workspace && timeout --kill-after=5 --signal=TERM 890 uv run --package levanter --frozen --with 'jax[tpu]==$JAX_VERSION' pytest lib/levanter/tests -m 'not entry and not ray and not slow' -v --tb=short --log-cli-level=WARNING --durations=20"
bash -c "cp -a /workspace-src/. /workspace/ && cd /workspace && timeout --kill-after=5 --signal=TERM 890 uv run --package levanter --frozen --with 'jax[tpu]==$JAX_VERSION' pytest lib/levanter/tests -m 'not entry and not ray and not slow' -v --tb=short --log-cli-level=WARNING --durations=20"
1 change: 1 addition & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
- Begin with the agent-friendly recipes in `docs/recipes/`.
- The first step for dataset addition is schema inspection. See the [add_dataset.md](docs/recipes/add_dataset.md) recipe for details.
- You can help organize experiments using the [organize_experiments.md](docs/recipes/organize_experiments.md) recipe.
- When making significant changes to Grug/Grugformer, follow [change_grug.md](docs/recipes/change_grug.md).
- Follow the rules and examples in each recipe to ensure compatibility and automation-friendliness.

## Shared Coding Practices
Expand Down
112 changes: 112 additions & 0 deletions docs/recipes/change_grug.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Recipe: Changing Grug (Experiment → Canonical)

Grug is meant to be “grug-simple” and easy to hack, but we still want a single, trustworthy “best guess” implementation in `levanter.grug`.

This recipe describes the workflow for:

1) trying changes safely in a speedrun experiment, and
2) upstreaming successful ideas into the canonical core (and cleaning up old experiments).

## Source Of Truth vs Experiments

- **Source of truth:** `lib/levanter/src/levanter/grug/`
- This is the “best guess” model. It should stay small, readable, and testable.
- **Evolving experiment:** `experiments/speedrun/grugformer_starter/grugformer_speedrun.py`
- This is the *living* entrypoint that is expected to evolve as we learn.
- **One-off experiments:** under `experiments/speedrun/…`
- These are snapshots / specialized edit surfaces (e.g. attention sinks).

We try not to let one-off scripts become the canonical implementation.

## When You Want To Try Something

### 1) Decide what you’re changing

Most changes fall into one bucket:

- **Attention** (masking semantics, kernels, sinks/aux, layout/sharding)
- **Block** (residual wiring, normalization order, pre/post-norm)
- **MLP** (activation, GLU variants, gating, dimension choices)
- **Loss** (large-vocab CE, z-loss, label smoothing, logit soft-cap)
- **Optimizer** (Adam, Muon, etc.)

Try to change **one bucket at a time**. Optimizer isn't really (currently) addressed by Grug, but we'll get there.

### 2) Create an experiment entrypoint

Start from:

- `experiments/speedrun/grugformer_starter/grugformer_speedrun.py`

Recommended workflow:

1. Copy the file to a new experiment (or branch the starter if the change is small):
- Example: `experiments/speedrun/grugformer_<idea>/grugformer_<idea>.py`
2. Keep the edit surface explicit:
- If you’re changing attention, keep the change in one local `attention()` or `attn_fn` block.
- If you’re changing the MLP, keep it local to an `mlp()` helper.
3. Avoid introducing new abstractions (this is a speedrun file; copy/paste is fine).

### 3) Register the experiment in the archive

Add an entry to:

- `docs/reports/grug-archive.md`

Record:
- the experiment path,
- the commit SHA (once known),
- what you changed and why,
- the intended “status” (`active`, `superseded`, `deleted`).

## When You Want To Adopt Something As Canonical

### 1) Port to `levanter.grug`

Move the change into one of the core files:

- `lib/levanter/src/levanter/grug/attention.py`
- `lib/levanter/src/levanter/grug/model.py`
- `lib/levanter/src/levanter/grug/loss.py`

Keep the “grug” style:
- top-level functions,
- small dataclasses only for parameter/state containers,
- explicit sharding when needed (and loud failures otherwise).

### 2) Update/extend tests

Add or adjust tests to lock the intended surface:

- `lib/levanter/tests/test_grugformer_core.py`
- `lib/levanter/tests/test_grugformer_model_loss.py`
- `lib/levanter/tests/test_grugformer_fused_loss.py`

The goal is:
- shapes don’t regress,
- `jit` still works,
- sharding doesn’t explode,
- mask semantics remain correct.

### 3) Clean up old experiments

After merging a canonical improvement:

- If an experiment is now redundant and not referenced, **delete it** and mark it `deleted` in `docs/reports/grug-archive.md`.
- If an experiment represents a meaningful historical run, keep it but mark it `superseded`, and point to the canonical change (or the new experiment).
Do this only if it's not going to be a maintenance burden.

Prefer “archive entry + deletion” over keeping lots of old code in-tree.

### 4) Run repo checks

Before sending the PR:

```sh
uv run python infra/pre-commit.py --all-files
```

## Notes / Inspiration

This workflow is inspired by projects like `modded-nanogpt`: keep a small, readable core, iterate quickly via “hackable” entrypoints, and regularly upstream what works.

61 changes: 61 additions & 0 deletions docs/reports/grug-archive.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Grug Archive: Experiments and Snapshots

This file is a lightweight “paper trail” for Grug-related experiments, inspired by the idea of keeping a runnable history without letting a pile of one-off scripts become the de facto source of truth.

## Principles

- **`levanter.grug` is the source of truth.** Speedrun files are snapshots/entrypoints, not the canonical implementation.
- **Every experiment should be attributable to a commit.** If an experiment is removed or superseded, it should be clear what replaced it and why.
- **Prefer deletion over permanent snapshots.** If a script is dead, delete it and record the last known-good commit here.
- **Keep diffs small.** When an experiment is kept “alive”, update it to track the current core rather than forking the entire model.

## When Grug Core Changes

When a change in `levanter.grug` is likely to affect results, performance, or semantics:

1. Update the experiment(s) that should track “best guess”.
2. For experiments that no longer make sense:
- delete them, or
- mark them superseded and point to the replacement.
3. Update the corresponding entry in this archive (and any linked issue).

## Entry Template

Copy/paste this block for new experiments:

```text
### <experiment-id>
- Path: `<repo-relative-path>`
- Introduced: <commit-sha>
- Last known-good: <commit-sha>
- Status: active | superseded | deleted
- Purpose: <one line>
- Notes: <optional; what changed, how to reproduce, caveats>
- Superseded by: <experiment-id or commit-sha; optional>
- Issue: <url or issue id; optional>
```

## Experiments

### grugformer-attnsink
- Path: `experiments/speedrun/grugformer_attnsink/grugformer_attn_sink.py`
- Introduced: TBD
- Last known-good: TBD
- Status: active
- Purpose: “Hackable” Grug attention-sink variant; intended edit surface for sinks/aux.
- Notes: Keep this file short; copy/paste local modifications rather than growing new abstractions.

### grugformer-starter-speedrun
- Path: `experiments/speedrun/grugformer_starter/grugformer_speedrun.py`
- Introduced: TBD
- Last known-good: TBD
- Status: active
- Purpose: Minimal starter speedrun for Grug; convenient baseline for quick iteration.

### grugformer-vs-hackable-125m
- Path: `experiments/speedrun/grugformer_vs_hackable_125m/grugformer_vs_hackable_125m.py`
- Introduced: TBD
- Last known-good: TBD
- Status: active
- Purpose: Head-to-head comparison between Hackable Transformer and Grugformer (no sinks).

2 changes: 1 addition & 1 deletion experiments/simple_train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class SimpleTrainConfig:
"""Whether to run the JAX profiler during training."""
profiler_start_step: int = 5
"""Which step to start profiling."""
profiler_num_steps: int = 100
profiler_num_steps: int = 25
"""How many steps to profile for once started."""

explicit_mesh_axes: bool = False
Expand Down
Loading