From 35eac90727aeb10d1c5b921d56d852b69b67d099 Mon Sep 17 00:00:00 2001 From: "Pinlin [Calvin] Xu" Date: Sat, 6 Dec 2025 03:08:11 -0800 Subject: [PATCH 01/11] rename attn_sink speedrun dir --- .../attnsink/jax/130m/speedrun_results.json | 0 .../attnsink/jax/1_2b/speedrun_results.json | 0 .../attnsink/jax/300m/speedrun_results.json | 0 .../attnsink/jax/520m/speedrun_results.json | 0 .../attnsink/splash/130m/speedrun_results.json | 0 .../attnsink/splash/1_2b/speedrun_results.json | 0 .../attnsink/splash/300m/speedrun_results.json | 0 .../attnsink/splash/520m/speedrun_results.json | 0 .../attnsink/splash/default_lr/130m/speedrun_results.json | 0 .../attnsink/splash/default_lr/1_2b/speedrun_results.json | 0 .../attnsink/splash/default_lr/300m/speedrun_results.json | 0 .../attnsink/splash/default_lr/520m/speedrun_results.json | 0 .../attnsink/splash/default_x0.25/130m/speedrun_results.json | 0 .../attnsink/splash/default_x0.25/1_2b/speedrun_results.json | 0 .../attnsink/splash/default_x0.25/300m/speedrun_results.json | 0 .../attnsink/splash/default_x0.25/520m/speedrun_results.json | 0 .../attnsink/splash/default_x0.5/130m/speedrun_results.json | 0 .../attnsink/splash/default_x0.5/1_2b/speedrun_results.json | 0 .../attnsink/splash/default_x0.5/300m/speedrun_results.json | 0 .../attnsink/splash/default_x0.5/520m/speedrun_results.json | 0 .../attnsink/splash/default_x2/130m/speedrun_results.json | 0 .../attnsink/splash/default_x2/1_2b/speedrun_results.json | 0 .../attnsink/splash/default_x2/300m/speedrun_results.json | 0 .../attnsink/splash/default_x2/520m/speedrun_results.json | 0 .../attnsink/splash/default_x4/130m/speedrun_results.json | 0 .../attnsink/splash/default_x4/1_2b/speedrun_results.json | 0 .../attnsink/splash/default_x4/300m/speedrun_results.json | 0 .../attnsink/splash/default_x4/520m/speedrun_results.json | 0 .../hackable_transformer_attn_sink.py | 0 .../std_attn/1.2b/speedrun_results.json | 0 .../std_attn/130m/speedrun_results.json | 0 .../std_attn/300m/speedrun_results.json | 0 .../std_attn/520m/speedrun_results.json | 0 33 files changed, 0 insertions(+), 0 deletions(-) rename experiments/speedrun/{hackable_transformer_starter => hackable_transformer_attn_sink}/attnsink/jax/130m/speedrun_results.json (100%) rename experiments/speedrun/{hackable_transformer_starter => hackable_transformer_attn_sink}/attnsink/jax/1_2b/speedrun_results.json (100%) rename experiments/speedrun/{hackable_transformer_starter => hackable_transformer_attn_sink}/attnsink/jax/300m/speedrun_results.json (100%) rename experiments/speedrun/{hackable_transformer_starter => hackable_transformer_attn_sink}/attnsink/jax/520m/speedrun_results.json (100%) rename experiments/speedrun/{hackable_transformer_starter => hackable_transformer_attn_sink}/attnsink/splash/130m/speedrun_results.json (100%) rename experiments/speedrun/{hackable_transformer_starter => hackable_transformer_attn_sink}/attnsink/splash/1_2b/speedrun_results.json (100%) rename experiments/speedrun/{hackable_transformer_starter => hackable_transformer_attn_sink}/attnsink/splash/300m/speedrun_results.json (100%) rename experiments/speedrun/{hackable_transformer_starter => hackable_transformer_attn_sink}/attnsink/splash/520m/speedrun_results.json (100%) rename experiments/speedrun/{hackable_transformer_starter => hackable_transformer_attn_sink}/attnsink/splash/default_lr/130m/speedrun_results.json (100%) rename experiments/speedrun/{hackable_transformer_starter => hackable_transformer_attn_sink}/attnsink/splash/default_lr/1_2b/speedrun_results.json (100%) rename experiments/speedrun/{hackable_transformer_starter => hackable_transformer_attn_sink}/attnsink/splash/default_lr/300m/speedrun_results.json (100%) rename experiments/speedrun/{hackable_transformer_starter => hackable_transformer_attn_sink}/attnsink/splash/default_lr/520m/speedrun_results.json (100%) rename experiments/speedrun/{hackable_transformer_starter => hackable_transformer_attn_sink}/attnsink/splash/default_x0.25/130m/speedrun_results.json (100%) rename experiments/speedrun/{hackable_transformer_starter => hackable_transformer_attn_sink}/attnsink/splash/default_x0.25/1_2b/speedrun_results.json (100%) rename experiments/speedrun/{hackable_transformer_starter => hackable_transformer_attn_sink}/attnsink/splash/default_x0.25/300m/speedrun_results.json (100%) rename experiments/speedrun/{hackable_transformer_starter => hackable_transformer_attn_sink}/attnsink/splash/default_x0.25/520m/speedrun_results.json (100%) rename experiments/speedrun/{hackable_transformer_starter => hackable_transformer_attn_sink}/attnsink/splash/default_x0.5/130m/speedrun_results.json (100%) rename experiments/speedrun/{hackable_transformer_starter => hackable_transformer_attn_sink}/attnsink/splash/default_x0.5/1_2b/speedrun_results.json (100%) rename experiments/speedrun/{hackable_transformer_starter => hackable_transformer_attn_sink}/attnsink/splash/default_x0.5/300m/speedrun_results.json (100%) rename experiments/speedrun/{hackable_transformer_starter => hackable_transformer_attn_sink}/attnsink/splash/default_x0.5/520m/speedrun_results.json (100%) rename experiments/speedrun/{hackable_transformer_starter => hackable_transformer_attn_sink}/attnsink/splash/default_x2/130m/speedrun_results.json (100%) rename experiments/speedrun/{hackable_transformer_starter => hackable_transformer_attn_sink}/attnsink/splash/default_x2/1_2b/speedrun_results.json (100%) rename experiments/speedrun/{hackable_transformer_starter => hackable_transformer_attn_sink}/attnsink/splash/default_x2/300m/speedrun_results.json (100%) rename experiments/speedrun/{hackable_transformer_starter => hackable_transformer_attn_sink}/attnsink/splash/default_x2/520m/speedrun_results.json (100%) rename experiments/speedrun/{hackable_transformer_starter => hackable_transformer_attn_sink}/attnsink/splash/default_x4/130m/speedrun_results.json (100%) rename experiments/speedrun/{hackable_transformer_starter => hackable_transformer_attn_sink}/attnsink/splash/default_x4/1_2b/speedrun_results.json (100%) rename experiments/speedrun/{hackable_transformer_starter => hackable_transformer_attn_sink}/attnsink/splash/default_x4/300m/speedrun_results.json (100%) rename experiments/speedrun/{hackable_transformer_starter => hackable_transformer_attn_sink}/attnsink/splash/default_x4/520m/speedrun_results.json (100%) rename experiments/speedrun/{hackable_transformer_starter => hackable_transformer_attn_sink}/hackable_transformer_attn_sink.py (100%) rename experiments/speedrun/{hackable_transformer_starter => hackable_transformer_attn_sink}/std_attn/1.2b/speedrun_results.json (100%) rename experiments/speedrun/{hackable_transformer_starter => hackable_transformer_attn_sink}/std_attn/130m/speedrun_results.json (100%) rename experiments/speedrun/{hackable_transformer_starter => hackable_transformer_attn_sink}/std_attn/300m/speedrun_results.json (100%) rename experiments/speedrun/{hackable_transformer_starter => hackable_transformer_attn_sink}/std_attn/520m/speedrun_results.json (100%) diff --git a/experiments/speedrun/hackable_transformer_starter/attnsink/jax/130m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_sink/attnsink/jax/130m/speedrun_results.json similarity index 100% rename from experiments/speedrun/hackable_transformer_starter/attnsink/jax/130m/speedrun_results.json rename to experiments/speedrun/hackable_transformer_attn_sink/attnsink/jax/130m/speedrun_results.json diff --git a/experiments/speedrun/hackable_transformer_starter/attnsink/jax/1_2b/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_sink/attnsink/jax/1_2b/speedrun_results.json similarity index 100% rename from experiments/speedrun/hackable_transformer_starter/attnsink/jax/1_2b/speedrun_results.json rename to experiments/speedrun/hackable_transformer_attn_sink/attnsink/jax/1_2b/speedrun_results.json diff --git a/experiments/speedrun/hackable_transformer_starter/attnsink/jax/300m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_sink/attnsink/jax/300m/speedrun_results.json similarity index 100% rename from experiments/speedrun/hackable_transformer_starter/attnsink/jax/300m/speedrun_results.json rename to experiments/speedrun/hackable_transformer_attn_sink/attnsink/jax/300m/speedrun_results.json diff --git a/experiments/speedrun/hackable_transformer_starter/attnsink/jax/520m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_sink/attnsink/jax/520m/speedrun_results.json similarity index 100% rename from experiments/speedrun/hackable_transformer_starter/attnsink/jax/520m/speedrun_results.json rename to experiments/speedrun/hackable_transformer_attn_sink/attnsink/jax/520m/speedrun_results.json diff --git a/experiments/speedrun/hackable_transformer_starter/attnsink/splash/130m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/130m/speedrun_results.json similarity index 100% rename from experiments/speedrun/hackable_transformer_starter/attnsink/splash/130m/speedrun_results.json rename to experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/130m/speedrun_results.json diff --git a/experiments/speedrun/hackable_transformer_starter/attnsink/splash/1_2b/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/1_2b/speedrun_results.json similarity index 100% rename from experiments/speedrun/hackable_transformer_starter/attnsink/splash/1_2b/speedrun_results.json rename to experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/1_2b/speedrun_results.json diff --git a/experiments/speedrun/hackable_transformer_starter/attnsink/splash/300m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/300m/speedrun_results.json similarity index 100% rename from experiments/speedrun/hackable_transformer_starter/attnsink/splash/300m/speedrun_results.json rename to experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/300m/speedrun_results.json diff --git a/experiments/speedrun/hackable_transformer_starter/attnsink/splash/520m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/520m/speedrun_results.json similarity index 100% rename from experiments/speedrun/hackable_transformer_starter/attnsink/splash/520m/speedrun_results.json rename to experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/520m/speedrun_results.json diff --git a/experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_lr/130m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_lr/130m/speedrun_results.json similarity index 100% rename from experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_lr/130m/speedrun_results.json rename to experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_lr/130m/speedrun_results.json diff --git a/experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_lr/1_2b/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_lr/1_2b/speedrun_results.json similarity index 100% rename from experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_lr/1_2b/speedrun_results.json rename to experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_lr/1_2b/speedrun_results.json diff --git a/experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_lr/300m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_lr/300m/speedrun_results.json similarity index 100% rename from experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_lr/300m/speedrun_results.json rename to experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_lr/300m/speedrun_results.json diff --git a/experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_lr/520m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_lr/520m/speedrun_results.json similarity index 100% rename from experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_lr/520m/speedrun_results.json rename to experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_lr/520m/speedrun_results.json diff --git a/experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_x0.25/130m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_x0.25/130m/speedrun_results.json similarity index 100% rename from experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_x0.25/130m/speedrun_results.json rename to experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_x0.25/130m/speedrun_results.json diff --git a/experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_x0.25/1_2b/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_x0.25/1_2b/speedrun_results.json similarity index 100% rename from experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_x0.25/1_2b/speedrun_results.json rename to experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_x0.25/1_2b/speedrun_results.json diff --git a/experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_x0.25/300m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_x0.25/300m/speedrun_results.json similarity index 100% rename from experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_x0.25/300m/speedrun_results.json rename to experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_x0.25/300m/speedrun_results.json diff --git a/experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_x0.25/520m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_x0.25/520m/speedrun_results.json similarity index 100% rename from experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_x0.25/520m/speedrun_results.json rename to experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_x0.25/520m/speedrun_results.json diff --git a/experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_x0.5/130m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_x0.5/130m/speedrun_results.json similarity index 100% rename from experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_x0.5/130m/speedrun_results.json rename to experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_x0.5/130m/speedrun_results.json diff --git a/experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_x0.5/1_2b/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_x0.5/1_2b/speedrun_results.json similarity index 100% rename from experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_x0.5/1_2b/speedrun_results.json rename to experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_x0.5/1_2b/speedrun_results.json diff --git a/experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_x0.5/300m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_x0.5/300m/speedrun_results.json similarity index 100% rename from experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_x0.5/300m/speedrun_results.json rename to experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_x0.5/300m/speedrun_results.json diff --git a/experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_x0.5/520m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_x0.5/520m/speedrun_results.json similarity index 100% rename from experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_x0.5/520m/speedrun_results.json rename to experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_x0.5/520m/speedrun_results.json diff --git a/experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_x2/130m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_x2/130m/speedrun_results.json similarity index 100% rename from experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_x2/130m/speedrun_results.json rename to experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_x2/130m/speedrun_results.json diff --git a/experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_x2/1_2b/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_x2/1_2b/speedrun_results.json similarity index 100% rename from experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_x2/1_2b/speedrun_results.json rename to experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_x2/1_2b/speedrun_results.json diff --git a/experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_x2/300m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_x2/300m/speedrun_results.json similarity index 100% rename from experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_x2/300m/speedrun_results.json rename to experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_x2/300m/speedrun_results.json diff --git a/experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_x2/520m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_x2/520m/speedrun_results.json similarity index 100% rename from experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_x2/520m/speedrun_results.json rename to experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_x2/520m/speedrun_results.json diff --git a/experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_x4/130m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_x4/130m/speedrun_results.json similarity index 100% rename from experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_x4/130m/speedrun_results.json rename to experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_x4/130m/speedrun_results.json diff --git a/experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_x4/1_2b/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_x4/1_2b/speedrun_results.json similarity index 100% rename from experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_x4/1_2b/speedrun_results.json rename to experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_x4/1_2b/speedrun_results.json diff --git a/experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_x4/300m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_x4/300m/speedrun_results.json similarity index 100% rename from experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_x4/300m/speedrun_results.json rename to experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_x4/300m/speedrun_results.json diff --git a/experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_x4/520m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_x4/520m/speedrun_results.json similarity index 100% rename from experiments/speedrun/hackable_transformer_starter/attnsink/splash/default_x4/520m/speedrun_results.json rename to experiments/speedrun/hackable_transformer_attn_sink/attnsink/splash/default_x4/520m/speedrun_results.json diff --git a/experiments/speedrun/hackable_transformer_starter/hackable_transformer_attn_sink.py b/experiments/speedrun/hackable_transformer_attn_sink/hackable_transformer_attn_sink.py similarity index 100% rename from experiments/speedrun/hackable_transformer_starter/hackable_transformer_attn_sink.py rename to experiments/speedrun/hackable_transformer_attn_sink/hackable_transformer_attn_sink.py diff --git a/experiments/speedrun/hackable_transformer_starter/std_attn/1.2b/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_sink/std_attn/1.2b/speedrun_results.json similarity index 100% rename from experiments/speedrun/hackable_transformer_starter/std_attn/1.2b/speedrun_results.json rename to experiments/speedrun/hackable_transformer_attn_sink/std_attn/1.2b/speedrun_results.json diff --git a/experiments/speedrun/hackable_transformer_starter/std_attn/130m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_sink/std_attn/130m/speedrun_results.json similarity index 100% rename from experiments/speedrun/hackable_transformer_starter/std_attn/130m/speedrun_results.json rename to experiments/speedrun/hackable_transformer_attn_sink/std_attn/130m/speedrun_results.json diff --git a/experiments/speedrun/hackable_transformer_starter/std_attn/300m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_sink/std_attn/300m/speedrun_results.json similarity index 100% rename from experiments/speedrun/hackable_transformer_starter/std_attn/300m/speedrun_results.json rename to experiments/speedrun/hackable_transformer_attn_sink/std_attn/300m/speedrun_results.json diff --git a/experiments/speedrun/hackable_transformer_starter/std_attn/520m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_sink/std_attn/520m/speedrun_results.json similarity index 100% rename from experiments/speedrun/hackable_transformer_starter/std_attn/520m/speedrun_results.json rename to experiments/speedrun/hackable_transformer_attn_sink/std_attn/520m/speedrun_results.json From e399064da52ac9ddccc426d79f641ed69a58ca10 Mon Sep 17 00:00:00 2001 From: "Pinlin [Calvin] Xu" Date: Tue, 9 Dec 2025 00:49:06 -0800 Subject: [PATCH 02/11] Add Gated Attention --- lib/levanter/src/levanter/layers/attention.py | 56 +++++++++++++++++-- lib/levanter/tests/test_attention.py | 26 +++++++++ 2 files changed, 77 insertions(+), 5 deletions(-) diff --git a/lib/levanter/src/levanter/layers/attention.py b/lib/levanter/src/levanter/layers/attention.py index 6013bea27d..a910695a47 100644 --- a/lib/levanter/src/levanter/layers/attention.py +++ b/lib/levanter/src/levanter/layers/attention.py @@ -1522,7 +1522,7 @@ class AttentionConfig: scaling_factor: Optional[float] = None logits_soft_cap: Optional[float] = None qk_norm: Optional[LayerNormConfigBase] = None - """Configuration for QK normalization. If None, no normalization is applied.""" + gated: bool = False def __post_init__(self): assert ( @@ -1581,12 +1581,17 @@ class Attention(eqx.Module): q_norm: Optional[LayerNormBase] = None k_norm: Optional[LayerNormBase] = None rot_embs: Optional[RotaryEmbeddings] = None + gate_proj: Optional[hnn.Linear] = None @staticmethod def init(config: AttentionConfig, *, key) -> "Attention": use_bias = config.use_bias use_output_bias = config.use_output_bias if config.use_output_bias is not None else use_bias - k_q, k_k, k_v, k_o = jrandom.split(key, 4) + if config.gated: + k_q, k_k, k_v, k_o, k_g = jrandom.split(key, 5) + else: + k_q, k_k, k_v, k_o = jrandom.split(key, 4) + k_g = None q_proj = hnn.Linear.init( In=config.Embed, Out=(config.KVHeads, config.QHeadsPerGroup, config.HeadSize), @@ -1616,6 +1621,16 @@ def init(config: AttentionConfig, *, key) -> "Attention": out_first=True, ) + gate_proj = None + if config.gated: + gate_proj = hnn.Linear.init( + In=config.Embed, + Out=(config.KVHeads, config.QHeadsPerGroup, config.HeadSize), + key=k_g, + use_bias=use_bias, + out_first=True, + ) + q_norm = None k_norm = None if config.qk_norm is not None: @@ -1625,7 +1640,7 @@ def init(config: AttentionConfig, *, key) -> "Attention": # Build rotary embeddings once during initialization if configured rot_embs = config.rope.build(config.HeadSize) if config.rope is not None else None - return Attention(config, q_proj, k_proj, v_proj, o_proj, q_norm, k_norm, rot_embs) + return Attention(config, q_proj, k_proj, v_proj, o_proj, q_norm, k_norm, rot_embs, gate_proj) def empty_page_cache(self, spec: PageTableSpec, *, dtype) -> "KvPageCache": return KvPageCache.init(spec, self.config.KVHeads, self.config.HeadSize, dtype=dtype) @@ -1639,7 +1654,11 @@ def __call__( key=None, pos_ids: NamedArray | None = None, ) -> NamedArray: - key_proj, key_o = maybe_rng_split(key, 2) + if self.gate_proj is not None: + key_proj, key_o, key_g = maybe_rng_split(key, 3) + else: + key_proj, key_o = maybe_rng_split(key, 2) + key_g = None # Shared computation of q, k, v q, k, v = self._compute_qkv(x, key=key_proj, pos_ids=pos_ids) @@ -1671,6 +1690,12 @@ def __call__( prng=key, ) + if self.gate_proj is not None: + gate = self.gate_proj(x, key=key_g) + gate = gate.rearrange((..., "kv_head", "q_heads_per_group", "position", "head_size")) + gate = hax.nn.sigmoid(gate) + attn_output = attn_output * gate + # Flatten heads and apply output projection attn_output = attn_output.flatten_axes(("kv_head", "q_heads_per_group"), "heads") attn_output = attn_output.astype(x.dtype) @@ -1698,7 +1723,11 @@ def paged_decode( Currently only causal masks are supported. """ - key_proj, key_o = maybe_rng_split(key, 2) + if self.gate_proj is not None: + key_proj, key_o, key_g = maybe_rng_split(key, 3) + else: + key_proj, key_o = maybe_rng_split(key, 2) + key_g = None q, k, v = self._compute_qkv(x, key=key_proj, pos_ids=pos_ids) @@ -1721,6 +1750,12 @@ def paged_decode( soft_cap=self.config.logits_soft_cap, ) + if self.gate_proj is not None: + gate = self.gate_proj(x, key=key_g) + gate = gate.rearrange((..., "kv_head", "q_heads_per_group", "position", "head_size")) + gate = hax.nn.sigmoid(gate) + attn_tokens = attn_tokens * gate + attn_output = attn_tokens.flatten_axes(("kv_head", "q_heads_per_group"), "heads") attn_output = attn_output.astype(x.dtype) attn_output = self.o_proj(attn_output, key=key_o) @@ -2296,6 +2331,7 @@ def init(config: AttentionConfig, *, key) -> "AttentionWithSink": base.q_norm, base.k_norm, base.rot_embs, + base.gate_proj, sinks, ) @@ -2309,6 +2345,10 @@ def __call__( pos_ids: NamedArray | None = None, ) -> NamedArray: key_q, key_k, key_v, key_o = maybe_rng_split(key, 4) + key_g = None + + if self.gate_proj is not None: + key_q, key_k, key_v, key_o, key_g = maybe_rng_split(key, 5) q_proj = self.q_proj(x, key=key_q) k_proj = self.k_proj(x, key=key_k) @@ -2353,6 +2393,12 @@ def __call__( attn_sink=self.sinks, ) + if self.gate_proj is not None: + gate = self.gate_proj(x, key=key_g) + gate = gate.rearrange((..., "kv_head", "q_heads_per_group", "position", "head_size")) + gate = hax.nn.sigmoid(gate) + attn_output = attn_output * gate + attn_output = attn_output.flatten_axes(("kv_head", "q_heads_per_group"), "heads") attn_output = attn_output.astype(x.dtype) attn_output = self.o_proj(attn_output, key=key_o) diff --git a/lib/levanter/tests/test_attention.py b/lib/levanter/tests/test_attention.py index 563e10cb44..fb62ea6e59 100644 --- a/lib/levanter/tests/test_attention.py +++ b/lib/levanter/tests/test_attention.py @@ -19,6 +19,7 @@ from haliax.partitioning import ResourceAxis from levanter.layers.attention import ( + Attention, AttentionBackend, AttentionConfig, AttentionMask, @@ -129,6 +130,31 @@ def test_attention_with_sink_module(): assert_trees_all_close(out.array, expected) +def test_attention_with_gating_module(): + Pos = hax.Axis("position", 2) + Embed = hax.Axis("embed", 1) + + config = AttentionConfig(Embed=Embed, num_heads=1, num_kv_heads=1, use_bias=True, gated=True) + attn = Attention.init(config, key=jrandom.PRNGKey(0)) + + attn = eqx.tree_at(lambda a: a.q_proj.weight, attn, hax.zeros(attn.q_proj.weight.axes)) + attn = eqx.tree_at(lambda a: a.q_proj.bias, attn, hax.zeros(attn.q_proj.bias.axes)) + attn = eqx.tree_at(lambda a: a.k_proj.weight, attn, hax.zeros(attn.k_proj.weight.axes)) + attn = eqx.tree_at(lambda a: a.k_proj.bias, attn, hax.zeros(attn.k_proj.bias.axes)) + attn = eqx.tree_at(lambda a: a.v_proj.weight, attn, hax.zeros(attn.v_proj.weight.axes)) + attn = eqx.tree_at(lambda a: a.v_proj.bias, attn, hax.ones(attn.v_proj.bias.axes)) + attn = eqx.tree_at(lambda a: a.o_proj.weight, attn, hax.ones(attn.o_proj.weight.axes)) + attn = eqx.tree_at(lambda a: a.o_proj.bias, attn, hax.zeros(attn.o_proj.bias.axes)) + attn = eqx.tree_at(lambda a: a.gate_proj.weight, attn, hax.zeros(attn.gate_proj.weight.axes)) + attn = eqx.tree_at(lambda a: a.gate_proj.bias, attn, hax.zeros(attn.gate_proj.bias.axes)) + + x = hax.zeros((Pos, Embed)) + out = attn(x, None) + + expected = np.full((2, 1), 0.5) + assert_trees_all_close(out.array, expected) + + def test_te_bin_and_group_axes_by_function(): QPos = hax.Axis("QPos", 128) KPos = hax.Axis("KPos", 128) From 2d8b1c09c45fe1701b39d53a4f4392f1551dcf55 Mon Sep 17 00:00:00 2001 From: "Pinlin [Calvin] Xu" Date: Tue, 9 Dec 2025 00:49:26 -0800 Subject: [PATCH 03/11] fix Paloma local download --- lib/marin/src/marin/speedrun/paloma_local_download.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/marin/src/marin/speedrun/paloma_local_download.py b/lib/marin/src/marin/speedrun/paloma_local_download.py index 9628f33345..1a7a74b39e 100644 --- a/lib/marin/src/marin/speedrun/paloma_local_download.py +++ b/lib/marin/src/marin/speedrun/paloma_local_download.py @@ -43,4 +43,4 @@ def speedrun_paloma_tokenized(tokenizer: str = llama3_tokenizer): if __name__ == "__main__": - executor_main(steps=[paloma_speedrun, *speedrun_paloma_tokenized]) + executor_main(steps=[paloma_speedrun, *speedrun_paloma_tokenized().values()]) From 4dab1b12410de23a793270b9c3afef0e4d6ef6d9 Mon Sep 17 00:00:00 2001 From: "Pinlin [Calvin] Xu" Date: Tue, 9 Dec 2025 00:49:41 -0800 Subject: [PATCH 04/11] Add Gated Attention sweep results --- .../hackable_transformer_attn_gate.py | 536 ++++++++++++++++++ .../130m/speedrun_results.json | 140 +++++ .../1_2b/speedrun_results.json | 140 +++++ .../300m/speedrun_results.json | 140 +++++ .../520m/speedrun_results.json | 140 +++++ 5 files changed, 1096 insertions(+) create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/hackable_transformer_attn_gate.py create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/naive_no_fuse_q/130m/speedrun_results.json create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/naive_no_fuse_q/1_2b/speedrun_results.json create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/naive_no_fuse_q/300m/speedrun_results.json create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/naive_no_fuse_q/520m/speedrun_results.json diff --git a/experiments/speedrun/hackable_transformer_attn_gate/hackable_transformer_attn_gate.py b/experiments/speedrun/hackable_transformer_attn_gate/hackable_transformer_attn_gate.py new file mode 100644 index 0000000000..ac06d83a0d --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/hackable_transformer_attn_gate.py @@ -0,0 +1,536 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Hackable transformer training speedrun sweep + +This file is intentionally self-contained: +- Defines a compact, Llama-ish transformer that implements Levanter's LmHeadModel +- Provides a ready-to-run speedrun sweep across multiple model sizes + +(this example allows comparing using / not using gated attention) + +How to run (GPU or TPU): + 1) Set env vars (WANDB_API_KEY, HF_TOKEN, etc.) as in the tutorial: + https://marin.readthedocs.io/en/latest/tutorials/submitting-speedrun/ + 2) From repo root: + python marin/run/ray_run.py -- \ + python -m experiments.speedrun.hackable_transformer_attn_gate.hackable_transformer_attn_gate + 3) Optional: SR_USE_GPU=1 to use GPU resource presets. + +The transformer is a pared-down version of levanter.models.llama; you can refer to it if you wish to +add back functionality (like inference, HF exports) + +To edit this file for your speedrun: + 1) Copy and rename the file in your location under experiments.speedrun + 2) Make changes to the architecture or configurations + 3) Add your author information + 4) Submit (see "How to run" above) +""" + +# nodryrun +import sys +import os +import dataclasses +import logging +from dataclasses import dataclass +from collections.abc import Callable + +import equinox as eqx +import jax.random as jrandom +from jaxtyping import PRNGKeyArray + +import haliax as hax +import haliax.nn as hnn +from haliax import Axis, AxisSpec, NamedArray +from haliax.jax_utils import maybe_rng_split, named_call, shaped_rng_split +from haliax.nn.scan import ScanCheckpointPolicy, Stacked +from haliax.state_dict import ModuleWithStateDictSerialization +from levanter.utils.types import BlockFoldable + +from levanter.layers import RmsNormConfig, LayerNormConfigBase +from levanter.layers.attention import Attention, AttentionConfig, AttentionMask, AttentionBackend +from levanter.layers.rotary import DefaultRotaryEmbeddingsConfig, RotaryEmbeddingsConfig +from levanter.models.lm_model import LmConfig, LmHeadModel +from levanter.utils.activation import ActivationFunctionEnum +from levanter.utils.flop_utils import lm_flops_per_token +from levanter.utils.logging import silence_transformer_nag + +from marin.speedrun.speedrun import Author, SpeedrunConfig, default_speedrun +from marin.execution.executor import executor_main +from fray.cluster import ResourceConfig +from experiments.simple_train_config import SimpleTrainConfig + +# Optional: Muon optimizer configs +from levanter.optim import MuonConfig +from experiments.llama import llama3_tokenizer_vocab_size + +logger = logging.getLogger("ray") + +_IMPORT_PATH = getattr(__spec__, "name", __name__) + +silence_transformer_nag() + +# ========================= +# Hackable config & modules +# ========================= + + +@LmConfig.register_subclass("hackable_transformer") +@dataclass(frozen=True) +class HackableTransformerConfig(LmConfig["HackableLMHeadModel"]): + # Core dims + seq_len: int = 2048 + hidden_dim: int = 4096 + intermediate_dim: int = 11008 + num_layers: int = 32 + num_heads: int = 32 + num_kv_heads: int = 32 + head_dim: int | None = None + + activation_function: ActivationFunctionEnum = ActivationFunctionEnum.silu + use_bias: bool = False + use_layer_norm_weight: bool = True + layer_norm_epsilon: float = 1e-5 + tie_word_embeddings: bool = False + input_embedding_norm: bool = False + + # Attention + use_gated_attention: bool = False + upcast_attn: bool = False + attn_backend: AttentionBackend | None = None + flash_attention_block_size: int | None = None + rope: RotaryEmbeddingsConfig = dataclasses.field(default_factory=DefaultRotaryEmbeddingsConfig) + qk_norm: LayerNormConfigBase | None = None # set to RmsNormConfig(...) to enable + + gradient_checkpointing: bool | ScanCheckpointPolicy | str = True + initializer_range: float = 0.02 + reference_checkpoint: str = "NousResearch/Llama-2-7b-hf" + tokenizer: str | None = None + + def __post_init__(self): + assert self.num_heads % self.num_kv_heads == 0, "num_heads must be divisible by num_kv_heads" + if self.head_dim is None: + assert self.hidden_dim % self.num_heads == 0, "hidden_dim % num_heads must be 0 when head_dim=None" + + # ---- LmConfig API ---- + @property + def model_type(self) -> type["HackableLMHeadModel"]: + return HackableLMHeadModel + + Pos = property(lambda self: Axis("position", self.seq_len)) + KeyPos = property(lambda self: self.Pos.alias("key_position")) + Embed = property(lambda self: Axis("embed", self.hidden_dim)) + Layers = property(lambda self: Axis("layers", self.num_layers)) + Mlp = property(lambda self: Axis("mlp", self.intermediate_dim)) + + @property + def norm_config(self) -> LayerNormConfigBase: + return RmsNormConfig(use_weight=self.use_layer_norm_weight, use_bias=self.use_bias, eps=self.layer_norm_epsilon) + + def mk_LayerNorm(self, axis: AxisSpec): + return self.norm_config.build(axis) + + def attention_config(self) -> AttentionConfig: + return AttentionConfig( + Embed=self.Embed, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + use_bias=self.use_bias, + upcast_attn=self.upcast_attn, + attn_backend=self.attn_backend, + flash_attention_block_size=self.flash_attention_block_size, + rope=self.rope, + qk_norm=self.qk_norm, + gated=self.use_gated_attention, + ) + + @property + def actual_head_size(self) -> int: + return self.head_dim or (self.hidden_dim // self.num_heads) + + def flops_per_token(self, vocab_size: int) -> float | None: + return lm_flops_per_token( + hidden_dim=self.hidden_dim, + intermediate_dim=self.intermediate_dim, + num_layers=self.num_layers, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + seq_len=self.seq_len, + vocab_size=vocab_size, + glu=True, + ) + + def total_trainable_params(self, vocab_size: int) -> int: + token_embedding = vocab_size * self.hidden_dim + hs = self.actual_head_size + attn = ( + self.hidden_dim * hs * self.num_heads + + 2 * self.hidden_dim * hs * self.num_kv_heads + + hs * self.num_heads * self.hidden_dim + ) + if self.use_gated_attention: + attn += self.hidden_dim * hs * self.num_heads + mlp = 3 * self.hidden_dim * self.intermediate_dim + transformer = self.num_layers * (attn + mlp + 2 * self.hidden_dim) + self.hidden_dim + if self.input_embedding_norm: + transformer += self.hidden_dim + head = 0 if self.tie_word_embeddings else token_embedding + return int(transformer + token_embedding + head) + + +class HackableMlp(eqx.Module): + """GLU MLP""" + + gate_proj: hnn.Linear + up_proj: hnn.Linear + down_proj: hnn.Linear + act: Callable = eqx.field(static=True) + + @staticmethod + def init(Embed: AxisSpec, Mlp: AxisSpec, activation_fn: ActivationFunctionEnum | Callable, *, key, use_bias=False): + k_fc, k_up_proj, k_down_proj = jrandom.split(key, 3) + gate_proj = hnn.Linear.init(Out=Mlp, In=Embed, key=k_fc, use_bias=use_bias, out_first=True) + up_proj = hnn.Linear.init(Out=Mlp, In=Embed, key=k_up_proj, use_bias=use_bias, out_first=True) + down_proj = hnn.Linear.init(Out=Embed, In=Mlp, key=k_down_proj, use_bias=use_bias, out_first=True) + if isinstance(activation_fn, ActivationFunctionEnum): + activation_fn = activation_fn.to_fn() + elif isinstance(activation_fn, str): + activation_fn = ActivationFunctionEnum(activation_fn).to_fn() + return HackableMlp(gate_proj, up_proj, down_proj, activation_fn) + + @named_call + def __call__(self, x: NamedArray, *, key=None) -> NamedArray: + k_gate, k_up, k_down = maybe_rng_split(key, 3) + h = self.act(self.gate_proj(x, key=k_gate)) * self.up_proj(x, key=k_up) + return self.down_proj(h, key=k_down) + + +class HackableDecoderLayer(eqx.Module): + """One transformer block.""" + + config: HackableTransformerConfig = eqx.field(static=True) + self_attn: Attention + mlp: HackableMlp + input_layernorm: hnn.RmsNorm + post_attention_layernorm: hnn.RmsNorm + post_attn_layernorm: hnn.RmsNorm | None = None + post_mlp_layernorm: hnn.RmsNorm | None = None + + @staticmethod + def init(config: HackableTransformerConfig, *, key) -> "HackableDecoderLayer": + k_attn, k_mlp = jrandom.split(key, 2) + attn_cfg = config.attention_config() + attn = Attention.init(attn_cfg, key=k_attn) + mlp = HackableMlp.init(config.Embed, config.Mlp, config.activation_function, key=k_mlp, use_bias=config.use_bias) + ln1 = config.mk_LayerNorm(config.Embed) + ln2 = config.mk_LayerNorm(config.Embed) + return HackableDecoderLayer(config, attn, mlp, ln1, ln2) + + @named_call + def __call__( + self, x: NamedArray, mask: NamedArray | AttentionMask | None, *, key=None, pos_ids: NamedArray | None = None + ): + k_attn, k_mlp = maybe_rng_split(key, 2) + # self attention and skip connection + residual = x + x = self.input_layernorm(x) + attn_output = self.self_attn(x=x, mask=mask, key=k_attn, pos_ids=pos_ids) + if self.post_attn_layernorm is not None: + attn_output = self.post_attn_layernorm(attn_output) + x = residual + attn_output + + # MLP and skip connection + residual = x + x = self.post_attention_layernorm(x) + mlp_output = self.mlp(x, key=k_mlp) + if self.post_mlp_layernorm is not None: + mlp_output = self.post_mlp_layernorm(mlp_output) + output = residual + mlp_output + return output + + +class HackableTransformer(eqx.Module): + config: HackableTransformerConfig = eqx.field(static=True) + layers: BlockFoldable[HackableDecoderLayer] + norm: hnn.RmsNorm + + @staticmethod + def init(config: HackableTransformerConfig, *, key): + S = Stacked # use BlockSeq for non-homogeneous layers + layers = S.init(config.Layers, HackableDecoderLayer, gradient_checkpointing=config.gradient_checkpointing)( + config, key=shaped_rng_split(key, config.num_layers) + ) + return HackableTransformer(config, layers, config.mk_LayerNorm(config.Embed)) + + @named_call + def __call__( + self, x: NamedArray, attn_mask: NamedArray | AttentionMask | None, *, key=None, pos_ids: NamedArray | None = None + ) -> NamedArray: + keys = maybe_rng_split(key, self.config.num_layers) if key is not None else None + x = self.layers.fold(x, mask=attn_mask, key=keys, pos_ids=pos_ids) + return self.norm(x) + + +class HackableEmbedding(ModuleWithStateDictSerialization, eqx.Module): + token_embeddings: hnn.Embedding + norm: hnn.RmsNorm | None = None + + @staticmethod + def init(Vocab: Axis, config: HackableTransformerConfig, *, key): + emb = hnn.Embedding.init(Vocab, config.Embed, key=key) + ln = config.mk_LayerNorm(config.Embed) if config.input_embedding_norm else None + return HackableEmbedding(emb, ln) + + @property + def Vocab(self) -> Axis: + return self.token_embeddings.Vocab + + @named_call + def embed(self, input_ids: NamedArray): + x = self.token_embeddings(input_ids) + return self.norm(x) if self.norm is not None else x + + +class HackableLMHeadModel( + ModuleWithStateDictSerialization, + LmHeadModel[HackableTransformerConfig], +): + """Minimal Llama-like implementation of LmHeadModel""" + + transformer: HackableTransformer + embeddings: HackableEmbedding + lm_head: hnn.Linear | None + + @property + def config(self) -> HackableTransformerConfig: + return self.transformer.config + + @property + def Vocab(self) -> Axis: + return self.embeddings.Vocab + + @classmethod + def init(cls, Vocab: Axis, config: HackableTransformerConfig, *, key) -> "HackableLMHeadModel": + k_t, k_e = jrandom.split(key, 2) + transformer = HackableTransformer.init(config, key=k_t) + embeddings = HackableEmbedding.init(Vocab, config, key=k_e) + lm_head = ( + None + if config.tie_word_embeddings + else hnn.Linear.init(In=config.Embed, Out=Vocab, key=k_e, use_bias=False, out_first=True) + ) + return HackableLMHeadModel(transformer, embeddings, lm_head) + + def activations( + self, + input_ids: NamedArray, + attn_mask: AttentionMask | NamedArray | None = None, + *, + key=None, + pos_ids: NamedArray | None = None, + ) -> NamedArray: + return self.transformer(self.embeddings.embed(input_ids), attn_mask=attn_mask, key=key, pos_ids=pos_ids) + + def get_lm_head(self) -> hax.NamedArray: + return self.embeddings.token_embeddings.weight if self.lm_head is None else self.lm_head.weight + + def resize_vocab(self, new_size: int, key: PRNGKeyArray | None = None) -> "HackableLMHeadModel": + pass + + +# ========================= +# Speedrun sweep definition +# ========================= + +AUTHOR = Author(name="Calvin Xu", affiliation="Stanford University", url="https://pinlinxu.com") # TODO: update me + + +def _get_num_train_steps(param_count: int, batch_size: int, seq_len: int, tpp: int = 20) -> int: + total_tokens = param_count * tpp + return max(1, total_tokens // (batch_size * seq_len)) + + +def _size_presets() -> dict[str, HackableTransformerConfig]: + base = dict( + seq_len=4096, + rope=DefaultRotaryEmbeddingsConfig(), # e.g., Llama3RotaryEmbeddingsConfig() + attn_backend=None, + qk_norm=None, # e.g. RmsNormConfig(use_weight=True, eps=1e-5) + tie_word_embeddings=False, + ) + return { + "130m": HackableTransformerConfig( + hidden_dim=512, intermediate_dim=1792, num_layers=6, num_heads=8, num_kv_heads=8, **base + ), + "300m": HackableTransformerConfig( + hidden_dim=768, intermediate_dim=2688, num_layers=12, num_heads=12, num_kv_heads=12, **base + ), + "520m": HackableTransformerConfig( + hidden_dim=1024, intermediate_dim=3584, num_layers=24, num_heads=16, num_kv_heads=8, **base + ), + "1_2b": HackableTransformerConfig( + hidden_dim=2048, intermediate_dim=7168, num_layers=16, num_heads=16, num_kv_heads=8, **base + ), + } + + +def _muon_presets() -> dict[str, MuonConfig]: + return { + "130m": MuonConfig( + learning_rate=0.016, + adam_lr=0.0032, + weight_decay=0.1, + min_lr_ratio=0, + warmup=0, + momentum=0.95, + beta1=0.8, + beta2=0.98, + epsilon=1e-15, + muon_epsilon=1e-5, + max_grad_norm=1, + lr_schedule="linear", + decay=0.8, + ), + "300m": MuonConfig( + learning_rate=0.008, + adam_lr=0.0024, + weight_decay=0.1, + min_lr_ratio=0, + warmup=0, + momentum=0.98, + beta1=0.8, + beta2=0.98, + epsilon=1e-15, + muon_epsilon=1e-5, + max_grad_norm=1, + lr_schedule="linear", + decay=0.8, + ), + "520m": MuonConfig( + learning_rate=0.008, + adam_lr=0.0024, + weight_decay=0.1, + min_lr_ratio=0, + warmup=0, + momentum=0.98, + beta1=0.8, + beta2=0.98, + epsilon=1e-25, + muon_epsilon=1e-5, + max_grad_norm=1, + lr_schedule="linear", + decay=1, + ), + "1_2b": MuonConfig( + learning_rate=0.004, + adam_lr=0.0012, + weight_decay=0.1, + min_lr_ratio=0, + warmup=0, + momentum=0.98, + beta1=0.8, + beta2=0.98, + epsilon=1e-15, + muon_epsilon=1e-5, + max_grad_norm=2, + lr_schedule="linear", + decay=1, + ), + } + + +def _resource_presets(use_gpu: bool = False): + if use_gpu: + return { + "130m": ResourceConfig.with_gpu("A100-80G", count=1), + "300m": ResourceConfig.with_gpu("A100-80G", count=1), + "520m": ResourceConfig.with_gpu("A100-80G", count=2), + "1_2b": ResourceConfig.with_gpu("A100-80G", count=4), + } + return { + "130m": ResourceConfig.with_tpu("v5p-32"), + "300m": ResourceConfig.with_tpu("v5p-32"), + "520m": ResourceConfig.with_tpu("v5p-32"), + "1_2b": ResourceConfig.with_tpu("v5p-32"), + } + + +def _batch_sizes() -> dict[str, int]: + return {"130m": 128, "300m": 128, "520m": 128, "1_2b": 256} + + +def build_run( + size: str, + use_gate: bool, + *, + use_gpu: bool = False, +) -> tuple[str, SpeedrunConfig]: + sizes = _size_presets() + if size not in sizes: + raise ValueError(f"Unknown size: {size}") + model_cfg = dataclasses.replace(sizes[size], use_gated_attention=use_gate) + + batch = _batch_sizes()[size] + seq_len = model_cfg.seq_len + params = int(model_cfg.total_trainable_params(llama3_tokenizer_vocab_size)) + print(params) + steps = _get_num_train_steps(params, batch, seq_len, tpp=20) + + muon = _muon_presets()[size] + resources = _resource_presets(use_gpu=use_gpu)[size] + + train = SimpleTrainConfig( + resources, + train_batch_size=batch, + num_train_steps=steps, + learning_rate=muon.learning_rate, + optimizer_config=muon, + steps_per_hf_export=-1, # disable checkpointing + ) + + run_name = f"hacktx_{size}_{'attngate' if use_gate else 'stdattn'}_{seq_len}_splash" + desc = ( + f"Hackable Transformer ({size}); " + f"{'Gated Attention' if use_gate else 'Std Attention'} (Splash)" + ) + cfg = SpeedrunConfig(author=AUTHOR, description=desc, model_config=model_cfg, train_config=train) + return run_name, cfg + + +if __name__ == "__main__": + ### + # make the current __main__ module importable under its canonical name + sys.modules[_IMPORT_PATH] = sys.modules[__name__] + # allow the workers to import the classes + for _cls in ( + HackableTransformerConfig, + HackableMlp, + HackableDecoderLayer, + HackableTransformer, + HackableEmbedding, + HackableLMHeadModel, + ): + _cls.__module__ = _IMPORT_PATH + ### + + sizes = ["130m", "300m", "520m", "1_2b"] + use_gpu = bool(int(os.environ.get("SR_USE_GPU", "0"))) + use_gate = True + steps = [] + for s in sizes: + name, cfg = build_run(s, use_gate, use_gpu=use_gpu) + steps.extend(default_speedrun(name, cfg)) + executor_main(steps=steps, description="Hackable transformer gated-attention sweep") diff --git a/experiments/speedrun/hackable_transformer_attn_gate/naive_no_fuse_q/130m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/naive_no_fuse_q/130m/speedrun_results.json new file mode 100644 index 0000000000..349d7d46f9 --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/naive_no_fuse_q/130m/speedrun_results.json @@ -0,0 +1,140 @@ +{ + "runs": [ + { + "run_info": { + "author": { + "affiliation": "Stanford University", + "name": "Calvin Xu", + "url": "https://pinlinxu.com" + }, + "description": "Hackable Transformer (130m); Gated Attention (Splash)", + "device_flops": 459000000000000.0, + "eval/paloma/c4_en/bpb": 1.1564404964447021, + "model_config": { + "activation_function": "silu", + "attn_backend": null, + "cross_entropy_block_size": null, + "flash_attention_block_size": null, + "gradient_checkpointing": true, + "head_dim": null, + "hidden_dim": 512, + "initializer_range": 0.02, + "input_embedding_norm": false, + "intermediate_dim": 1792, + "layer_norm_epsilon": 1e-05, + "num_heads": 8, + "num_kv_heads": 8, + "num_layers": 6, + "qk_norm": null, + "reference_checkpoint": "NousResearch/Llama-2-7b-hf", + "rope": { + "factor": 1.0, + "theta": 10000 + }, + "seq_len": 4096, + "tie_word_embeddings": false, + "tokenizer": null, + "upcast_attn": false, + "use_bias": false, + "use_gated_attention": true, + "use_layer_norm_weight": true + }, + "model_flops": 2.1289341996446515e+18, + "model_flops_per_token": 227868672.0, + "model_size": 155720192, + "num_chips": 16, + "num_devices": 16, + "resources": { + "cpu": 1, + "device": { + "topology": null, + "type": "v5p-32" + }, + "disk": "1g", + "preemptible": true, + "ram": "128m", + "regions": null, + "replicas": 1 + }, + "run_completion_timestamp": "2025-12-06 16:35:18 UTC", + "tokenized_dataset": "gs://marin-us-central1/tokenized/subcache/fineweb-edu-10B-ac65f6", + "total_tokens": 3114270720, + "train_config": { + "allow_partial_checkpoint": false, + "beta1": null, + "beta2": null, + "cycle_length": null, + "data_seed": null, + "decay": null, + "ema_beta": null, + "epsilon": null, + "initialize_from_checkpoint_path": null, + "initialize_from_hf": null, + "int8": false, + "learning_rate": 0.016, + "lr_schedule": null, + "max_eval_batches": null, + "max_grad_norm": null, + "min_lr_ratio": null, + "num_train_steps": 5940, + "optimizer_config": { + "adam_lr": 0.0032, + "adam_weight_decay": null, + "backend_steps": 5, + "beta1": 0.8, + "beta2": 0.98, + "cooldown": null, + "cycle_length": null, + "cycles": null, + "decay": 0.8, + "default_weight_decay_mask": null, + "epsilon": 1e-15, + "haps": null, + "learning_rate": 0.016, + "lr": 0.02, + "lr_schedule": "linear", + "max_grad_norm": 1, + "min_lr_ratio": 0, + "momentum": 0.95, + "muon_epsilon": 1e-05, + "nesterov": true, + "rewarmup": 0.0, + "use_kimi_scaling": false, + "warmup": 0, + "weight_decay": 0.1, + "weight_decay_modules": null + }, + "per_device_eval_parallelism": null, + "profiler": false, + "profiler_num_steps": 100, + "profiler_start_step": 5, + "reset_data_loader_on_init": true, + "rewarmup": null, + "skip_bad_steps": false, + "steps_per_eval": null, + "steps_per_export": 10000, + "steps_per_hf_export": -1, + "steps_per_task_eval": null, + "train_batch_size": 128, + "warmup": null, + "watch": { + "include_histograms": false, + "include_norms": true, + "include_per_parameter_norms": true, + "interval": 10, + "split_scan_layers": true, + "watch_targets": [ + "grads", + "params" + ] + }, + "weight_decay": null, + "z_loss_weight": null + }, + "training_hardware_flops": 1.2199985674433722e+19, + "training_time": 1661.2180929239819, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_130m_attngate_4096_splash-3c6cbd" + } + } + ] +} \ No newline at end of file diff --git a/experiments/speedrun/hackable_transformer_attn_gate/naive_no_fuse_q/1_2b/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/naive_no_fuse_q/1_2b/speedrun_results.json new file mode 100644 index 0000000000..60172d9a5a --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/naive_no_fuse_q/1_2b/speedrun_results.json @@ -0,0 +1,140 @@ +{ + "runs": [ + { + "run_info": { + "author": { + "affiliation": "Stanford University", + "name": "Calvin Xu", + "url": "https://pinlinxu.com" + }, + "description": "Hackable Transformer (1_2b); Gated Attention (Splash)", + "device_flops": 459000000000000.0, + "eval/paloma/c4_en/bpb": 0.9263789653778076, + "model_config": { + "activation_function": "silu", + "attn_backend": null, + "cross_entropy_block_size": null, + "flash_attention_block_size": null, + "gradient_checkpointing": true, + "head_dim": null, + "hidden_dim": 2048, + "initializer_range": 0.02, + "input_embedding_norm": false, + "intermediate_dim": 7168, + "layer_norm_epsilon": 1e-05, + "num_heads": 16, + "num_kv_heads": 8, + "num_layers": 16, + "qk_norm": null, + "reference_checkpoint": "NousResearch/Llama-2-7b-hf", + "rope": { + "factor": 1.0, + "theta": 10000 + }, + "seq_len": 4096, + "tie_word_embeddings": false, + "tokenizer": null, + "upcast_attn": false, + "use_bias": false, + "use_gated_attention": true, + "use_layer_norm_weight": true + }, + "model_flops": 2.5869176757309093e+20, + "model_flops_per_token": 2877292544.0, + "model_size": 1498482688, + "num_chips": 16, + "num_devices": 16, + "resources": { + "cpu": 1, + "device": { + "topology": null, + "type": "v5p-32" + }, + "disk": "1g", + "preemptible": true, + "ram": "128m", + "regions": null, + "replicas": 1 + }, + "run_completion_timestamp": "2025-12-07 15:35:49 UTC", + "tokenized_dataset": "gs://marin-us-central1/tokenized/subcache/fineweb-edu-10B-ac65f6", + "total_tokens": 29969350656, + "train_config": { + "allow_partial_checkpoint": false, + "beta1": null, + "beta2": null, + "cycle_length": null, + "data_seed": null, + "decay": null, + "ema_beta": null, + "epsilon": null, + "initialize_from_checkpoint_path": null, + "initialize_from_hf": null, + "int8": false, + "learning_rate": 0.004, + "lr_schedule": null, + "max_eval_batches": null, + "max_grad_norm": null, + "min_lr_ratio": null, + "num_train_steps": 28581, + "optimizer_config": { + "adam_lr": 0.0012, + "adam_weight_decay": null, + "backend_steps": 5, + "beta1": 0.8, + "beta2": 0.98, + "cooldown": null, + "cycle_length": null, + "cycles": null, + "decay": 1, + "default_weight_decay_mask": null, + "epsilon": 1e-15, + "haps": null, + "learning_rate": 0.004, + "lr": 0.02, + "lr_schedule": "linear", + "max_grad_norm": 2, + "min_lr_ratio": 0, + "momentum": 0.98, + "muon_epsilon": 1e-05, + "nesterov": true, + "rewarmup": 0.0, + "use_kimi_scaling": false, + "warmup": 0, + "weight_decay": 0.1, + "weight_decay_modules": null + }, + "per_device_eval_parallelism": null, + "profiler": false, + "profiler_num_steps": 100, + "profiler_start_step": 5, + "reset_data_loader_on_init": true, + "rewarmup": null, + "skip_bad_steps": false, + "steps_per_eval": null, + "steps_per_export": 10000, + "steps_per_hf_export": -1, + "steps_per_task_eval": null, + "train_batch_size": 256, + "warmup": null, + "watch": { + "include_histograms": false, + "include_norms": true, + "include_per_parameter_norms": true, + "interval": 10, + "split_scan_layers": true, + "watch_targets": [ + "grads", + "params" + ] + }, + "weight_decay": null, + "z_loss_weight": null + }, + "training_hardware_flops": 5.928678262278687e+20, + "training_time": 80728.18984584269, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_1_2b_attngate_4096_splash-1544b5" + } + } + ] +} \ No newline at end of file diff --git a/experiments/speedrun/hackable_transformer_attn_gate/naive_no_fuse_q/300m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/naive_no_fuse_q/300m/speedrun_results.json new file mode 100644 index 0000000000..31f5452cc7 --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/naive_no_fuse_q/300m/speedrun_results.json @@ -0,0 +1,140 @@ +{ + "runs": [ + { + "run_info": { + "author": { + "affiliation": "Stanford University", + "name": "Calvin Xu", + "url": "https://pinlinxu.com" + }, + "description": "Hackable Transformer (300m); Gated Attention (Splash)", + "device_flops": 459000000000000.0, + "eval/paloma/c4_en/bpb": 1.0539624691009521, + "model_config": { + "activation_function": "silu", + "attn_backend": null, + "cross_entropy_block_size": null, + "flash_attention_block_size": null, + "gradient_checkpointing": true, + "head_dim": null, + "hidden_dim": 768, + "initializer_range": 0.02, + "input_embedding_norm": false, + "intermediate_dim": 2688, + "layer_norm_epsilon": 1e-05, + "num_heads": 12, + "num_kv_heads": 12, + "num_layers": 12, + "qk_norm": null, + "reference_checkpoint": "NousResearch/Llama-2-7b-hf", + "rope": { + "factor": 1.0, + "theta": 10000 + }, + "seq_len": 4096, + "tie_word_embeddings": false, + "tokenizer": null, + "upcast_attn": false, + "use_bias": false, + "use_gated_attention": true, + "use_layer_norm_weight": true + }, + "model_flops": 1.021384111077458e+19, + "model_flops_per_token": 555024384.0, + "model_size": 306727680, + "num_chips": 16, + "num_devices": 16, + "resources": { + "cpu": 1, + "device": { + "topology": null, + "type": "v5p-32" + }, + "disk": "1g", + "preemptible": true, + "ram": "128m", + "regions": null, + "replicas": 1 + }, + "run_completion_timestamp": "2025-12-06 18:26:07 UTC", + "tokenized_dataset": "gs://marin-us-central1/tokenized/subcache/fineweb-edu-10B-ac65f6", + "total_tokens": 6134169600, + "train_config": { + "allow_partial_checkpoint": false, + "beta1": null, + "beta2": null, + "cycle_length": null, + "data_seed": null, + "decay": null, + "ema_beta": null, + "epsilon": null, + "initialize_from_checkpoint_path": null, + "initialize_from_hf": null, + "int8": false, + "learning_rate": 0.008, + "lr_schedule": null, + "max_eval_batches": null, + "max_grad_norm": null, + "min_lr_ratio": null, + "num_train_steps": 11700, + "optimizer_config": { + "adam_lr": 0.0024, + "adam_weight_decay": null, + "backend_steps": 5, + "beta1": 0.8, + "beta2": 0.98, + "cooldown": null, + "cycle_length": null, + "cycles": null, + "decay": 0.8, + "default_weight_decay_mask": null, + "epsilon": 1e-15, + "haps": null, + "learning_rate": 0.008, + "lr": 0.02, + "lr_schedule": "linear", + "max_grad_norm": 1, + "min_lr_ratio": 0, + "momentum": 0.98, + "muon_epsilon": 1e-05, + "nesterov": true, + "rewarmup": 0.0, + "use_kimi_scaling": false, + "warmup": 0, + "weight_decay": 0.1, + "weight_decay_modules": null + }, + "per_device_eval_parallelism": null, + "profiler": false, + "profiler_num_steps": 100, + "profiler_start_step": 5, + "reset_data_loader_on_init": true, + "rewarmup": null, + "skip_bad_steps": false, + "steps_per_eval": null, + "steps_per_export": 10000, + "steps_per_hf_export": -1, + "steps_per_task_eval": null, + "train_batch_size": 128, + "warmup": null, + "watch": { + "include_histograms": false, + "include_norms": true, + "include_per_parameter_norms": true, + "interval": 10, + "split_scan_layers": true, + "watch_targets": [ + "grads", + "params" + ] + }, + "weight_decay": null, + "z_loss_weight": null + }, + "training_hardware_flops": 5.2696079184663675e+19, + "training_time": 7175.392045842003, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_300m_attngate_4096_splash-a5e290" + } + } + ] +} \ No newline at end of file diff --git a/experiments/speedrun/hackable_transformer_attn_gate/naive_no_fuse_q/520m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/naive_no_fuse_q/520m/speedrun_results.json new file mode 100644 index 0000000000..5ff5fd85f9 --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/naive_no_fuse_q/520m/speedrun_results.json @@ -0,0 +1,140 @@ +{ + "runs": [ + { + "run_info": { + "author": { + "affiliation": "Stanford University", + "name": "Calvin Xu", + "url": "https://pinlinxu.com" + }, + "description": "Hackable Transformer (520m); Gated Attention (Splash)", + "device_flops": 459000000000000.0, + "eval/paloma/c4_en/bpb": 0.9801320433616638, + "model_config": { + "activation_function": "silu", + "attn_backend": null, + "cross_entropy_block_size": null, + "flash_attention_block_size": null, + "gradient_checkpointing": true, + "head_dim": null, + "hidden_dim": 1024, + "initializer_range": 0.02, + "input_embedding_norm": false, + "intermediate_dim": 3584, + "layer_norm_epsilon": 1e-05, + "num_heads": 16, + "num_kv_heads": 8, + "num_layers": 24, + "qk_norm": null, + "reference_checkpoint": "NousResearch/Llama-2-7b-hf", + "rope": { + "factor": 1.0, + "theta": 10000 + }, + "seq_len": 4096, + "tie_word_embeddings": false, + "tokenizer": null, + "upcast_attn": false, + "use_bias": false, + "use_gated_attention": true, + "use_layer_norm_weight": true + }, + "model_flops": 5.081733891346976e+19, + "model_flops_per_token": 1349517312.0, + "model_size": 627622912, + "num_chips": 16, + "num_devices": 16, + "resources": { + "cpu": 1, + "device": { + "topology": null, + "type": "v5p-32" + }, + "disk": "1g", + "preemptible": true, + "ram": "128m", + "regions": null, + "replicas": 1 + }, + "run_completion_timestamp": "2025-12-07 01:18:57 UTC", + "tokenized_dataset": "gs://marin-us-central1/tokenized/subcache/fineweb-edu-10B-ac65f6", + "total_tokens": 12551979008, + "train_config": { + "allow_partial_checkpoint": false, + "beta1": null, + "beta2": null, + "cycle_length": null, + "data_seed": null, + "decay": null, + "ema_beta": null, + "epsilon": null, + "initialize_from_checkpoint_path": null, + "initialize_from_hf": null, + "int8": false, + "learning_rate": 0.008, + "lr_schedule": null, + "max_eval_batches": null, + "max_grad_norm": null, + "min_lr_ratio": null, + "num_train_steps": 23941, + "optimizer_config": { + "adam_lr": 0.0024, + "adam_weight_decay": null, + "backend_steps": 5, + "beta1": 0.8, + "beta2": 0.98, + "cooldown": null, + "cycle_length": null, + "cycles": null, + "decay": 1, + "default_weight_decay_mask": null, + "epsilon": 1e-25, + "haps": null, + "learning_rate": 0.008, + "lr": 0.02, + "lr_schedule": "linear", + "max_grad_norm": 1, + "min_lr_ratio": 0, + "momentum": 0.98, + "muon_epsilon": 1e-05, + "nesterov": true, + "rewarmup": 0.0, + "use_kimi_scaling": false, + "warmup": 0, + "weight_decay": 0.1, + "weight_decay_modules": null + }, + "per_device_eval_parallelism": null, + "profiler": false, + "profiler_num_steps": 100, + "profiler_start_step": 5, + "reset_data_loader_on_init": true, + "rewarmup": null, + "skip_bad_steps": false, + "steps_per_eval": null, + "steps_per_export": 10000, + "steps_per_hf_export": -1, + "steps_per_task_eval": null, + "train_batch_size": 128, + "warmup": null, + "watch": { + "include_histograms": false, + "include_norms": true, + "include_per_parameter_norms": true, + "interval": 10, + "split_scan_layers": true, + "watch_targets": [ + "grads", + "params" + ] + }, + "weight_decay": null, + "z_loss_weight": null + }, + "training_hardware_flops": 2.2878030460560127e+20, + "training_time": 31152.002261111284, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_520m_attngate_4096_splash-5794f3" + } + } + ] +} \ No newline at end of file From 59ec26c403f3f20e111a67becac50dc72d20c810 Mon Sep 17 00:00:00 2001 From: "Pinlin [Calvin] Xu" Date: Tue, 9 Dec 2025 13:15:29 -0800 Subject: [PATCH 05/11] Improved Gated Attention impl + LR sweep --- .../hackable_transformer_attn_gate.py | 36 +++- lib/levanter/src/levanter/layers/attention.py | 164 ++++++++++-------- lib/levanter/tests/test_attention.py | 55 +++++- 3 files changed, 166 insertions(+), 89 deletions(-) diff --git a/experiments/speedrun/hackable_transformer_attn_gate/hackable_transformer_attn_gate.py b/experiments/speedrun/hackable_transformer_attn_gate/hackable_transformer_attn_gate.py index ac06d83a0d..55c6f01177 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/hackable_transformer_attn_gate.py +++ b/experiments/speedrun/hackable_transformer_attn_gate/hackable_transformer_attn_gate.py @@ -48,6 +48,7 @@ from collections.abc import Callable import equinox as eqx +import numpy as np import jax.random as jrandom from jaxtyping import PRNGKeyArray @@ -472,11 +473,24 @@ def _batch_sizes() -> dict[str, int]: return {"130m": 128, "300m": 128, "520m": 128, "1_2b": 256} +def _lr_multipliers(start: float = 1.0, stop: float = 2.5, step: float = 0.5) -> list[float]: + """Generate LR multipliers for sweep. Paper suggests training with increased LR.""" + vals = np.arange(start, stop + step / 2, step) # +step/2 to include stop + return [float(v) for v in vals] + + +def _format_multiplier_label(mult: float) -> str: + s = f"{mult:.6g}" + s = s.rstrip("0").rstrip(".") if "." in s else s + return s.replace(".", "_") + + def build_run( size: str, use_gate: bool, *, use_gpu: bool = False, + lr_multiplier: float | None = None, ) -> tuple[str, SpeedrunConfig]: sizes = _size_presets() if size not in sizes: @@ -486,10 +500,15 @@ def build_run( batch = _batch_sizes()[size] seq_len = model_cfg.seq_len params = int(model_cfg.total_trainable_params(llama3_tokenizer_vocab_size)) - print(params) steps = _get_num_train_steps(params, batch, seq_len, tpp=20) muon = _muon_presets()[size] + if lr_multiplier is not None: + muon = dataclasses.replace( + muon, + learning_rate=muon.learning_rate * lr_multiplier, + adam_lr=muon.adam_lr * lr_multiplier, + ) resources = _resource_presets(use_gpu=use_gpu)[size] train = SimpleTrainConfig( @@ -501,10 +520,12 @@ def build_run( steps_per_hf_export=-1, # disable checkpointing ) - run_name = f"hacktx_{size}_{'attngate' if use_gate else 'stdattn'}_{seq_len}_splash" + lr_tag = f"_lr_x{_format_multiplier_label(lr_multiplier)}" if lr_multiplier is not None else "" + run_name = f"hacktx_{size}_{'attngate' if use_gate else 'stdattn'}_{seq_len}_splash_lr_sweep{lr_tag}" desc = ( f"Hackable Transformer ({size}); " - f"{'Gated Attention' if use_gate else 'Std Attention'} (Splash)" + f"{'Gated Attention' if use_gate else 'Std Attention'} (Splash); " + f"LR sweep multiplier={lr_multiplier if lr_multiplier is not None else 1.0:g}" ) cfg = SpeedrunConfig(author=AUTHOR, description=desc, model_config=model_cfg, train_config=train) return run_name, cfg @@ -530,7 +551,10 @@ def build_run( use_gpu = bool(int(os.environ.get("SR_USE_GPU", "0"))) use_gate = True steps = [] + # Sweep LR from 1x to 4x at 0.5x increments (paper suggests higher LR for gated attention) + lr_mults = _lr_multipliers(start=1.0, stop=4.0, step=0.5) for s in sizes: - name, cfg = build_run(s, use_gate, use_gpu=use_gpu) - steps.extend(default_speedrun(name, cfg)) - executor_main(steps=steps, description="Hackable transformer gated-attention sweep") + for m in lr_mults: + name, cfg = build_run(s, use_gate, use_gpu=use_gpu, lr_multiplier=m) + steps.extend(default_speedrun(name, cfg)) + executor_main(steps=steps, description="Hackable transformer gated-attention LR sweep") diff --git a/lib/levanter/src/levanter/layers/attention.py b/lib/levanter/src/levanter/layers/attention.py index a910695a47..8eeff7623e 100644 --- a/lib/levanter/src/levanter/layers/attention.py +++ b/lib/levanter/src/levanter/layers/attention.py @@ -9,7 +9,7 @@ from dataclasses import dataclass from enum import Enum from numbers import Integral -from typing import Optional, Union, cast, overload +from typing import Literal, Optional, Union, cast, overload import equinox as eqx import jax @@ -1522,12 +1522,12 @@ class AttentionConfig: scaling_factor: Optional[float] = None logits_soft_cap: Optional[float] = None qk_norm: Optional[LayerNormConfigBase] = None - gated: bool = False + gated: Literal["none", "headwise", "elementwise"] = "none" def __post_init__(self): - assert ( - self.num_heads % self.num_kv_heads == 0 - ), f"num_heads={self.num_heads} not divisible by num_kv_heads={self.num_kv_heads}." + assert self.num_heads % self.num_kv_heads == 0, ( + f"num_heads={self.num_heads} not divisible by num_kv_heads={self.num_kv_heads}." + ) @property def head_size(self) -> int: @@ -1563,6 +1563,20 @@ def use_flash_attention(self) -> bool: return default_attention_type() != AttentionBackend.VANILLA return self.attn_backend != AttentionBackend.VANILLA + @property + def GateSize(self) -> Axis: + """Axis for the gate output size based on gating mode. + + For headwise gating, returns an axis of size 1 (one scalar per head). + For elementwise gating, returns an axis of size head_size (one value per element). + + The axis is always named "gate_size" for consistency. + """ + if self.gated == "headwise": + return Axis("gate_size", 1) + else: # elementwise + return Axis("gate_size", self.head_size) + class Attention(eqx.Module): """A multi-head attention layer that uses dot product attention. @@ -1570,7 +1584,7 @@ class Attention(eqx.Module): This is a general-purpose attention layer that can be used in various transformer architectures. It supports multi-head attention (MHA), multi-query attention (MQA), and grouped-query attention (GQA). - Supports ROPE and QK normalization. We should probably not add much more stuff. + Supports ROPE, QK normalization, and gated attention (headwise or elementwise). """ config: AttentionConfig = eqx.field(static=True) @@ -1581,20 +1595,26 @@ class Attention(eqx.Module): q_norm: Optional[LayerNormBase] = None k_norm: Optional[LayerNormBase] = None rot_embs: Optional[RotaryEmbeddings] = None - gate_proj: Optional[hnn.Linear] = None @staticmethod def init(config: AttentionConfig, *, key) -> "Attention": use_bias = config.use_bias use_output_bias = config.use_output_bias if config.use_output_bias is not None else use_bias - if config.gated: - k_q, k_k, k_v, k_o, k_g = jrandom.split(key, 5) + k_q, k_k, k_v, k_o = jrandom.split(key, 4) + + # For gated attention, the gate is fused with Q projection (following the paper). + # The Q projection outputs [KVHeads, QHeadsPerGroup, HeadSize + GateSize]. + # For headwise gating: GateSize = 1 (one scalar per head) + # For elementwise gating: GateSize = HeadSize (one value per element) + if config.gated != "none": + QGateAxis = Axis("q_gate_combined", config.HeadSize.size + config.GateSize.size) + q_out_axes = (config.KVHeads, config.QHeadsPerGroup, QGateAxis) else: - k_q, k_k, k_v, k_o = jrandom.split(key, 4) - k_g = None + q_out_axes = (config.KVHeads, config.QHeadsPerGroup, config.HeadSize) + q_proj = hnn.Linear.init( In=config.Embed, - Out=(config.KVHeads, config.QHeadsPerGroup, config.HeadSize), + Out=q_out_axes, key=k_q, use_bias=use_bias, out_first=True, @@ -1621,16 +1641,6 @@ def init(config: AttentionConfig, *, key) -> "Attention": out_first=True, ) - gate_proj = None - if config.gated: - gate_proj = hnn.Linear.init( - In=config.Embed, - Out=(config.KVHeads, config.QHeadsPerGroup, config.HeadSize), - key=k_g, - use_bias=use_bias, - out_first=True, - ) - q_norm = None k_norm = None if config.qk_norm is not None: @@ -1640,7 +1650,27 @@ def init(config: AttentionConfig, *, key) -> "Attention": # Build rotary embeddings once during initialization if configured rot_embs = config.rope.build(config.HeadSize) if config.rope is not None else None - return Attention(config, q_proj, k_proj, v_proj, o_proj, q_norm, k_norm, rot_embs, gate_proj) + return Attention(config, q_proj, k_proj, v_proj, o_proj, q_norm, k_norm, rot_embs) + + def _split_q_and_gate(self, q_combined: NamedArray) -> tuple[NamedArray, NamedArray | None]: + """Split the combined Q+gate projection into Q and gate components. + + Args: + q_combined: The combined output from q_proj with shape [..., q_gate_combined]. + + Returns: + A tuple of (q, gate) where: + - q has shape [..., head_size] + - gate has shape [..., gate_size] (1 for headwise, head_size for elementwise) + or None if gating is disabled. + """ + if self.config.gated == "none": + return q_combined, None + + combined_axis = q_combined.resolve_axis("q_gate_combined") + q, gate = hax.split(q_combined, axis=combined_axis, new_axes=[self.config.HeadSize, self.config.GateSize]) + + return q, gate def empty_page_cache(self, spec: PageTableSpec, *, dtype) -> "KvPageCache": return KvPageCache.init(spec, self.config.KVHeads, self.config.HeadSize, dtype=dtype) @@ -1654,14 +1684,10 @@ def __call__( key=None, pos_ids: NamedArray | None = None, ) -> NamedArray: - if self.gate_proj is not None: - key_proj, key_o, key_g = maybe_rng_split(key, 3) - else: - key_proj, key_o = maybe_rng_split(key, 2) - key_g = None + key_proj, key_o = maybe_rng_split(key, 2) - # Shared computation of q, k, v - q, k, v = self._compute_qkv(x, key=key_proj, pos_ids=pos_ids) + # Shared computation of q, k, v (and gate if gated) + q, k, v, gate = self._compute_qkv(x, key=key_proj, pos_ids=pos_ids) # Reshape for attention kernels (convert embed → heads/head_size) q = q.rearrange((..., "kv_head", "q_heads_per_group", "position", "head_size")) @@ -1690,10 +1716,10 @@ def __call__( prng=key, ) - if self.gate_proj is not None: - gate = self.gate_proj(x, key=key_g) - gate = gate.rearrange((..., "kv_head", "q_heads_per_group", "position", "head_size")) + if gate is not None: + gate = gate.rearrange((..., "kv_head", "q_heads_per_group", "position", "gate_size")) gate = hax.nn.sigmoid(gate) + gate = gate.rename({"gate_size": "head_size"}) attn_output = attn_output * gate # Flatten heads and apply output projection @@ -1722,14 +1748,9 @@ def paged_decode( describes where the new keys and values should be written in ``kv_cache``. Currently only causal masks are supported. """ + key_proj, key_o = maybe_rng_split(key, 2) - if self.gate_proj is not None: - key_proj, key_o, key_g = maybe_rng_split(key, 3) - else: - key_proj, key_o = maybe_rng_split(key, 2) - key_g = None - - q, k, v = self._compute_qkv(x, key=key_proj, pos_ids=pos_ids) + q, k, v, gate = self._compute_qkv(x, key=key_proj, pos_ids=pos_ids) kv_cache = kv_cache.update(batch_info, k, v) @@ -1750,10 +1771,10 @@ def paged_decode( soft_cap=self.config.logits_soft_cap, ) - if self.gate_proj is not None: - gate = self.gate_proj(x, key=key_g) - gate = gate.rearrange((..., "kv_head", "q_heads_per_group", "position", "head_size")) + if gate is not None: + gate = gate.rearrange((..., "kv_head", "q_heads_per_group", "position", "gate_size")) gate = hax.nn.sigmoid(gate) + gate = gate.rename({"gate_size": "head_size"}) attn_tokens = attn_tokens * gate attn_output = attn_tokens.flatten_axes(("kv_head", "q_heads_per_group"), "heads") @@ -1769,30 +1790,37 @@ def _compute_qkv( *, key, pos_ids: NamedArray | None = None, - ) -> tuple[NamedArray, NamedArray, NamedArray]: - """Project *x* to Q, K and V and apply all per-head processing.""" + ) -> tuple[NamedArray, NamedArray, NamedArray, NamedArray | None]: + """Project *x* to Q, K and V (and gate if gated) and apply all per-head processing. + + Returns: + A tuple of (q, k, v, gate) where gate is None if gating is disabled. + """ # Split the projection key into three – one for each of Q, K, V key_q, key_k, key_v = maybe_rng_split(key, 3) # Linear projections - q = self.q_proj(x, key=key_q) + q_combined = self.q_proj(x, key=key_q) k = self.k_proj(x, key=key_k) v = self.v_proj(x, key=key_v) - # Optional QK layer-norm + # Split Q and gate if gated attention is enabled + q, gate = self._split_q_and_gate(q_combined) + + # Optional QK layer-norm (applied only to Q, not gate) if self.config.qk_norm is not None: q = self.q_norm(q) # type: ignore[misc] k = self.k_norm(k) # type: ignore[misc] - # Apply rotary embeddings if configured + # Apply rotary embeddings if configured (applied only to Q, not gate) if self.rot_embs is not None: if pos_ids is None: pos_ids = hax.arange(x.resolve_axis("position")) q = self.rot_embs(q, pos_ids).astype(q.dtype) k = self.rot_embs(k, pos_ids).astype(k.dtype) - return q, k, v + return q, k, v, gate @named_call @@ -2256,9 +2284,9 @@ def __call__( if self.config.q_lora_rank is None: q = self.q_proj(x, key=k_q_a) else: - assert ( - self.q_a_proj is not None and self.q_a_norm is not None and self.q_b_proj is not None - ), "q_lora_rank defined, but LoRA matrices are not." + assert self.q_a_proj is not None and self.q_a_norm is not None and self.q_b_proj is not None, ( + "q_lora_rank defined, but LoRA matrices are not." + ) q = self.q_a_proj(x, key=k_q_a) q = self.q_a_norm(q) q = self.q_b_proj(q, key=k_q_b) @@ -2331,7 +2359,6 @@ def init(config: AttentionConfig, *, key) -> "AttentionWithSink": base.q_norm, base.k_norm, base.rot_embs, - base.gate_proj, sinks, ) @@ -2344,33 +2371,15 @@ def __call__( key=None, pos_ids: NamedArray | None = None, ) -> NamedArray: - key_q, key_k, key_v, key_o = maybe_rng_split(key, 4) - key_g = None + key_proj, key_o = maybe_rng_split(key, 2) - if self.gate_proj is not None: - key_q, key_k, key_v, key_o, key_g = maybe_rng_split(key, 5) - - q_proj = self.q_proj(x, key=key_q) - k_proj = self.k_proj(x, key=key_k) - v = self.v_proj(x, key=key_v) - - if self.config.qk_norm is not None: - q = self.q_norm(q_proj) # type: ignore[misc] - k = self.k_norm(k_proj) # type: ignore[misc] - else: - q = q_proj - k = k_proj + # Compute q, k, v (and gate if gated) + q, k, v, gate = self._compute_qkv(x, key=key_proj, pos_ids=pos_ids) q = q.rearrange((..., "kv_head", "q_heads_per_group", "position", "head_size")) k = k.rearrange((..., "kv_head", "position", "head_size")) v = v.rearrange((..., "kv_head", "position", "head_size")) - if self.rot_embs is not None: - if pos_ids is None: - pos_ids = hax.arange(x.resolve_axis("position"), dtype=jnp.int32) - q = self.rot_embs(q, pos_ids) - k = self.rot_embs(k, pos_ids) - k = k.rename({"position": "key_position"}) v = v.rename({"position": "key_position"}) @@ -2393,10 +2402,11 @@ def __call__( attn_sink=self.sinks, ) - if self.gate_proj is not None: - gate = self.gate_proj(x, key=key_g) - gate = gate.rearrange((..., "kv_head", "q_heads_per_group", "position", "head_size")) + if gate is not None: + gate = gate.rearrange((..., "kv_head", "q_heads_per_group", "position", "gate_size")) gate = hax.nn.sigmoid(gate) + # Rename gate_size to head_size for proper broadcasting/multiplication + gate = gate.rename({"gate_size": "head_size"}) attn_output = attn_output * gate attn_output = attn_output.flatten_axes(("kv_head", "q_heads_per_group"), "heads") diff --git a/lib/levanter/tests/test_attention.py b/lib/levanter/tests/test_attention.py index 64f7886fb3..37943237cd 100644 --- a/lib/levanter/tests/test_attention.py +++ b/lib/levanter/tests/test_attention.py @@ -131,12 +131,57 @@ def test_attention_with_sink_module(): def test_attention_with_gating_module(): + """Test elementwise gated attention. + + The gate is fused with q_proj. When gated="elementwise", q_proj outputs + [head_size + head_size] and the second half is the gate. + + With zero weights/biases for Q (and gate), the gate output is sigmoid(0) = 0.5. + With v_proj bias=1 and o_proj weight=1, the attention output before gating is 1. + After gating: 1 * 0.5 = 0.5 + """ + Pos = hax.Axis("position", 2) + Embed = hax.Axis("embed", 1) + + config = AttentionConfig(Embed=Embed, num_heads=1, num_kv_heads=1, use_bias=True, gated="elementwise") + attn = Attention.init(config, key=jrandom.PRNGKey(0)) + + # q_proj now has shape [embed, kv_head, q_heads_per_group, head_size*2] for elementwise gating + # The first half is Q, the second half is the gate + attn = eqx.tree_at(lambda a: a.q_proj.weight, attn, hax.zeros(attn.q_proj.weight.axes)) + attn = eqx.tree_at(lambda a: a.q_proj.bias, attn, hax.zeros(attn.q_proj.bias.axes)) + attn = eqx.tree_at(lambda a: a.k_proj.weight, attn, hax.zeros(attn.k_proj.weight.axes)) + attn = eqx.tree_at(lambda a: a.k_proj.bias, attn, hax.zeros(attn.k_proj.bias.axes)) + attn = eqx.tree_at(lambda a: a.v_proj.weight, attn, hax.zeros(attn.v_proj.weight.axes)) + attn = eqx.tree_at(lambda a: a.v_proj.bias, attn, hax.ones(attn.v_proj.bias.axes)) + attn = eqx.tree_at(lambda a: a.o_proj.weight, attn, hax.ones(attn.o_proj.weight.axes)) + attn = eqx.tree_at(lambda a: a.o_proj.bias, attn, hax.zeros(attn.o_proj.bias.axes)) + + x = hax.zeros((Pos, Embed)) + out = attn(x, None) + + expected = np.full((2, 1), 0.5) + assert_trees_all_close(out.array, expected) + + +def test_attention_with_headwise_gating_module(): + """Test headwise gated attention. + + The gate is fused with q_proj. When gated="headwise", q_proj outputs + [head_size + 1] and the last value is the gate (one scalar per head). + + With zero weights/biases for Q (and gate), the gate output is sigmoid(0) = 0.5. + With v_proj bias=1 and o_proj weight=1, the attention output before gating is 1. + After gating: 1 * 0.5 = 0.5 + """ Pos = hax.Axis("position", 2) Embed = hax.Axis("embed", 1) - config = AttentionConfig(Embed=Embed, num_heads=1, num_kv_heads=1, use_bias=True, gated=True) + config = AttentionConfig(Embed=Embed, num_heads=1, num_kv_heads=1, use_bias=True, gated="headwise") attn = Attention.init(config, key=jrandom.PRNGKey(0)) + # q_proj now has shape [embed, kv_head, q_heads_per_group, head_size+1] for headwise gating + # The first head_size values are Q, the last value is the gate attn = eqx.tree_at(lambda a: a.q_proj.weight, attn, hax.zeros(attn.q_proj.weight.axes)) attn = eqx.tree_at(lambda a: a.q_proj.bias, attn, hax.zeros(attn.q_proj.bias.axes)) attn = eqx.tree_at(lambda a: a.k_proj.weight, attn, hax.zeros(attn.k_proj.weight.axes)) @@ -145,8 +190,6 @@ def test_attention_with_gating_module(): attn = eqx.tree_at(lambda a: a.v_proj.bias, attn, hax.ones(attn.v_proj.bias.axes)) attn = eqx.tree_at(lambda a: a.o_proj.weight, attn, hax.ones(attn.o_proj.weight.axes)) attn = eqx.tree_at(lambda a: a.o_proj.bias, attn, hax.zeros(attn.o_proj.bias.axes)) - attn = eqx.tree_at(lambda a: a.gate_proj.weight, attn, hax.zeros(attn.gate_proj.weight.axes)) - attn = eqx.tree_at(lambda a: a.gate_proj.bias, attn, hax.zeros(attn.gate_proj.bias.axes)) x = hax.zeros((Pos, Embed)) out = attn(x, None) @@ -517,9 +560,9 @@ def test_causal_offset_cross_attention(impl): precision=Precision.HIGHEST, ) - assert not jnp.allclose( - offset_out.array, wrong_out.array, atol=2e-3, rtol=2e-3 - ), "Output should differ without offset" + assert not jnp.allclose(offset_out.array, wrong_out.array, atol=2e-3, rtol=2e-3), ( + "Output should differ without offset" + ) # This is a bottleneck in tests From 98bd063fc1bf41d2fc9f7238cb987c5fd42a5c8e Mon Sep 17 00:00:00 2001 From: "Pinlin [Calvin] Xu" Date: Sat, 20 Dec 2025 17:19:43 -1000 Subject: [PATCH 06/11] update w/ main --- .../hackable_transformer_attn_gate.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/experiments/speedrun/hackable_transformer_attn_gate/hackable_transformer_attn_gate.py b/experiments/speedrun/hackable_transformer_attn_gate/hackable_transformer_attn_gate.py index 55c6f01177..0e88118aec 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/hackable_transformer_attn_gate.py +++ b/experiments/speedrun/hackable_transformer_attn_gate/hackable_transformer_attn_gate.py @@ -162,14 +162,14 @@ def attention_config(self) -> AttentionConfig: def actual_head_size(self) -> int: return self.head_dim or (self.hidden_dim // self.num_heads) - def flops_per_token(self, vocab_size: int) -> float | None: + def flops_per_token(self, vocab_size: int, context_length: int) -> float | None: return lm_flops_per_token( hidden_dim=self.hidden_dim, intermediate_dim=self.intermediate_dim, num_layers=self.num_layers, num_kv_heads=self.num_kv_heads, num_heads=self.num_heads, - seq_len=self.seq_len, + seq_len=context_length, vocab_size=vocab_size, glu=True, ) @@ -366,7 +366,7 @@ def _get_num_train_steps(param_count: int, batch_size: int, seq_len: int, tpp: i def _size_presets() -> dict[str, HackableTransformerConfig]: base = dict( - seq_len=4096, + max_seq_len=4096, rope=DefaultRotaryEmbeddingsConfig(), # e.g., Llama3RotaryEmbeddingsConfig() attn_backend=None, qk_norm=None, # e.g. RmsNormConfig(use_weight=True, eps=1e-5) @@ -465,7 +465,7 @@ def _resource_presets(use_gpu: bool = False): "130m": ResourceConfig.with_tpu("v5p-32"), "300m": ResourceConfig.with_tpu("v5p-32"), "520m": ResourceConfig.with_tpu("v5p-32"), - "1_2b": ResourceConfig.with_tpu("v5p-32"), + "1_2b": ResourceConfig.with_tpu("v5p-64"), } @@ -547,7 +547,8 @@ def build_run( _cls.__module__ = _IMPORT_PATH ### - sizes = ["130m", "300m", "520m", "1_2b"] + # sizes = ["130m", "300m", "520m", "1_2b"] + sizes = ["1_2b"] use_gpu = bool(int(os.environ.get("SR_USE_GPU", "0"))) use_gate = True steps = [] From 25e5a9441af705e334f826dc3e51038e591f8085 Mon Sep 17 00:00:00 2001 From: "Pinlin [Calvin] Xu" Date: Sat, 27 Dec 2025 19:01:40 -0800 Subject: [PATCH 07/11] Initial LR sweep results --- .../hackable_transformer_attn_gate.py | 11 +- .../lr_x1/130m/speedrun_results.json | 141 +++++++++++++++++ .../lr_x1/1_2b/speedrun_results.json | 143 ++++++++++++++++++ .../lr_x1/300m/speedrun_results.json | 141 +++++++++++++++++ .../lr_x1/520m/speedrun_results.json | 141 +++++++++++++++++ .../lr_x1_5/130m/speedrun_results.json | 141 +++++++++++++++++ .../lr_x1_5/1_2b/speedrun_results.json | 143 ++++++++++++++++++ .../lr_x1_5/300m/speedrun_results.json | 141 +++++++++++++++++ .../lr_x1_5/520m/speedrun_results.json | 141 +++++++++++++++++ .../lr_x2/130m/speedrun_results.json | 141 +++++++++++++++++ .../lr_x2/1_2b/speedrun_results.json | 143 ++++++++++++++++++ .../lr_x2/300m/speedrun_results.json | 141 +++++++++++++++++ .../lr_x2/520m/speedrun_results.json | 141 +++++++++++++++++ .../lr_x2_5/130m/speedrun_results.json | 141 +++++++++++++++++ .../lr_x2_5/1_2b/speedrun_results.json | 143 ++++++++++++++++++ .../lr_x2_5/300m/speedrun_results.json | 141 +++++++++++++++++ .../lr_x2_5/520m/speedrun_results.json | 143 ++++++++++++++++++ .../lr_x3/130m/speedrun_results.json | 141 +++++++++++++++++ .../lr_x3/1_2b/speedrun_results.json | 143 ++++++++++++++++++ .../lr_x3/300m/speedrun_results.json | 141 +++++++++++++++++ .../lr_x3/520m/speedrun_results.json | 141 +++++++++++++++++ .../lr_x3_5/130m/speedrun_results.json | 141 +++++++++++++++++ .../lr_x3_5/1_2b/speedrun_results.json | 143 ++++++++++++++++++ .../lr_x3_5/300m/speedrun_results.json | 141 +++++++++++++++++ .../lr_x3_5/520m/speedrun_results.json | 141 +++++++++++++++++ .../lr_x4/130m/speedrun_results.json | 141 +++++++++++++++++ .../lr_x4/1_2b/speedrun_results.json | 143 ++++++++++++++++++ .../lr_x4/300m/speedrun_results.json | 141 +++++++++++++++++ .../lr_x4/520m/speedrun_results.json | 141 +++++++++++++++++ 29 files changed, 3972 insertions(+), 3 deletions(-) create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/lr_x1/130m/speedrun_results.json create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/lr_x1/1_2b/speedrun_results.json create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/lr_x1/300m/speedrun_results.json create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/lr_x1/520m/speedrun_results.json create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/130m/speedrun_results.json create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/1_2b/speedrun_results.json create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/300m/speedrun_results.json create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/520m/speedrun_results.json create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/lr_x2/130m/speedrun_results.json create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/lr_x2/1_2b/speedrun_results.json create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/lr_x2/300m/speedrun_results.json create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/lr_x2/520m/speedrun_results.json create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/130m/speedrun_results.json create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/1_2b/speedrun_results.json create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/300m/speedrun_results.json create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/520m/speedrun_results.json create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/lr_x3/130m/speedrun_results.json create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/lr_x3/1_2b/speedrun_results.json create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/lr_x3/300m/speedrun_results.json create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/lr_x3/520m/speedrun_results.json create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/130m/speedrun_results.json create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/1_2b/speedrun_results.json create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/300m/speedrun_results.json create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/520m/speedrun_results.json create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/lr_x4/130m/speedrun_results.json create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/lr_x4/1_2b/speedrun_results.json create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/lr_x4/300m/speedrun_results.json create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/lr_x4/520m/speedrun_results.json diff --git a/experiments/speedrun/hackable_transformer_attn_gate/hackable_transformer_attn_gate.py b/experiments/speedrun/hackable_transformer_attn_gate/hackable_transformer_attn_gate.py index 0e88118aec..b0bf58d458 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/hackable_transformer_attn_gate.py +++ b/experiments/speedrun/hackable_transformer_attn_gate/hackable_transformer_attn_gate.py @@ -46,6 +46,7 @@ import logging from dataclasses import dataclass from collections.abc import Callable +from typing import Literal import equinox as eqx import numpy as np @@ -108,7 +109,7 @@ class HackableTransformerConfig(LmConfig["HackableLMHeadModel"]): input_embedding_norm: bool = False # Attention - use_gated_attention: bool = False + use_gated_attention: Literal["none", "headwise", "elementwise"] = "none" upcast_attn: bool = False attn_backend: AttentionBackend | None = None flash_attention_block_size: int | None = None @@ -182,8 +183,12 @@ def total_trainable_params(self, vocab_size: int) -> int: + 2 * self.hidden_dim * hs * self.num_kv_heads + hs * self.num_heads * self.hidden_dim ) - if self.use_gated_attention: + if self.use_gated_attention == "headwise": + attn += self.hidden_dim * self.num_heads + elif self.use_gated_attention == "elementwise": attn += self.hidden_dim * hs * self.num_heads + else: + raise ValueError(f"Unknown gated attention mode: {self.use_gated_attention}") mlp = 3 * self.hidden_dim * self.intermediate_dim transformer = self.num_layers * (attn + mlp + 2 * self.hidden_dim) + self.hidden_dim if self.input_embedding_norm: @@ -550,7 +555,7 @@ def build_run( # sizes = ["130m", "300m", "520m", "1_2b"] sizes = ["1_2b"] use_gpu = bool(int(os.environ.get("SR_USE_GPU", "0"))) - use_gate = True + use_gate = "elementwise" steps = [] # Sweep LR from 1x to 4x at 0.5x increments (paper suggests higher LR for gated attention) lr_mults = _lr_multipliers(start=1.0, stop=4.0, step=0.5) diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x1/130m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1/130m/speedrun_results.json new file mode 100644 index 0000000000..e4b12437cc --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1/130m/speedrun_results.json @@ -0,0 +1,141 @@ +{ + "runs": [ + { + "run_info": { + "author": { + "affiliation": "Stanford University", + "name": "Calvin Xu", + "url": "https://pinlinxu.com" + }, + "description": "Hackable Transformer (130m); Gated Attention (Splash); LR sweep multiplier=1", + "device_flops": 459000000000000.0, + "eval/paloma/c4_en/bpb": 1.1556898355484009, + "model_config": { + "activation_function": "silu", + "attn_backend": null, + "cross_entropy_block_size": null, + "flash_attention_block_size": null, + "gradient_checkpointing": true, + "head_dim": null, + "hidden_dim": 512, + "initializer_range": 0.02, + "input_embedding_norm": false, + "intermediate_dim": 1792, + "layer_norm_epsilon": 1e-05, + "max_seq_len": 4096, + "num_heads": 8, + "num_kv_heads": 8, + "num_layers": 6, + "qk_norm": null, + "reference_checkpoint": "NousResearch/Llama-2-7b-hf", + "rope": { + "factor": 1.0, + "theta": 10000 + }, + "tie_word_embeddings": false, + "tokenizer": null, + "upcast_attn": false, + "use_bias": false, + "use_gated_attention": true, + "use_layer_norm_weight": true + }, + "model_flops": 2.1289341996446515e+18, + "model_flops_per_token": 227868672.0, + "model_size": 155720192, + "num_chips": 16, + "num_devices": 16, + "resources": { + "cpu": 1, + "device": { + "topology": null, + "type": "v5p-32" + }, + "disk": "1g", + "preemptible": true, + "ram": "128m", + "regions": null, + "replicas": 1 + }, + "run_completion_timestamp": "2025-12-09 10:26:57 UTC", + "tokenized_dataset": "gs://marin-us-central1/tokenized/subcache/fineweb-edu-10B-ac65f6", + "total_tokens": 3114270720, + "train_config": { + "allow_partial_checkpoint": false, + "beta1": null, + "beta2": null, + "cycle_length": null, + "data_seed": null, + "decay": null, + "ema_beta": null, + "epsilon": null, + "initialize_from_checkpoint_path": null, + "initialize_from_hf": null, + "int8": false, + "learning_rate": 0.016, + "lr_schedule": null, + "max_eval_batches": null, + "max_grad_norm": null, + "min_lr_ratio": null, + "num_train_steps": 5940, + "optimizer_config": { + "adam_lr": 0.0032, + "adam_weight_decay": null, + "backend_steps": 5, + "beta1": 0.8, + "beta2": 0.98, + "cooldown": null, + "cycle_length": null, + "cycles": null, + "decay": 0.8, + "default_weight_decay_mask": null, + "epsilon": 1e-15, + "haps": null, + "learning_rate": 0.016, + "lr": 0.02, + "lr_schedule": "linear", + "max_grad_norm": 1, + "min_lr_ratio": 0, + "momentum": 0.95, + "muon_epsilon": 1e-05, + "nesterov": true, + "rewarmup": 0.0, + "use_kimi_scaling": false, + "warmup": 0, + "weight_decay": 0.1, + "weight_decay_modules": null + }, + "per_device_eval_parallelism": null, + "profiler": false, + "profiler_num_steps": 100, + "profiler_start_step": 5, + "reset_data_loader_on_init": true, + "rewarmup": null, + "skip_bad_steps": false, + "steps_per_eval": null, + "steps_per_export": 10000, + "steps_per_hf_export": -1, + "steps_per_task_eval": null, + "train_batch_size": 128, + "train_seq_len": null, + "warmup": null, + "watch": { + "include_histograms": false, + "include_norms": true, + "include_per_parameter_norms": true, + "interval": 10, + "split_scan_layers": true, + "watch_targets": [ + "grads", + "params" + ] + }, + "weight_decay": null, + "z_loss_weight": null + }, + "training_hardware_flops": 1.1670497124689082e+19, + "training_time": 1589.1199788519991, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_130m_attngate_4096_splash_lr_sweep_lr_x1-88f232" + } + } + ] +} \ No newline at end of file diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x1/1_2b/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1/1_2b/speedrun_results.json new file mode 100644 index 0000000000..b5a75910bf --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1/1_2b/speedrun_results.json @@ -0,0 +1,143 @@ +{ + "runs": [ + { + "run_info": { + "author": { + "affiliation": "Stanford University", + "name": "Calvin Xu", + "url": "https://pinlinxu.com" + }, + "description": "Hackable Transformer (1_2b); Gated Attention (Splash); LR sweep multiplier=1", + "device_flops": 459000000000000.0, + "eval/paloma/c4_en/bpb": 0.9160401225090027, + "model_config": { + "activation_function": "silu", + "attn_backend": null, + "cross_entropy_block_size": null, + "flash_attention_block_size": null, + "gradient_checkpointing": true, + "head_dim": null, + "hidden_dim": 2048, + "initializer_range": 0.02, + "input_embedding_norm": false, + "intermediate_dim": 7168, + "layer_norm_epsilon": 1e-05, + "max_seq_len": 4096, + "num_heads": 16, + "num_kv_heads": 8, + "num_layers": 16, + "qk_norm": null, + "reference_checkpoint": "NousResearch/Llama-2-7b-hf", + "rope": { + "factor": 1.0, + "theta": 10000 + }, + "seq_len": 2048, + "tie_word_embeddings": false, + "tokenizer": null, + "upcast_attn": false, + "use_bias": false, + "use_gated_attention": true, + "use_layer_norm_weight": true + }, + "model_flops": 5.1738353514618185e+20, + "model_flops_per_token": 2877292544.0, + "model_size": 1498482688, + "num_chips": 32, + "num_devices": 32, + "resources": { + "cpu": 1, + "device": { + "kind": "tpu", + "topology": null, + "variant": "v5p-64" + }, + "disk": "1g", + "preemptible": true, + "ram": "128m", + "regions": null, + "replicas": 1 + }, + "run_completion_timestamp": "2025-12-17 04:43:09 UTC", + "tokenized_dataset": "gs://marin-us-east5/tokenized/subcache/fineweb-edu-10B-ac65f6", + "total_tokens": 59938701312, + "train_config": { + "allow_partial_checkpoint": false, + "beta1": null, + "beta2": null, + "cycle_length": null, + "data_seed": null, + "decay": null, + "ema_beta": null, + "epsilon": null, + "initialize_from_checkpoint_path": null, + "initialize_from_hf": null, + "int8": false, + "learning_rate": 0.004, + "lr_schedule": null, + "max_eval_batches": null, + "max_grad_norm": null, + "min_lr_ratio": null, + "num_train_steps": 57162, + "optimizer_config": { + "adam_lr": 0.0012, + "adam_weight_decay": null, + "backend_steps": 5, + "beta1": 0.8, + "beta2": 0.98, + "cooldown": null, + "cycle_length": null, + "cycles": null, + "decay": 1, + "default_weight_decay_mask": null, + "epsilon": 1e-15, + "haps": null, + "learning_rate": 0.004, + "lr": 0.02, + "lr_schedule": "linear", + "max_grad_norm": 2, + "min_lr_ratio": 0, + "momentum": 0.98, + "muon_epsilon": 1e-05, + "nesterov": true, + "rewarmup": 0.0, + "use_kimi_scaling": false, + "warmup": 0, + "weight_decay": 0.1, + "weight_decay_modules": null + }, + "per_device_eval_parallelism": null, + "profiler": false, + "profiler_num_steps": 100, + "profiler_start_step": 5, + "reset_data_loader_on_init": true, + "rewarmup": null, + "skip_bad_steps": false, + "steps_per_eval": null, + "steps_per_export": 10000, + "steps_per_hf_export": -1, + "steps_per_task_eval": null, + "train_batch_size": 256, + "train_seq_len": null, + "warmup": null, + "watch": { + "include_histograms": false, + "include_norms": true, + "include_per_parameter_norms": true, + "interval": 10, + "split_scan_layers": true, + "watch_targets": [ + "grads", + "params" + ] + }, + "weight_decay": null, + "z_loss_weight": null + }, + "training_hardware_flops": 1.80638282968129e+21, + "training_time": 122983.58045215753, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_1_2b_attngate_2048_splash_lr_sweep_lr_x1-ecb416" + } + } + ] +} \ No newline at end of file diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x1/300m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1/300m/speedrun_results.json new file mode 100644 index 0000000000..df4850adfe --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1/300m/speedrun_results.json @@ -0,0 +1,141 @@ +{ + "runs": [ + { + "run_info": { + "author": { + "affiliation": "Stanford University", + "name": "Calvin Xu", + "url": "https://pinlinxu.com" + }, + "description": "Hackable Transformer (300m); Gated Attention (Splash); LR sweep multiplier=1", + "device_flops": 459000000000000.0, + "eval/paloma/c4_en/bpb": 1.0535272359848022, + "model_config": { + "activation_function": "silu", + "attn_backend": null, + "cross_entropy_block_size": null, + "flash_attention_block_size": null, + "gradient_checkpointing": true, + "head_dim": null, + "hidden_dim": 768, + "initializer_range": 0.02, + "input_embedding_norm": false, + "intermediate_dim": 2688, + "layer_norm_epsilon": 1e-05, + "max_seq_len": 4096, + "num_heads": 12, + "num_kv_heads": 12, + "num_layers": 12, + "qk_norm": null, + "reference_checkpoint": "NousResearch/Llama-2-7b-hf", + "rope": { + "factor": 1.0, + "theta": 10000 + }, + "tie_word_embeddings": false, + "tokenizer": null, + "upcast_attn": false, + "use_bias": false, + "use_gated_attention": true, + "use_layer_norm_weight": true + }, + "model_flops": 1.021384111077458e+19, + "model_flops_per_token": 555024384.0, + "model_size": 306727680, + "num_chips": 16, + "num_devices": 16, + "resources": { + "cpu": 1, + "device": { + "topology": null, + "type": "v5p-32" + }, + "disk": "1g", + "preemptible": true, + "ram": "128m", + "regions": null, + "replicas": 1 + }, + "run_completion_timestamp": "2025-12-09 12:44:00 UTC", + "tokenized_dataset": "gs://marin-us-central1/tokenized/subcache/fineweb-edu-10B-ac65f6", + "total_tokens": 6134169600, + "train_config": { + "allow_partial_checkpoint": false, + "beta1": null, + "beta2": null, + "cycle_length": null, + "data_seed": null, + "decay": null, + "ema_beta": null, + "epsilon": null, + "initialize_from_checkpoint_path": null, + "initialize_from_hf": null, + "int8": false, + "learning_rate": 0.008, + "lr_schedule": null, + "max_eval_batches": null, + "max_grad_norm": null, + "min_lr_ratio": null, + "num_train_steps": 11700, + "optimizer_config": { + "adam_lr": 0.0024, + "adam_weight_decay": null, + "backend_steps": 5, + "beta1": 0.8, + "beta2": 0.98, + "cooldown": null, + "cycle_length": null, + "cycles": null, + "decay": 0.8, + "default_weight_decay_mask": null, + "epsilon": 1e-15, + "haps": null, + "learning_rate": 0.008, + "lr": 0.02, + "lr_schedule": "linear", + "max_grad_norm": 1, + "min_lr_ratio": 0, + "momentum": 0.98, + "muon_epsilon": 1e-05, + "nesterov": true, + "rewarmup": 0.0, + "use_kimi_scaling": false, + "warmup": 0, + "weight_decay": 0.1, + "weight_decay_modules": null + }, + "per_device_eval_parallelism": null, + "profiler": false, + "profiler_num_steps": 100, + "profiler_start_step": 5, + "reset_data_loader_on_init": true, + "rewarmup": null, + "skip_bad_steps": false, + "steps_per_eval": null, + "steps_per_export": 10000, + "steps_per_hf_export": -1, + "steps_per_task_eval": null, + "train_batch_size": 128, + "train_seq_len": null, + "warmup": null, + "watch": { + "include_histograms": false, + "include_norms": true, + "include_per_parameter_norms": true, + "interval": 10, + "split_scan_layers": true, + "watch_targets": [ + "grads", + "params" + ] + }, + "weight_decay": null, + "z_loss_weight": null + }, + "training_hardware_flops": 5.068078886783368e+19, + "training_time": 6900.97887633901, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_300m_attngate_4096_splash_lr_sweep_lr_x1-25ee3b" + } + } + ] +} \ No newline at end of file diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x1/520m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1/520m/speedrun_results.json new file mode 100644 index 0000000000..a2423e58ac --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1/520m/speedrun_results.json @@ -0,0 +1,141 @@ +{ + "runs": [ + { + "run_info": { + "author": { + "affiliation": "Stanford University", + "name": "Calvin Xu", + "url": "https://pinlinxu.com" + }, + "description": "Hackable Transformer (520m); Gated Attention (Splash); LR sweep multiplier=1", + "device_flops": 459000000000000.0, + "eval/paloma/c4_en/bpb": 0.9799903035163879, + "model_config": { + "activation_function": "silu", + "attn_backend": null, + "cross_entropy_block_size": null, + "flash_attention_block_size": null, + "gradient_checkpointing": true, + "head_dim": null, + "hidden_dim": 1024, + "initializer_range": 0.02, + "input_embedding_norm": false, + "intermediate_dim": 3584, + "layer_norm_epsilon": 1e-05, + "max_seq_len": 4096, + "num_heads": 16, + "num_kv_heads": 8, + "num_layers": 24, + "qk_norm": null, + "reference_checkpoint": "NousResearch/Llama-2-7b-hf", + "rope": { + "factor": 1.0, + "theta": 10000 + }, + "tie_word_embeddings": false, + "tokenizer": null, + "upcast_attn": false, + "use_bias": false, + "use_gated_attention": true, + "use_layer_norm_weight": true + }, + "model_flops": 5.081733891346976e+19, + "model_flops_per_token": 1349517312.0, + "model_size": 627622912, + "num_chips": 16, + "num_devices": 16, + "resources": { + "cpu": 1, + "device": { + "topology": null, + "type": "v5p-32" + }, + "disk": "1g", + "preemptible": true, + "ram": "128m", + "regions": null, + "replicas": 1 + }, + "run_completion_timestamp": "2025-12-09 19:07:36 UTC", + "tokenized_dataset": "gs://marin-us-central1/tokenized/subcache/fineweb-edu-10B-ac65f6", + "total_tokens": 12551979008, + "train_config": { + "allow_partial_checkpoint": false, + "beta1": null, + "beta2": null, + "cycle_length": null, + "data_seed": null, + "decay": null, + "ema_beta": null, + "epsilon": null, + "initialize_from_checkpoint_path": null, + "initialize_from_hf": null, + "int8": false, + "learning_rate": 0.008, + "lr_schedule": null, + "max_eval_batches": null, + "max_grad_norm": null, + "min_lr_ratio": null, + "num_train_steps": 23941, + "optimizer_config": { + "adam_lr": 0.0024, + "adam_weight_decay": null, + "backend_steps": 5, + "beta1": 0.8, + "beta2": 0.98, + "cooldown": null, + "cycle_length": null, + "cycles": null, + "decay": 1, + "default_weight_decay_mask": null, + "epsilon": 1e-25, + "haps": null, + "learning_rate": 0.008, + "lr": 0.02, + "lr_schedule": "linear", + "max_grad_norm": 1, + "min_lr_ratio": 0, + "momentum": 0.98, + "muon_epsilon": 1e-05, + "nesterov": true, + "rewarmup": 0.0, + "use_kimi_scaling": false, + "warmup": 0, + "weight_decay": 0.1, + "weight_decay_modules": null + }, + "per_device_eval_parallelism": null, + "profiler": false, + "profiler_num_steps": 100, + "profiler_start_step": 5, + "reset_data_loader_on_init": true, + "rewarmup": null, + "skip_bad_steps": false, + "steps_per_eval": null, + "steps_per_export": 10000, + "steps_per_hf_export": -1, + "steps_per_task_eval": null, + "train_batch_size": 128, + "train_seq_len": null, + "warmup": null, + "watch": { + "include_histograms": false, + "include_norms": true, + "include_per_parameter_norms": true, + "interval": 10, + "split_scan_layers": true, + "watch_targets": [ + "grads", + "params" + ] + }, + "weight_decay": null, + "z_loss_weight": null + }, + "training_hardware_flops": 2.2386170836914753e+20, + "training_time": 30482.258764862137, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_520m_attngate_4096_splash_lr_sweep_lr_x1-bca683" + } + } + ] +} \ No newline at end of file diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/130m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/130m/speedrun_results.json new file mode 100644 index 0000000000..8ded8f4651 --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/130m/speedrun_results.json @@ -0,0 +1,141 @@ +{ + "runs": [ + { + "run_info": { + "author": { + "affiliation": "Stanford University", + "name": "Calvin Xu", + "url": "https://pinlinxu.com" + }, + "description": "Hackable Transformer (130m); Gated Attention (Splash); LR sweep multiplier=1.5", + "device_flops": 459000000000000.0, + "eval/paloma/c4_en/bpb": 1.1563303470611572, + "model_config": { + "activation_function": "silu", + "attn_backend": null, + "cross_entropy_block_size": null, + "flash_attention_block_size": null, + "gradient_checkpointing": true, + "head_dim": null, + "hidden_dim": 512, + "initializer_range": 0.02, + "input_embedding_norm": false, + "intermediate_dim": 1792, + "layer_norm_epsilon": 1e-05, + "max_seq_len": 4096, + "num_heads": 8, + "num_kv_heads": 8, + "num_layers": 6, + "qk_norm": null, + "reference_checkpoint": "NousResearch/Llama-2-7b-hf", + "rope": { + "factor": 1.0, + "theta": 10000 + }, + "tie_word_embeddings": false, + "tokenizer": null, + "upcast_attn": false, + "use_bias": false, + "use_gated_attention": true, + "use_layer_norm_weight": true + }, + "model_flops": 2.1289341996446515e+18, + "model_flops_per_token": 227868672.0, + "model_size": 155720192, + "num_chips": 16, + "num_devices": 16, + "resources": { + "cpu": 1, + "device": { + "topology": null, + "type": "v5p-32" + }, + "disk": "1g", + "preemptible": true, + "ram": "128m", + "regions": null, + "replicas": 1 + }, + "run_completion_timestamp": "2025-12-09 11:00:02 UTC", + "tokenized_dataset": "gs://marin-us-central1/tokenized/subcache/fineweb-edu-10B-ac65f6", + "total_tokens": 3114270720, + "train_config": { + "allow_partial_checkpoint": false, + "beta1": null, + "beta2": null, + "cycle_length": null, + "data_seed": null, + "decay": null, + "ema_beta": null, + "epsilon": null, + "initialize_from_checkpoint_path": null, + "initialize_from_hf": null, + "int8": false, + "learning_rate": 0.024, + "lr_schedule": null, + "max_eval_batches": null, + "max_grad_norm": null, + "min_lr_ratio": null, + "num_train_steps": 5940, + "optimizer_config": { + "adam_lr": 0.0048000000000000004, + "adam_weight_decay": null, + "backend_steps": 5, + "beta1": 0.8, + "beta2": 0.98, + "cooldown": null, + "cycle_length": null, + "cycles": null, + "decay": 0.8, + "default_weight_decay_mask": null, + "epsilon": 1e-15, + "haps": null, + "learning_rate": 0.024, + "lr": 0.02, + "lr_schedule": "linear", + "max_grad_norm": 1, + "min_lr_ratio": 0, + "momentum": 0.95, + "muon_epsilon": 1e-05, + "nesterov": true, + "rewarmup": 0.0, + "use_kimi_scaling": false, + "warmup": 0, + "weight_decay": 0.1, + "weight_decay_modules": null + }, + "per_device_eval_parallelism": null, + "profiler": false, + "profiler_num_steps": 100, + "profiler_start_step": 5, + "reset_data_loader_on_init": true, + "rewarmup": null, + "skip_bad_steps": false, + "steps_per_eval": null, + "steps_per_export": 10000, + "steps_per_hf_export": -1, + "steps_per_task_eval": null, + "train_batch_size": 128, + "train_seq_len": null, + "warmup": null, + "watch": { + "include_histograms": false, + "include_norms": true, + "include_per_parameter_norms": true, + "interval": 10, + "split_scan_layers": true, + "watch_targets": [ + "grads", + "params" + ] + }, + "weight_decay": null, + "z_loss_weight": null + }, + "training_hardware_flops": 1.1669741753874995e+19, + "training_time": 1589.0171233489916, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_130m_attngate_4096_splash_lr_sweep_lr_x1_5-c1c1a5" + } + } + ] +} \ No newline at end of file diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/1_2b/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/1_2b/speedrun_results.json new file mode 100644 index 0000000000..9988e91b5c --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/1_2b/speedrun_results.json @@ -0,0 +1,143 @@ +{ + "runs": [ + { + "run_info": { + "author": { + "affiliation": "Stanford University", + "name": "Calvin Xu", + "url": "https://pinlinxu.com" + }, + "description": "Hackable Transformer (1_2b); Gated Attention (Splash); LR sweep multiplier=1.5", + "device_flops": 459000000000000.0, + "eval/paloma/c4_en/bpb": 0.9122039675712585, + "model_config": { + "activation_function": "silu", + "attn_backend": null, + "cross_entropy_block_size": null, + "flash_attention_block_size": null, + "gradient_checkpointing": true, + "head_dim": null, + "hidden_dim": 2048, + "initializer_range": 0.02, + "input_embedding_norm": false, + "intermediate_dim": 7168, + "layer_norm_epsilon": 1e-05, + "max_seq_len": 4096, + "num_heads": 16, + "num_kv_heads": 8, + "num_layers": 16, + "qk_norm": null, + "reference_checkpoint": "NousResearch/Llama-2-7b-hf", + "rope": { + "factor": 1.0, + "theta": 10000 + }, + "seq_len": 2048, + "tie_word_embeddings": false, + "tokenizer": null, + "upcast_attn": false, + "use_bias": false, + "use_gated_attention": true, + "use_layer_norm_weight": true + }, + "model_flops": 5.1738353514618185e+20, + "model_flops_per_token": 2877292544.0, + "model_size": 1498482688, + "num_chips": 32, + "num_devices": 32, + "resources": { + "cpu": 1, + "device": { + "kind": "tpu", + "topology": null, + "variant": "v5p-64" + }, + "disk": "1g", + "preemptible": true, + "ram": "128m", + "regions": null, + "replicas": 1 + }, + "run_completion_timestamp": "2025-12-17 00:31:25 UTC", + "tokenized_dataset": "gs://marin-us-east5/tokenized/subcache/fineweb-edu-10B-ac65f6", + "total_tokens": 59938701312, + "train_config": { + "allow_partial_checkpoint": false, + "beta1": null, + "beta2": null, + "cycle_length": null, + "data_seed": null, + "decay": null, + "ema_beta": null, + "epsilon": null, + "initialize_from_checkpoint_path": null, + "initialize_from_hf": null, + "int8": false, + "learning_rate": 0.006, + "lr_schedule": null, + "max_eval_batches": null, + "max_grad_norm": null, + "min_lr_ratio": null, + "num_train_steps": 57162, + "optimizer_config": { + "adam_lr": 0.0018, + "adam_weight_decay": null, + "backend_steps": 5, + "beta1": 0.8, + "beta2": 0.98, + "cooldown": null, + "cycle_length": null, + "cycles": null, + "decay": 1, + "default_weight_decay_mask": null, + "epsilon": 1e-15, + "haps": null, + "learning_rate": 0.006, + "lr": 0.02, + "lr_schedule": "linear", + "max_grad_norm": 2, + "min_lr_ratio": 0, + "momentum": 0.98, + "muon_epsilon": 1e-05, + "nesterov": true, + "rewarmup": 0.0, + "use_kimi_scaling": false, + "warmup": 0, + "weight_decay": 0.1, + "weight_decay_modules": null + }, + "per_device_eval_parallelism": null, + "profiler": false, + "profiler_num_steps": 100, + "profiler_start_step": 5, + "reset_data_loader_on_init": true, + "rewarmup": null, + "skip_bad_steps": false, + "steps_per_eval": null, + "steps_per_export": 10000, + "steps_per_hf_export": -1, + "steps_per_task_eval": null, + "train_batch_size": 256, + "train_seq_len": null, + "warmup": null, + "watch": { + "include_histograms": false, + "include_norms": true, + "include_per_parameter_norms": true, + "interval": 10, + "split_scan_layers": true, + "watch_targets": [ + "grads", + "params" + ] + }, + "weight_decay": null, + "z_loss_weight": null + }, + "training_hardware_flops": 1.8044518078036129e+21, + "training_time": 122852.11109774053, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_1_2b_attngate_2048_splash_lr_sweep_lr_x1_5-e5d647" + } + } + ] +} \ No newline at end of file diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/300m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/300m/speedrun_results.json new file mode 100644 index 0000000000..dd9ffa16fa --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/300m/speedrun_results.json @@ -0,0 +1,141 @@ +{ + "runs": [ + { + "run_info": { + "author": { + "affiliation": "Stanford University", + "name": "Calvin Xu", + "url": "https://pinlinxu.com" + }, + "description": "Hackable Transformer (300m); Gated Attention (Splash); LR sweep multiplier=1.5", + "device_flops": 459000000000000.0, + "eval/paloma/c4_en/bpb": 1.0536423921585083, + "model_config": { + "activation_function": "silu", + "attn_backend": null, + "cross_entropy_block_size": null, + "flash_attention_block_size": null, + "gradient_checkpointing": true, + "head_dim": null, + "hidden_dim": 768, + "initializer_range": 0.02, + "input_embedding_norm": false, + "intermediate_dim": 2688, + "layer_norm_epsilon": 1e-05, + "max_seq_len": 4096, + "num_heads": 12, + "num_kv_heads": 12, + "num_layers": 12, + "qk_norm": null, + "reference_checkpoint": "NousResearch/Llama-2-7b-hf", + "rope": { + "factor": 1.0, + "theta": 10000 + }, + "tie_word_embeddings": false, + "tokenizer": null, + "upcast_attn": false, + "use_bias": false, + "use_gated_attention": true, + "use_layer_norm_weight": true + }, + "model_flops": 1.021384111077458e+19, + "model_flops_per_token": 555024384.0, + "model_size": 306727680, + "num_chips": 16, + "num_devices": 16, + "resources": { + "cpu": 1, + "device": { + "topology": null, + "type": "v5p-32" + }, + "disk": "1g", + "preemptible": true, + "ram": "128m", + "regions": null, + "replicas": 1 + }, + "run_completion_timestamp": "2025-12-09 12:24:57 UTC", + "tokenized_dataset": "gs://marin-us-central1/tokenized/subcache/fineweb-edu-10B-ac65f6", + "total_tokens": 6134169600, + "train_config": { + "allow_partial_checkpoint": false, + "beta1": null, + "beta2": null, + "cycle_length": null, + "data_seed": null, + "decay": null, + "ema_beta": null, + "epsilon": null, + "initialize_from_checkpoint_path": null, + "initialize_from_hf": null, + "int8": false, + "learning_rate": 0.012, + "lr_schedule": null, + "max_eval_batches": null, + "max_grad_norm": null, + "min_lr_ratio": null, + "num_train_steps": 11700, + "optimizer_config": { + "adam_lr": 0.0036, + "adam_weight_decay": null, + "backend_steps": 5, + "beta1": 0.8, + "beta2": 0.98, + "cooldown": null, + "cycle_length": null, + "cycles": null, + "decay": 0.8, + "default_weight_decay_mask": null, + "epsilon": 1e-15, + "haps": null, + "learning_rate": 0.012, + "lr": 0.02, + "lr_schedule": "linear", + "max_grad_norm": 1, + "min_lr_ratio": 0, + "momentum": 0.98, + "muon_epsilon": 1e-05, + "nesterov": true, + "rewarmup": 0.0, + "use_kimi_scaling": false, + "warmup": 0, + "weight_decay": 0.1, + "weight_decay_modules": null + }, + "per_device_eval_parallelism": null, + "profiler": false, + "profiler_num_steps": 100, + "profiler_start_step": 5, + "reset_data_loader_on_init": true, + "rewarmup": null, + "skip_bad_steps": false, + "steps_per_eval": null, + "steps_per_export": 10000, + "steps_per_hf_export": -1, + "steps_per_task_eval": null, + "train_batch_size": 128, + "train_seq_len": null, + "warmup": null, + "watch": { + "include_histograms": false, + "include_norms": true, + "include_per_parameter_norms": true, + "interval": 10, + "split_scan_layers": true, + "watch_targets": [ + "grads", + "params" + ] + }, + "weight_decay": null, + "z_loss_weight": null + }, + "training_hardware_flops": 5.0759425868293325e+19, + "training_time": 6911.686528907043, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_300m_attngate_4096_splash_lr_sweep_lr_x1_5-1394e4" + } + } + ] +} \ No newline at end of file diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/520m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/520m/speedrun_results.json new file mode 100644 index 0000000000..fa5babaacb --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/520m/speedrun_results.json @@ -0,0 +1,141 @@ +{ + "runs": [ + { + "run_info": { + "author": { + "affiliation": "Stanford University", + "name": "Calvin Xu", + "url": "https://pinlinxu.com" + }, + "description": "Hackable Transformer (520m); Gated Attention (Splash); LR sweep multiplier=1.5", + "device_flops": 459000000000000.0, + "eval/paloma/c4_en/bpb": 0.9817151427268982, + "model_config": { + "activation_function": "silu", + "attn_backend": null, + "cross_entropy_block_size": null, + "flash_attention_block_size": null, + "gradient_checkpointing": true, + "head_dim": null, + "hidden_dim": 1024, + "initializer_range": 0.02, + "input_embedding_norm": false, + "intermediate_dim": 3584, + "layer_norm_epsilon": 1e-05, + "max_seq_len": 4096, + "num_heads": 16, + "num_kv_heads": 8, + "num_layers": 24, + "qk_norm": null, + "reference_checkpoint": "NousResearch/Llama-2-7b-hf", + "rope": { + "factor": 1.0, + "theta": 10000 + }, + "tie_word_embeddings": false, + "tokenizer": null, + "upcast_attn": false, + "use_bias": false, + "use_gated_attention": true, + "use_layer_norm_weight": true + }, + "model_flops": 5.081733891346976e+19, + "model_flops_per_token": 1349517312.0, + "model_size": 627622912, + "num_chips": 16, + "num_devices": 16, + "resources": { + "cpu": 1, + "device": { + "topology": null, + "type": "v5p-32" + }, + "disk": "1g", + "preemptible": true, + "ram": "128m", + "regions": null, + "replicas": 1 + }, + "run_completion_timestamp": "2025-12-09 19:12:38 UTC", + "tokenized_dataset": "gs://marin-us-central1/tokenized/subcache/fineweb-edu-10B-ac65f6", + "total_tokens": 12551979008, + "train_config": { + "allow_partial_checkpoint": false, + "beta1": null, + "beta2": null, + "cycle_length": null, + "data_seed": null, + "decay": null, + "ema_beta": null, + "epsilon": null, + "initialize_from_checkpoint_path": null, + "initialize_from_hf": null, + "int8": false, + "learning_rate": 0.012, + "lr_schedule": null, + "max_eval_batches": null, + "max_grad_norm": null, + "min_lr_ratio": null, + "num_train_steps": 23941, + "optimizer_config": { + "adam_lr": 0.0036, + "adam_weight_decay": null, + "backend_steps": 5, + "beta1": 0.8, + "beta2": 0.98, + "cooldown": null, + "cycle_length": null, + "cycles": null, + "decay": 1, + "default_weight_decay_mask": null, + "epsilon": 1e-25, + "haps": null, + "learning_rate": 0.012, + "lr": 0.02, + "lr_schedule": "linear", + "max_grad_norm": 1, + "min_lr_ratio": 0, + "momentum": 0.98, + "muon_epsilon": 1e-05, + "nesterov": true, + "rewarmup": 0.0, + "use_kimi_scaling": false, + "warmup": 0, + "weight_decay": 0.1, + "weight_decay_modules": null + }, + "per_device_eval_parallelism": null, + "profiler": false, + "profiler_num_steps": 100, + "profiler_start_step": 5, + "reset_data_loader_on_init": true, + "rewarmup": null, + "skip_bad_steps": false, + "steps_per_eval": null, + "steps_per_export": 10000, + "steps_per_hf_export": -1, + "steps_per_task_eval": null, + "train_batch_size": 128, + "train_seq_len": null, + "warmup": null, + "watch": { + "include_histograms": false, + "include_norms": true, + "include_per_parameter_norms": true, + "interval": 10, + "split_scan_layers": true, + "watch_targets": [ + "grads", + "params" + ] + }, + "weight_decay": null, + "z_loss_weight": null + }, + "training_hardware_flops": 2.2333461727415042e+20, + "training_time": 30410.487101600003, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_520m_attngate_4096_splash_lr_sweep_lr_x1_5-76e777" + } + } + ] +} \ No newline at end of file diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x2/130m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2/130m/speedrun_results.json new file mode 100644 index 0000000000..c8dc916e77 --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2/130m/speedrun_results.json @@ -0,0 +1,141 @@ +{ + "runs": [ + { + "run_info": { + "author": { + "affiliation": "Stanford University", + "name": "Calvin Xu", + "url": "https://pinlinxu.com" + }, + "description": "Hackable Transformer (130m); Gated Attention (Splash); LR sweep multiplier=2", + "device_flops": 459000000000000.0, + "eval/paloma/c4_en/bpb": 1.1576088666915894, + "model_config": { + "activation_function": "silu", + "attn_backend": null, + "cross_entropy_block_size": null, + "flash_attention_block_size": null, + "gradient_checkpointing": true, + "head_dim": null, + "hidden_dim": 512, + "initializer_range": 0.02, + "input_embedding_norm": false, + "intermediate_dim": 1792, + "layer_norm_epsilon": 1e-05, + "max_seq_len": 4096, + "num_heads": 8, + "num_kv_heads": 8, + "num_layers": 6, + "qk_norm": null, + "reference_checkpoint": "NousResearch/Llama-2-7b-hf", + "rope": { + "factor": 1.0, + "theta": 10000 + }, + "tie_word_embeddings": false, + "tokenizer": null, + "upcast_attn": false, + "use_bias": false, + "use_gated_attention": true, + "use_layer_norm_weight": true + }, + "model_flops": 2.1289341996446515e+18, + "model_flops_per_token": 227868672.0, + "model_size": 155720192, + "num_chips": 16, + "num_devices": 16, + "resources": { + "cpu": 1, + "device": { + "topology": null, + "type": "v5p-32" + }, + "disk": "1g", + "preemptible": true, + "ram": "128m", + "regions": null, + "replicas": 1 + }, + "run_completion_timestamp": "2025-12-09 10:57:26 UTC", + "tokenized_dataset": "gs://marin-us-central1/tokenized/subcache/fineweb-edu-10B-ac65f6", + "total_tokens": 3114270720, + "train_config": { + "allow_partial_checkpoint": false, + "beta1": null, + "beta2": null, + "cycle_length": null, + "data_seed": null, + "decay": null, + "ema_beta": null, + "epsilon": null, + "initialize_from_checkpoint_path": null, + "initialize_from_hf": null, + "int8": false, + "learning_rate": 0.032, + "lr_schedule": null, + "max_eval_batches": null, + "max_grad_norm": null, + "min_lr_ratio": null, + "num_train_steps": 5940, + "optimizer_config": { + "adam_lr": 0.0064, + "adam_weight_decay": null, + "backend_steps": 5, + "beta1": 0.8, + "beta2": 0.98, + "cooldown": null, + "cycle_length": null, + "cycles": null, + "decay": 0.8, + "default_weight_decay_mask": null, + "epsilon": 1e-15, + "haps": null, + "learning_rate": 0.032, + "lr": 0.02, + "lr_schedule": "linear", + "max_grad_norm": 1, + "min_lr_ratio": 0, + "momentum": 0.95, + "muon_epsilon": 1e-05, + "nesterov": true, + "rewarmup": 0.0, + "use_kimi_scaling": false, + "warmup": 0, + "weight_decay": 0.1, + "weight_decay_modules": null + }, + "per_device_eval_parallelism": null, + "profiler": false, + "profiler_num_steps": 100, + "profiler_start_step": 5, + "reset_data_loader_on_init": true, + "rewarmup": null, + "skip_bad_steps": false, + "steps_per_eval": null, + "steps_per_export": 10000, + "steps_per_hf_export": -1, + "steps_per_task_eval": null, + "train_batch_size": 128, + "train_seq_len": null, + "warmup": null, + "watch": { + "include_histograms": false, + "include_norms": true, + "include_per_parameter_norms": true, + "interval": 10, + "split_scan_layers": true, + "watch_targets": [ + "grads", + "params" + ] + }, + "weight_decay": null, + "z_loss_weight": null + }, + "training_hardware_flops": 1.1642901007026072e+19, + "training_time": 1585.3623375580164, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_130m_attngate_4096_splash_lr_sweep_lr_x2-8abf41" + } + } + ] +} \ No newline at end of file diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x2/1_2b/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2/1_2b/speedrun_results.json new file mode 100644 index 0000000000..b89184f05f --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2/1_2b/speedrun_results.json @@ -0,0 +1,143 @@ +{ + "runs": [ + { + "run_info": { + "author": { + "affiliation": "Stanford University", + "name": "Calvin Xu", + "url": "https://pinlinxu.com" + }, + "description": "Hackable Transformer (1_2b); Gated Attention (Splash); LR sweep multiplier=2", + "device_flops": 459000000000000.0, + "eval/paloma/c4_en/bpb": 0.9119555354118347, + "model_config": { + "activation_function": "silu", + "attn_backend": null, + "cross_entropy_block_size": null, + "flash_attention_block_size": null, + "gradient_checkpointing": true, + "head_dim": null, + "hidden_dim": 2048, + "initializer_range": 0.02, + "input_embedding_norm": false, + "intermediate_dim": 7168, + "layer_norm_epsilon": 1e-05, + "max_seq_len": 4096, + "num_heads": 16, + "num_kv_heads": 8, + "num_layers": 16, + "qk_norm": null, + "reference_checkpoint": "NousResearch/Llama-2-7b-hf", + "rope": { + "factor": 1.0, + "theta": 10000 + }, + "seq_len": 2048, + "tie_word_embeddings": false, + "tokenizer": null, + "upcast_attn": false, + "use_bias": false, + "use_gated_attention": true, + "use_layer_norm_weight": true + }, + "model_flops": 5.1738353514618185e+20, + "model_flops_per_token": 2877292544.0, + "model_size": 1498482688, + "num_chips": 32, + "num_devices": 32, + "resources": { + "cpu": 1, + "device": { + "kind": "tpu", + "topology": null, + "variant": "v5p-64" + }, + "disk": "1g", + "preemptible": true, + "ram": "128m", + "regions": null, + "replicas": 1 + }, + "run_completion_timestamp": "2025-12-17 01:10:26 UTC", + "tokenized_dataset": "gs://marin-us-east5/tokenized/subcache/fineweb-edu-10B-ac65f6", + "total_tokens": 59938701312, + "train_config": { + "allow_partial_checkpoint": false, + "beta1": null, + "beta2": null, + "cycle_length": null, + "data_seed": null, + "decay": null, + "ema_beta": null, + "epsilon": null, + "initialize_from_checkpoint_path": null, + "initialize_from_hf": null, + "int8": false, + "learning_rate": 0.008, + "lr_schedule": null, + "max_eval_batches": null, + "max_grad_norm": null, + "min_lr_ratio": null, + "num_train_steps": 57162, + "optimizer_config": { + "adam_lr": 0.0024, + "adam_weight_decay": null, + "backend_steps": 5, + "beta1": 0.8, + "beta2": 0.98, + "cooldown": null, + "cycle_length": null, + "cycles": null, + "decay": 1, + "default_weight_decay_mask": null, + "epsilon": 1e-15, + "haps": null, + "learning_rate": 0.008, + "lr": 0.02, + "lr_schedule": "linear", + "max_grad_norm": 2, + "min_lr_ratio": 0, + "momentum": 0.98, + "muon_epsilon": 1e-05, + "nesterov": true, + "rewarmup": 0.0, + "use_kimi_scaling": false, + "warmup": 0, + "weight_decay": 0.1, + "weight_decay_modules": null + }, + "per_device_eval_parallelism": null, + "profiler": false, + "profiler_num_steps": 100, + "profiler_start_step": 5, + "reset_data_loader_on_init": true, + "rewarmup": null, + "skip_bad_steps": false, + "steps_per_eval": null, + "steps_per_export": 10000, + "steps_per_hf_export": -1, + "steps_per_task_eval": null, + "train_batch_size": 256, + "train_seq_len": null, + "warmup": null, + "watch": { + "include_histograms": false, + "include_norms": true, + "include_per_parameter_norms": true, + "interval": 10, + "split_scan_layers": true, + "watch_targets": [ + "grads", + "params" + ] + }, + "weight_decay": null, + "z_loss_weight": null + }, + "training_hardware_flops": 1.8017056211590617e+21, + "training_time": 122665.14305276837, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_1_2b_attngate_2048_splash_lr_sweep_lr_x2-be36f3" + } + } + ] +} \ No newline at end of file diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x2/300m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2/300m/speedrun_results.json new file mode 100644 index 0000000000..63fa3f7a72 --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2/300m/speedrun_results.json @@ -0,0 +1,141 @@ +{ + "runs": [ + { + "run_info": { + "author": { + "affiliation": "Stanford University", + "name": "Calvin Xu", + "url": "https://pinlinxu.com" + }, + "description": "Hackable Transformer (300m); Gated Attention (Splash); LR sweep multiplier=2", + "device_flops": 459000000000000.0, + "eval/paloma/c4_en/bpb": 1.0551481246948242, + "model_config": { + "activation_function": "silu", + "attn_backend": null, + "cross_entropy_block_size": null, + "flash_attention_block_size": null, + "gradient_checkpointing": true, + "head_dim": null, + "hidden_dim": 768, + "initializer_range": 0.02, + "input_embedding_norm": false, + "intermediate_dim": 2688, + "layer_norm_epsilon": 1e-05, + "max_seq_len": 4096, + "num_heads": 12, + "num_kv_heads": 12, + "num_layers": 12, + "qk_norm": null, + "reference_checkpoint": "NousResearch/Llama-2-7b-hf", + "rope": { + "factor": 1.0, + "theta": 10000 + }, + "tie_word_embeddings": false, + "tokenizer": null, + "upcast_attn": false, + "use_bias": false, + "use_gated_attention": true, + "use_layer_norm_weight": true + }, + "model_flops": 1.021384111077458e+19, + "model_flops_per_token": 555024384.0, + "model_size": 306727680, + "num_chips": 16, + "num_devices": 16, + "resources": { + "cpu": 1, + "device": { + "topology": null, + "type": "v5p-32" + }, + "disk": "1g", + "preemptible": true, + "ram": "128m", + "regions": null, + "replicas": 1 + }, + "run_completion_timestamp": "2025-12-09 12:26:21 UTC", + "tokenized_dataset": "gs://marin-us-central1/tokenized/subcache/fineweb-edu-10B-ac65f6", + "total_tokens": 6134169600, + "train_config": { + "allow_partial_checkpoint": false, + "beta1": null, + "beta2": null, + "cycle_length": null, + "data_seed": null, + "decay": null, + "ema_beta": null, + "epsilon": null, + "initialize_from_checkpoint_path": null, + "initialize_from_hf": null, + "int8": false, + "learning_rate": 0.016, + "lr_schedule": null, + "max_eval_batches": null, + "max_grad_norm": null, + "min_lr_ratio": null, + "num_train_steps": 11700, + "optimizer_config": { + "adam_lr": 0.0048, + "adam_weight_decay": null, + "backend_steps": 5, + "beta1": 0.8, + "beta2": 0.98, + "cooldown": null, + "cycle_length": null, + "cycles": null, + "decay": 0.8, + "default_weight_decay_mask": null, + "epsilon": 1e-15, + "haps": null, + "learning_rate": 0.016, + "lr": 0.02, + "lr_schedule": "linear", + "max_grad_norm": 1, + "min_lr_ratio": 0, + "momentum": 0.98, + "muon_epsilon": 1e-05, + "nesterov": true, + "rewarmup": 0.0, + "use_kimi_scaling": false, + "warmup": 0, + "weight_decay": 0.1, + "weight_decay_modules": null + }, + "per_device_eval_parallelism": null, + "profiler": false, + "profiler_num_steps": 100, + "profiler_start_step": 5, + "reset_data_loader_on_init": true, + "rewarmup": null, + "skip_bad_steps": false, + "steps_per_eval": null, + "steps_per_export": 10000, + "steps_per_hf_export": -1, + "steps_per_task_eval": null, + "train_batch_size": 128, + "train_seq_len": null, + "warmup": null, + "watch": { + "include_histograms": false, + "include_norms": true, + "include_per_parameter_norms": true, + "interval": 10, + "split_scan_layers": true, + "watch_targets": [ + "grads", + "params" + ] + }, + "weight_decay": null, + "z_loss_weight": null + }, + "training_hardware_flops": 5.0833465300533e+19, + "training_time": 6921.768150944036, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_300m_attngate_4096_splash_lr_sweep_lr_x2-03a06d" + } + } + ] +} \ No newline at end of file diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x2/520m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2/520m/speedrun_results.json new file mode 100644 index 0000000000..69d348ea33 --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2/520m/speedrun_results.json @@ -0,0 +1,141 @@ +{ + "runs": [ + { + "run_info": { + "author": { + "affiliation": "Stanford University", + "name": "Calvin Xu", + "url": "https://pinlinxu.com" + }, + "description": "Hackable Transformer (520m); Gated Attention (Splash); LR sweep multiplier=2", + "device_flops": 459000000000000.0, + "eval/paloma/c4_en/bpb": 0.9846972823143005, + "model_config": { + "activation_function": "silu", + "attn_backend": null, + "cross_entropy_block_size": null, + "flash_attention_block_size": null, + "gradient_checkpointing": true, + "head_dim": null, + "hidden_dim": 1024, + "initializer_range": 0.02, + "input_embedding_norm": false, + "intermediate_dim": 3584, + "layer_norm_epsilon": 1e-05, + "max_seq_len": 4096, + "num_heads": 16, + "num_kv_heads": 8, + "num_layers": 24, + "qk_norm": null, + "reference_checkpoint": "NousResearch/Llama-2-7b-hf", + "rope": { + "factor": 1.0, + "theta": 10000 + }, + "tie_word_embeddings": false, + "tokenizer": null, + "upcast_attn": false, + "use_bias": false, + "use_gated_attention": true, + "use_layer_norm_weight": true + }, + "model_flops": 5.081733891346976e+19, + "model_flops_per_token": 1349517312.0, + "model_size": 627622912, + "num_chips": 16, + "num_devices": 16, + "resources": { + "cpu": 1, + "device": { + "topology": null, + "type": "v5p-32" + }, + "disk": "1g", + "preemptible": true, + "ram": "128m", + "regions": null, + "replicas": 1 + }, + "run_completion_timestamp": "2025-12-09 19:47:23 UTC", + "tokenized_dataset": "gs://marin-us-central1/tokenized/subcache/fineweb-edu-10B-ac65f6", + "total_tokens": 12551979008, + "train_config": { + "allow_partial_checkpoint": false, + "beta1": null, + "beta2": null, + "cycle_length": null, + "data_seed": null, + "decay": null, + "ema_beta": null, + "epsilon": null, + "initialize_from_checkpoint_path": null, + "initialize_from_hf": null, + "int8": false, + "learning_rate": 0.016, + "lr_schedule": null, + "max_eval_batches": null, + "max_grad_norm": null, + "min_lr_ratio": null, + "num_train_steps": 23941, + "optimizer_config": { + "adam_lr": 0.0048, + "adam_weight_decay": null, + "backend_steps": 5, + "beta1": 0.8, + "beta2": 0.98, + "cooldown": null, + "cycle_length": null, + "cycles": null, + "decay": 1, + "default_weight_decay_mask": null, + "epsilon": 1e-25, + "haps": null, + "learning_rate": 0.016, + "lr": 0.02, + "lr_schedule": "linear", + "max_grad_norm": 1, + "min_lr_ratio": 0, + "momentum": 0.98, + "muon_epsilon": 1e-05, + "nesterov": true, + "rewarmup": 0.0, + "use_kimi_scaling": false, + "warmup": 0, + "weight_decay": 0.1, + "weight_decay_modules": null + }, + "per_device_eval_parallelism": null, + "profiler": false, + "profiler_num_steps": 100, + "profiler_start_step": 5, + "reset_data_loader_on_init": true, + "rewarmup": null, + "skip_bad_steps": false, + "steps_per_eval": null, + "steps_per_export": 10000, + "steps_per_hf_export": -1, + "steps_per_task_eval": null, + "train_batch_size": 128, + "train_seq_len": null, + "warmup": null, + "watch": { + "include_histograms": false, + "include_norms": true, + "include_per_parameter_norms": true, + "interval": 10, + "split_scan_layers": true, + "watch_targets": [ + "grads", + "params" + ] + }, + "weight_decay": null, + "z_loss_weight": null + }, + "training_hardware_flops": 2.2363047585747508e+20, + "training_time": 30450.772856410007, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_520m_attngate_4096_splash_lr_sweep_lr_x2-083666" + } + } + ] +} \ No newline at end of file diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/130m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/130m/speedrun_results.json new file mode 100644 index 0000000000..995167ee58 --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/130m/speedrun_results.json @@ -0,0 +1,141 @@ +{ + "runs": [ + { + "run_info": { + "author": { + "affiliation": "Stanford University", + "name": "Calvin Xu", + "url": "https://pinlinxu.com" + }, + "description": "Hackable Transformer (130m); Gated Attention (Splash); LR sweep multiplier=2.5", + "device_flops": 459000000000000.0, + "eval/paloma/c4_en/bpb": 1.1613965034484863, + "model_config": { + "activation_function": "silu", + "attn_backend": null, + "cross_entropy_block_size": null, + "flash_attention_block_size": null, + "gradient_checkpointing": true, + "head_dim": null, + "hidden_dim": 512, + "initializer_range": 0.02, + "input_embedding_norm": false, + "intermediate_dim": 1792, + "layer_norm_epsilon": 1e-05, + "max_seq_len": 4096, + "num_heads": 8, + "num_kv_heads": 8, + "num_layers": 6, + "qk_norm": null, + "reference_checkpoint": "NousResearch/Llama-2-7b-hf", + "rope": { + "factor": 1.0, + "theta": 10000 + }, + "tie_word_embeddings": false, + "tokenizer": null, + "upcast_attn": false, + "use_bias": false, + "use_gated_attention": true, + "use_layer_norm_weight": true + }, + "model_flops": 2.1289341996446515e+18, + "model_flops_per_token": 227868672.0, + "model_size": 155720192, + "num_chips": 16, + "num_devices": 16, + "resources": { + "cpu": 1, + "device": { + "topology": null, + "type": "v5p-32" + }, + "disk": "1g", + "preemptible": true, + "ram": "128m", + "regions": null, + "replicas": 1 + }, + "run_completion_timestamp": "2025-12-09 10:57:45 UTC", + "tokenized_dataset": "gs://marin-us-central1/tokenized/subcache/fineweb-edu-10B-ac65f6", + "total_tokens": 3114270720, + "train_config": { + "allow_partial_checkpoint": false, + "beta1": null, + "beta2": null, + "cycle_length": null, + "data_seed": null, + "decay": null, + "ema_beta": null, + "epsilon": null, + "initialize_from_checkpoint_path": null, + "initialize_from_hf": null, + "int8": false, + "learning_rate": 0.04, + "lr_schedule": null, + "max_eval_batches": null, + "max_grad_norm": null, + "min_lr_ratio": null, + "num_train_steps": 5940, + "optimizer_config": { + "adam_lr": 0.008, + "adam_weight_decay": null, + "backend_steps": 5, + "beta1": 0.8, + "beta2": 0.98, + "cooldown": null, + "cycle_length": null, + "cycles": null, + "decay": 0.8, + "default_weight_decay_mask": null, + "epsilon": 1e-15, + "haps": null, + "learning_rate": 0.04, + "lr": 0.02, + "lr_schedule": "linear", + "max_grad_norm": 1, + "min_lr_ratio": 0, + "momentum": 0.95, + "muon_epsilon": 1e-05, + "nesterov": true, + "rewarmup": 0.0, + "use_kimi_scaling": false, + "warmup": 0, + "weight_decay": 0.1, + "weight_decay_modules": null + }, + "per_device_eval_parallelism": null, + "profiler": false, + "profiler_num_steps": 100, + "profiler_start_step": 5, + "reset_data_loader_on_init": true, + "rewarmup": null, + "skip_bad_steps": false, + "steps_per_eval": null, + "steps_per_export": 10000, + "steps_per_hf_export": -1, + "steps_per_task_eval": null, + "train_batch_size": 128, + "train_seq_len": null, + "warmup": null, + "watch": { + "include_histograms": false, + "include_norms": true, + "include_per_parameter_norms": true, + "interval": 10, + "split_scan_layers": true, + "watch_targets": [ + "grads", + "params" + ] + }, + "weight_decay": null, + "z_loss_weight": null + }, + "training_hardware_flops": 1.1718583294415854e+19, + "training_time": 1595.6676599150128, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_130m_attngate_4096_splash_lr_sweep_lr_x2_5-01984c" + } + } + ] +} \ No newline at end of file diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/1_2b/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/1_2b/speedrun_results.json new file mode 100644 index 0000000000..e4f78d80a9 --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/1_2b/speedrun_results.json @@ -0,0 +1,143 @@ +{ + "runs": [ + { + "run_info": { + "author": { + "affiliation": "Stanford University", + "name": "Calvin Xu", + "url": "https://pinlinxu.com" + }, + "description": "Hackable Transformer (1_2b); Gated Attention (Splash); LR sweep multiplier=2.5", + "device_flops": 459000000000000.0, + "eval/paloma/c4_en/bpb": 0.9125918745994568, + "model_config": { + "activation_function": "silu", + "attn_backend": null, + "cross_entropy_block_size": null, + "flash_attention_block_size": null, + "gradient_checkpointing": true, + "head_dim": null, + "hidden_dim": 2048, + "initializer_range": 0.02, + "input_embedding_norm": false, + "intermediate_dim": 7168, + "layer_norm_epsilon": 1e-05, + "max_seq_len": 4096, + "num_heads": 16, + "num_kv_heads": 8, + "num_layers": 16, + "qk_norm": null, + "reference_checkpoint": "NousResearch/Llama-2-7b-hf", + "rope": { + "factor": 1.0, + "theta": 10000 + }, + "seq_len": 2048, + "tie_word_embeddings": false, + "tokenizer": null, + "upcast_attn": false, + "use_bias": false, + "use_gated_attention": true, + "use_layer_norm_weight": true + }, + "model_flops": 5.1738353514618185e+20, + "model_flops_per_token": 2877292544.0, + "model_size": 1498482688, + "num_chips": 32, + "num_devices": 32, + "resources": { + "cpu": 1, + "device": { + "kind": "tpu", + "topology": null, + "variant": "v5p-64" + }, + "disk": "1g", + "preemptible": true, + "ram": "128m", + "regions": null, + "replicas": 1 + }, + "run_completion_timestamp": "2025-12-16 23:49:35 UTC", + "tokenized_dataset": "gs://marin-us-east5/tokenized/subcache/fineweb-edu-10B-ac65f6", + "total_tokens": 59938701312, + "train_config": { + "allow_partial_checkpoint": false, + "beta1": null, + "beta2": null, + "cycle_length": null, + "data_seed": null, + "decay": null, + "ema_beta": null, + "epsilon": null, + "initialize_from_checkpoint_path": null, + "initialize_from_hf": null, + "int8": false, + "learning_rate": 0.01, + "lr_schedule": null, + "max_eval_batches": null, + "max_grad_norm": null, + "min_lr_ratio": null, + "num_train_steps": 57162, + "optimizer_config": { + "adam_lr": 0.0029999999999999996, + "adam_weight_decay": null, + "backend_steps": 5, + "beta1": 0.8, + "beta2": 0.98, + "cooldown": null, + "cycle_length": null, + "cycles": null, + "decay": 1, + "default_weight_decay_mask": null, + "epsilon": 1e-15, + "haps": null, + "learning_rate": 0.01, + "lr": 0.02, + "lr_schedule": "linear", + "max_grad_norm": 2, + "min_lr_ratio": 0, + "momentum": 0.98, + "muon_epsilon": 1e-05, + "nesterov": true, + "rewarmup": 0.0, + "use_kimi_scaling": false, + "warmup": 0, + "weight_decay": 0.1, + "weight_decay_modules": null + }, + "per_device_eval_parallelism": null, + "profiler": false, + "profiler_num_steps": 100, + "profiler_start_step": 5, + "reset_data_loader_on_init": true, + "rewarmup": null, + "skip_bad_steps": false, + "steps_per_eval": null, + "steps_per_export": 10000, + "steps_per_hf_export": -1, + "steps_per_task_eval": null, + "train_batch_size": 256, + "train_seq_len": null, + "warmup": null, + "watch": { + "include_histograms": false, + "include_norms": true, + "include_per_parameter_norms": true, + "interval": 10, + "split_scan_layers": true, + "watch_targets": [ + "grads", + "params" + ] + }, + "weight_decay": null, + "z_loss_weight": null + }, + "training_hardware_flops": 1.8052209610227266e+21, + "training_time": 122904.47719381309, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_1_2b_attngate_2048_splash_lr_sweep_lr_x2_5-2f4194" + } + } + ] +} \ No newline at end of file diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/300m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/300m/speedrun_results.json new file mode 100644 index 0000000000..c709179068 --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/300m/speedrun_results.json @@ -0,0 +1,141 @@ +{ + "runs": [ + { + "run_info": { + "author": { + "affiliation": "Stanford University", + "name": "Calvin Xu", + "url": "https://pinlinxu.com" + }, + "description": "Hackable Transformer (300m); Gated Attention (Splash); LR sweep multiplier=2.5", + "device_flops": 459000000000000.0, + "eval/paloma/c4_en/bpb": 1.05768620967865, + "model_config": { + "activation_function": "silu", + "attn_backend": null, + "cross_entropy_block_size": null, + "flash_attention_block_size": null, + "gradient_checkpointing": true, + "head_dim": null, + "hidden_dim": 768, + "initializer_range": 0.02, + "input_embedding_norm": false, + "intermediate_dim": 2688, + "layer_norm_epsilon": 1e-05, + "max_seq_len": 4096, + "num_heads": 12, + "num_kv_heads": 12, + "num_layers": 12, + "qk_norm": null, + "reference_checkpoint": "NousResearch/Llama-2-7b-hf", + "rope": { + "factor": 1.0, + "theta": 10000 + }, + "tie_word_embeddings": false, + "tokenizer": null, + "upcast_attn": false, + "use_bias": false, + "use_gated_attention": true, + "use_layer_norm_weight": true + }, + "model_flops": 1.021384111077458e+19, + "model_flops_per_token": 555024384.0, + "model_size": 306727680, + "num_chips": 16, + "num_devices": 16, + "resources": { + "cpu": 1, + "device": { + "topology": null, + "type": "v5p-32" + }, + "disk": "1g", + "preemptible": true, + "ram": "128m", + "regions": null, + "replicas": 1 + }, + "run_completion_timestamp": "2025-12-09 12:42:55 UTC", + "tokenized_dataset": "gs://marin-us-central1/tokenized/subcache/fineweb-edu-10B-ac65f6", + "total_tokens": 6134169600, + "train_config": { + "allow_partial_checkpoint": false, + "beta1": null, + "beta2": null, + "cycle_length": null, + "data_seed": null, + "decay": null, + "ema_beta": null, + "epsilon": null, + "initialize_from_checkpoint_path": null, + "initialize_from_hf": null, + "int8": false, + "learning_rate": 0.02, + "lr_schedule": null, + "max_eval_batches": null, + "max_grad_norm": null, + "min_lr_ratio": null, + "num_train_steps": 11700, + "optimizer_config": { + "adam_lr": 0.005999999999999999, + "adam_weight_decay": null, + "backend_steps": 5, + "beta1": 0.8, + "beta2": 0.98, + "cooldown": null, + "cycle_length": null, + "cycles": null, + "decay": 0.8, + "default_weight_decay_mask": null, + "epsilon": 1e-15, + "haps": null, + "learning_rate": 0.02, + "lr": 0.02, + "lr_schedule": "linear", + "max_grad_norm": 1, + "min_lr_ratio": 0, + "momentum": 0.98, + "muon_epsilon": 1e-05, + "nesterov": true, + "rewarmup": 0.0, + "use_kimi_scaling": false, + "warmup": 0, + "weight_decay": 0.1, + "weight_decay_modules": null + }, + "per_device_eval_parallelism": null, + "profiler": false, + "profiler_num_steps": 100, + "profiler_start_step": 5, + "reset_data_loader_on_init": true, + "rewarmup": null, + "skip_bad_steps": false, + "steps_per_eval": null, + "steps_per_export": 10000, + "steps_per_hf_export": -1, + "steps_per_task_eval": null, + "train_batch_size": 128, + "train_seq_len": null, + "warmup": null, + "watch": { + "include_histograms": false, + "include_norms": true, + "include_per_parameter_norms": true, + "interval": 10, + "split_scan_layers": true, + "watch_targets": [ + "grads", + "params" + ] + }, + "weight_decay": null, + "z_loss_weight": null + }, + "training_hardware_flops": 5.071905773521359e+19, + "training_time": 6906.18977876002, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_300m_attngate_4096_splash_lr_sweep_lr_x2_5-13d2c4" + } + } + ] +} \ No newline at end of file diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/520m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/520m/speedrun_results.json new file mode 100644 index 0000000000..19845cb864 --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/520m/speedrun_results.json @@ -0,0 +1,143 @@ +{ + "runs": [ + { + "run_info": { + "author": { + "affiliation": "Stanford University", + "name": "Calvin Xu", + "url": "https://pinlinxu.com" + }, + "description": "Hackable Transformer (520m); Gated Attention (Splash); LR sweep multiplier=2.5", + "device_flops": 459000000000000.0, + "eval/paloma/c4_en/bpb": 0.9764590263366699, + "model_config": { + "activation_function": "silu", + "attn_backend": null, + "cross_entropy_block_size": null, + "flash_attention_block_size": null, + "gradient_checkpointing": true, + "head_dim": null, + "hidden_dim": 1024, + "initializer_range": 0.02, + "input_embedding_norm": false, + "intermediate_dim": 3584, + "layer_norm_epsilon": 1e-05, + "max_seq_len": 4096, + "num_heads": 16, + "num_kv_heads": 8, + "num_layers": 24, + "qk_norm": null, + "reference_checkpoint": "NousResearch/Llama-2-7b-hf", + "rope": { + "factor": 1.0, + "theta": 10000 + }, + "seq_len": 2048, + "tie_word_embeddings": false, + "tokenizer": null, + "upcast_attn": false, + "use_bias": false, + "use_gated_attention": true, + "use_layer_norm_weight": true + }, + "model_flops": 1.0163680043413694e+20, + "model_flops_per_token": 1349517312.0, + "model_size": 627622912, + "num_chips": 16, + "num_devices": 16, + "resources": { + "cpu": 1, + "device": { + "kind": "tpu", + "topology": null, + "variant": "v5p-32" + }, + "disk": "1g", + "preemptible": true, + "ram": "128m", + "regions": null, + "replicas": 1 + }, + "run_completion_timestamp": "2025-12-16 02:37:19 UTC", + "tokenized_dataset": "gs://marin-us-central1/tokenized/subcache/fineweb-edu-10B-ac65f6", + "total_tokens": 25104482304, + "train_config": { + "allow_partial_checkpoint": false, + "beta1": null, + "beta2": null, + "cycle_length": null, + "data_seed": null, + "decay": null, + "ema_beta": null, + "epsilon": null, + "initialize_from_checkpoint_path": null, + "initialize_from_hf": null, + "int8": false, + "learning_rate": 0.02, + "lr_schedule": null, + "max_eval_batches": null, + "max_grad_norm": null, + "min_lr_ratio": null, + "num_train_steps": 47883, + "optimizer_config": { + "adam_lr": 0.005999999999999999, + "adam_weight_decay": null, + "backend_steps": 5, + "beta1": 0.8, + "beta2": 0.98, + "cooldown": null, + "cycle_length": null, + "cycles": null, + "decay": 1, + "default_weight_decay_mask": null, + "epsilon": 1e-25, + "haps": null, + "learning_rate": 0.02, + "lr": 0.02, + "lr_schedule": "linear", + "max_grad_norm": 1, + "min_lr_ratio": 0, + "momentum": 0.98, + "muon_epsilon": 1e-05, + "nesterov": true, + "rewarmup": 0.0, + "use_kimi_scaling": false, + "warmup": 0, + "weight_decay": 0.1, + "weight_decay_modules": null + }, + "per_device_eval_parallelism": null, + "profiler": false, + "profiler_num_steps": 100, + "profiler_start_step": 5, + "reset_data_loader_on_init": true, + "rewarmup": null, + "skip_bad_steps": false, + "steps_per_eval": null, + "steps_per_export": 10000, + "steps_per_hf_export": -1, + "steps_per_task_eval": null, + "train_batch_size": 128, + "train_seq_len": null, + "warmup": null, + "watch": { + "include_histograms": false, + "include_norms": true, + "include_per_parameter_norms": true, + "interval": 10, + "split_scan_layers": true, + "watch_targets": [ + "grads", + "params" + ] + }, + "weight_decay": null, + "z_loss_weight": null + }, + "training_hardware_flops": 4.474616058782678e+20, + "training_time": 60928.86790281424, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_520m_attngate_2048_splash_lr_sweep_lr_x2_5-71c9e9" + } + } + ] +} \ No newline at end of file diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x3/130m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3/130m/speedrun_results.json new file mode 100644 index 0000000000..564e5b8078 --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3/130m/speedrun_results.json @@ -0,0 +1,141 @@ +{ + "runs": [ + { + "run_info": { + "author": { + "affiliation": "Stanford University", + "name": "Calvin Xu", + "url": "https://pinlinxu.com" + }, + "description": "Hackable Transformer (130m); Gated Attention (Splash); LR sweep multiplier=3", + "device_flops": 459000000000000.0, + "eval/paloma/c4_en/bpb": 1.1638695001602173, + "model_config": { + "activation_function": "silu", + "attn_backend": null, + "cross_entropy_block_size": null, + "flash_attention_block_size": null, + "gradient_checkpointing": true, + "head_dim": null, + "hidden_dim": 512, + "initializer_range": 0.02, + "input_embedding_norm": false, + "intermediate_dim": 1792, + "layer_norm_epsilon": 1e-05, + "max_seq_len": 4096, + "num_heads": 8, + "num_kv_heads": 8, + "num_layers": 6, + "qk_norm": null, + "reference_checkpoint": "NousResearch/Llama-2-7b-hf", + "rope": { + "factor": 1.0, + "theta": 10000 + }, + "tie_word_embeddings": false, + "tokenizer": null, + "upcast_attn": false, + "use_bias": false, + "use_gated_attention": true, + "use_layer_norm_weight": true + }, + "model_flops": 2.1289341996446515e+18, + "model_flops_per_token": 227868672.0, + "model_size": 155720192, + "num_chips": 16, + "num_devices": 16, + "resources": { + "cpu": 1, + "device": { + "topology": null, + "type": "v5p-32" + }, + "disk": "1g", + "preemptible": true, + "ram": "128m", + "regions": null, + "replicas": 1 + }, + "run_completion_timestamp": "2025-12-09 10:45:13 UTC", + "tokenized_dataset": "gs://marin-us-central1/tokenized/subcache/fineweb-edu-10B-ac65f6", + "total_tokens": 3114270720, + "train_config": { + "allow_partial_checkpoint": false, + "beta1": null, + "beta2": null, + "cycle_length": null, + "data_seed": null, + "decay": null, + "ema_beta": null, + "epsilon": null, + "initialize_from_checkpoint_path": null, + "initialize_from_hf": null, + "int8": false, + "learning_rate": 0.048, + "lr_schedule": null, + "max_eval_batches": null, + "max_grad_norm": null, + "min_lr_ratio": null, + "num_train_steps": 5940, + "optimizer_config": { + "adam_lr": 0.009600000000000001, + "adam_weight_decay": null, + "backend_steps": 5, + "beta1": 0.8, + "beta2": 0.98, + "cooldown": null, + "cycle_length": null, + "cycles": null, + "decay": 0.8, + "default_weight_decay_mask": null, + "epsilon": 1e-15, + "haps": null, + "learning_rate": 0.048, + "lr": 0.02, + "lr_schedule": "linear", + "max_grad_norm": 1, + "min_lr_ratio": 0, + "momentum": 0.95, + "muon_epsilon": 1e-05, + "nesterov": true, + "rewarmup": 0.0, + "use_kimi_scaling": false, + "warmup": 0, + "weight_decay": 0.1, + "weight_decay_modules": null + }, + "per_device_eval_parallelism": null, + "profiler": false, + "profiler_num_steps": 100, + "profiler_start_step": 5, + "reset_data_loader_on_init": true, + "rewarmup": null, + "skip_bad_steps": false, + "steps_per_eval": null, + "steps_per_export": 10000, + "steps_per_hf_export": -1, + "steps_per_task_eval": null, + "train_batch_size": 128, + "train_seq_len": null, + "warmup": null, + "watch": { + "include_histograms": false, + "include_norms": true, + "include_per_parameter_norms": true, + "interval": 10, + "split_scan_layers": true, + "watch_targets": [ + "grads", + "params" + ] + }, + "weight_decay": null, + "z_loss_weight": null + }, + "training_hardware_flops": 1.1665914456968133e+19, + "training_time": 1588.4959772560094, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_130m_attngate_4096_splash_lr_sweep_lr_x3-70d7ec" + } + } + ] +} \ No newline at end of file diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x3/1_2b/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3/1_2b/speedrun_results.json new file mode 100644 index 0000000000..b71df6db06 --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3/1_2b/speedrun_results.json @@ -0,0 +1,143 @@ +{ + "runs": [ + { + "run_info": { + "author": { + "affiliation": "Stanford University", + "name": "Calvin Xu", + "url": "https://pinlinxu.com" + }, + "description": "Hackable Transformer (1_2b); Gated Attention (Splash); LR sweep multiplier=3", + "device_flops": 459000000000000.0, + "eval/paloma/c4_en/bpb": 0.9134978652000427, + "model_config": { + "activation_function": "silu", + "attn_backend": null, + "cross_entropy_block_size": null, + "flash_attention_block_size": null, + "gradient_checkpointing": true, + "head_dim": null, + "hidden_dim": 2048, + "initializer_range": 0.02, + "input_embedding_norm": false, + "intermediate_dim": 7168, + "layer_norm_epsilon": 1e-05, + "max_seq_len": 4096, + "num_heads": 16, + "num_kv_heads": 8, + "num_layers": 16, + "qk_norm": null, + "reference_checkpoint": "NousResearch/Llama-2-7b-hf", + "rope": { + "factor": 1.0, + "theta": 10000 + }, + "seq_len": 2048, + "tie_word_embeddings": false, + "tokenizer": null, + "upcast_attn": false, + "use_bias": false, + "use_gated_attention": true, + "use_layer_norm_weight": true + }, + "model_flops": 5.1738353514618185e+20, + "model_flops_per_token": 2877292544.0, + "model_size": 1498482688, + "num_chips": 32, + "num_devices": 32, + "resources": { + "cpu": 1, + "device": { + "kind": "tpu", + "topology": null, + "variant": "v5p-64" + }, + "disk": "1g", + "preemptible": true, + "ram": "128m", + "regions": null, + "replicas": 1 + }, + "run_completion_timestamp": "2025-12-17 00:34:36 UTC", + "tokenized_dataset": "gs://marin-us-east5/tokenized/subcache/fineweb-edu-10B-ac65f6", + "total_tokens": 59938701312, + "train_config": { + "allow_partial_checkpoint": false, + "beta1": null, + "beta2": null, + "cycle_length": null, + "data_seed": null, + "decay": null, + "ema_beta": null, + "epsilon": null, + "initialize_from_checkpoint_path": null, + "initialize_from_hf": null, + "int8": false, + "learning_rate": 0.012, + "lr_schedule": null, + "max_eval_batches": null, + "max_grad_norm": null, + "min_lr_ratio": null, + "num_train_steps": 57162, + "optimizer_config": { + "adam_lr": 0.0036, + "adam_weight_decay": null, + "backend_steps": 5, + "beta1": 0.8, + "beta2": 0.98, + "cooldown": null, + "cycle_length": null, + "cycles": null, + "decay": 1, + "default_weight_decay_mask": null, + "epsilon": 1e-15, + "haps": null, + "learning_rate": 0.012, + "lr": 0.02, + "lr_schedule": "linear", + "max_grad_norm": 2, + "min_lr_ratio": 0, + "momentum": 0.98, + "muon_epsilon": 1e-05, + "nesterov": true, + "rewarmup": 0.0, + "use_kimi_scaling": false, + "warmup": 0, + "weight_decay": 0.1, + "weight_decay_modules": null + }, + "per_device_eval_parallelism": null, + "profiler": false, + "profiler_num_steps": 100, + "profiler_start_step": 5, + "reset_data_loader_on_init": true, + "rewarmup": null, + "skip_bad_steps": false, + "steps_per_eval": null, + "steps_per_export": 10000, + "steps_per_hf_export": -1, + "steps_per_task_eval": null, + "train_batch_size": 256, + "train_seq_len": null, + "warmup": null, + "watch": { + "include_histograms": false, + "include_norms": true, + "include_per_parameter_norms": true, + "interval": 10, + "split_scan_layers": true, + "watch_targets": [ + "grads", + "params" + ] + }, + "weight_decay": null, + "z_loss_weight": null + }, + "training_hardware_flops": 1.81031625465692e+21, + "training_time": 123251.37899352668, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_1_2b_attngate_2048_splash_lr_sweep_lr_x3-e8942d" + } + } + ] +} \ No newline at end of file diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x3/300m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3/300m/speedrun_results.json new file mode 100644 index 0000000000..95d2277545 --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3/300m/speedrun_results.json @@ -0,0 +1,141 @@ +{ + "runs": [ + { + "run_info": { + "author": { + "affiliation": "Stanford University", + "name": "Calvin Xu", + "url": "https://pinlinxu.com" + }, + "description": "Hackable Transformer (300m); Gated Attention (Splash); LR sweep multiplier=3", + "device_flops": 459000000000000.0, + "eval/paloma/c4_en/bpb": 1.060373306274414, + "model_config": { + "activation_function": "silu", + "attn_backend": null, + "cross_entropy_block_size": null, + "flash_attention_block_size": null, + "gradient_checkpointing": true, + "head_dim": null, + "hidden_dim": 768, + "initializer_range": 0.02, + "input_embedding_norm": false, + "intermediate_dim": 2688, + "layer_norm_epsilon": 1e-05, + "max_seq_len": 4096, + "num_heads": 12, + "num_kv_heads": 12, + "num_layers": 12, + "qk_norm": null, + "reference_checkpoint": "NousResearch/Llama-2-7b-hf", + "rope": { + "factor": 1.0, + "theta": 10000 + }, + "tie_word_embeddings": false, + "tokenizer": null, + "upcast_attn": false, + "use_bias": false, + "use_gated_attention": true, + "use_layer_norm_weight": true + }, + "model_flops": 1.021384111077458e+19, + "model_flops_per_token": 555024384.0, + "model_size": 306727680, + "num_chips": 16, + "num_devices": 16, + "resources": { + "cpu": 1, + "device": { + "topology": null, + "type": "v5p-32" + }, + "disk": "1g", + "preemptible": true, + "ram": "128m", + "regions": null, + "replicas": 1 + }, + "run_completion_timestamp": "2025-12-09 12:42:27 UTC", + "tokenized_dataset": "gs://marin-us-central1/tokenized/subcache/fineweb-edu-10B-ac65f6", + "total_tokens": 6134169600, + "train_config": { + "allow_partial_checkpoint": false, + "beta1": null, + "beta2": null, + "cycle_length": null, + "data_seed": null, + "decay": null, + "ema_beta": null, + "epsilon": null, + "initialize_from_checkpoint_path": null, + "initialize_from_hf": null, + "int8": false, + "learning_rate": 0.024, + "lr_schedule": null, + "max_eval_batches": null, + "max_grad_norm": null, + "min_lr_ratio": null, + "num_train_steps": 11700, + "optimizer_config": { + "adam_lr": 0.0072, + "adam_weight_decay": null, + "backend_steps": 5, + "beta1": 0.8, + "beta2": 0.98, + "cooldown": null, + "cycle_length": null, + "cycles": null, + "decay": 0.8, + "default_weight_decay_mask": null, + "epsilon": 1e-15, + "haps": null, + "learning_rate": 0.024, + "lr": 0.02, + "lr_schedule": "linear", + "max_grad_norm": 1, + "min_lr_ratio": 0, + "momentum": 0.98, + "muon_epsilon": 1e-05, + "nesterov": true, + "rewarmup": 0.0, + "use_kimi_scaling": false, + "warmup": 0, + "weight_decay": 0.1, + "weight_decay_modules": null + }, + "per_device_eval_parallelism": null, + "profiler": false, + "profiler_num_steps": 100, + "profiler_start_step": 5, + "reset_data_loader_on_init": true, + "rewarmup": null, + "skip_bad_steps": false, + "steps_per_eval": null, + "steps_per_export": 10000, + "steps_per_hf_export": -1, + "steps_per_task_eval": null, + "train_batch_size": 128, + "train_seq_len": null, + "warmup": null, + "watch": { + "include_histograms": false, + "include_norms": true, + "include_per_parameter_norms": true, + "interval": 10, + "split_scan_layers": true, + "watch_targets": [ + "grads", + "params" + ] + }, + "weight_decay": null, + "z_loss_weight": null + }, + "training_hardware_flops": 5.078015592931944e+19, + "training_time": 6914.509249635, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_300m_attngate_4096_splash_lr_sweep_lr_x3-6cc06b" + } + } + ] +} \ No newline at end of file diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x3/520m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3/520m/speedrun_results.json new file mode 100644 index 0000000000..85cc77d30d --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3/520m/speedrun_results.json @@ -0,0 +1,141 @@ +{ + "runs": [ + { + "run_info": { + "author": { + "affiliation": "Stanford University", + "name": "Calvin Xu", + "url": "https://pinlinxu.com" + }, + "description": "Hackable Transformer (520m); Gated Attention (Splash); LR sweep multiplier=3", + "device_flops": 459000000000000.0, + "eval/paloma/c4_en/bpb": 0.9944517016410828, + "model_config": { + "activation_function": "silu", + "attn_backend": null, + "cross_entropy_block_size": null, + "flash_attention_block_size": null, + "gradient_checkpointing": true, + "head_dim": null, + "hidden_dim": 1024, + "initializer_range": 0.02, + "input_embedding_norm": false, + "intermediate_dim": 3584, + "layer_norm_epsilon": 1e-05, + "max_seq_len": 4096, + "num_heads": 16, + "num_kv_heads": 8, + "num_layers": 24, + "qk_norm": null, + "reference_checkpoint": "NousResearch/Llama-2-7b-hf", + "rope": { + "factor": 1.0, + "theta": 10000 + }, + "tie_word_embeddings": false, + "tokenizer": null, + "upcast_attn": false, + "use_bias": false, + "use_gated_attention": true, + "use_layer_norm_weight": true + }, + "model_flops": 5.081733891346976e+19, + "model_flops_per_token": 1349517312.0, + "model_size": 627622912, + "num_chips": 16, + "num_devices": 16, + "resources": { + "cpu": 1, + "device": { + "topology": null, + "type": "v5p-32" + }, + "disk": "1g", + "preemptible": true, + "ram": "128m", + "regions": null, + "replicas": 1 + }, + "run_completion_timestamp": "2025-12-09 19:40:08 UTC", + "tokenized_dataset": "gs://marin-us-central1/tokenized/subcache/fineweb-edu-10B-ac65f6", + "total_tokens": 12551979008, + "train_config": { + "allow_partial_checkpoint": false, + "beta1": null, + "beta2": null, + "cycle_length": null, + "data_seed": null, + "decay": null, + "ema_beta": null, + "epsilon": null, + "initialize_from_checkpoint_path": null, + "initialize_from_hf": null, + "int8": false, + "learning_rate": 0.024, + "lr_schedule": null, + "max_eval_batches": null, + "max_grad_norm": null, + "min_lr_ratio": null, + "num_train_steps": 23941, + "optimizer_config": { + "adam_lr": 0.0072, + "adam_weight_decay": null, + "backend_steps": 5, + "beta1": 0.8, + "beta2": 0.98, + "cooldown": null, + "cycle_length": null, + "cycles": null, + "decay": 1, + "default_weight_decay_mask": null, + "epsilon": 1e-25, + "haps": null, + "learning_rate": 0.024, + "lr": 0.02, + "lr_schedule": "linear", + "max_grad_norm": 1, + "min_lr_ratio": 0, + "momentum": 0.98, + "muon_epsilon": 1e-05, + "nesterov": true, + "rewarmup": 0.0, + "use_kimi_scaling": false, + "warmup": 0, + "weight_decay": 0.1, + "weight_decay_modules": null + }, + "per_device_eval_parallelism": null, + "profiler": false, + "profiler_num_steps": 100, + "profiler_start_step": 5, + "reset_data_loader_on_init": true, + "rewarmup": null, + "skip_bad_steps": false, + "steps_per_eval": null, + "steps_per_export": 10000, + "steps_per_hf_export": -1, + "steps_per_task_eval": null, + "train_batch_size": 128, + "train_seq_len": null, + "warmup": null, + "watch": { + "include_histograms": false, + "include_norms": true, + "include_per_parameter_norms": true, + "interval": 10, + "split_scan_layers": true, + "watch_targets": [ + "grads", + "params" + ] + }, + "weight_decay": null, + "z_loss_weight": null + }, + "training_hardware_flops": 2.2342612489445596e+20, + "training_time": 30422.947289550102, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_520m_attngate_4096_splash_lr_sweep_lr_x3-325c68" + } + } + ] +} \ No newline at end of file diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/130m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/130m/speedrun_results.json new file mode 100644 index 0000000000..6727da96ff --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/130m/speedrun_results.json @@ -0,0 +1,141 @@ +{ + "runs": [ + { + "run_info": { + "author": { + "affiliation": "Stanford University", + "name": "Calvin Xu", + "url": "https://pinlinxu.com" + }, + "description": "Hackable Transformer (130m); Gated Attention (Splash); LR sweep multiplier=3.5", + "device_flops": 459000000000000.0, + "eval/paloma/c4_en/bpb": 1.16841721534729, + "model_config": { + "activation_function": "silu", + "attn_backend": null, + "cross_entropy_block_size": null, + "flash_attention_block_size": null, + "gradient_checkpointing": true, + "head_dim": null, + "hidden_dim": 512, + "initializer_range": 0.02, + "input_embedding_norm": false, + "intermediate_dim": 1792, + "layer_norm_epsilon": 1e-05, + "max_seq_len": 4096, + "num_heads": 8, + "num_kv_heads": 8, + "num_layers": 6, + "qk_norm": null, + "reference_checkpoint": "NousResearch/Llama-2-7b-hf", + "rope": { + "factor": 1.0, + "theta": 10000 + }, + "tie_word_embeddings": false, + "tokenizer": null, + "upcast_attn": false, + "use_bias": false, + "use_gated_attention": true, + "use_layer_norm_weight": true + }, + "model_flops": 2.1289341996446515e+18, + "model_flops_per_token": 227868672.0, + "model_size": 155720192, + "num_chips": 16, + "num_devices": 16, + "resources": { + "cpu": 1, + "device": { + "topology": null, + "type": "v5p-32" + }, + "disk": "1g", + "preemptible": true, + "ram": "128m", + "regions": null, + "replicas": 1 + }, + "run_completion_timestamp": "2025-12-09 10:53:41 UTC", + "tokenized_dataset": "gs://marin-us-central1/tokenized/subcache/fineweb-edu-10B-ac65f6", + "total_tokens": 3114270720, + "train_config": { + "allow_partial_checkpoint": false, + "beta1": null, + "beta2": null, + "cycle_length": null, + "data_seed": null, + "decay": null, + "ema_beta": null, + "epsilon": null, + "initialize_from_checkpoint_path": null, + "initialize_from_hf": null, + "int8": false, + "learning_rate": 0.056, + "lr_schedule": null, + "max_eval_batches": null, + "max_grad_norm": null, + "min_lr_ratio": null, + "num_train_steps": 5940, + "optimizer_config": { + "adam_lr": 0.0112, + "adam_weight_decay": null, + "backend_steps": 5, + "beta1": 0.8, + "beta2": 0.98, + "cooldown": null, + "cycle_length": null, + "cycles": null, + "decay": 0.8, + "default_weight_decay_mask": null, + "epsilon": 1e-15, + "haps": null, + "learning_rate": 0.056, + "lr": 0.02, + "lr_schedule": "linear", + "max_grad_norm": 1, + "min_lr_ratio": 0, + "momentum": 0.95, + "muon_epsilon": 1e-05, + "nesterov": true, + "rewarmup": 0.0, + "use_kimi_scaling": false, + "warmup": 0, + "weight_decay": 0.1, + "weight_decay_modules": null + }, + "per_device_eval_parallelism": null, + "profiler": false, + "profiler_num_steps": 100, + "profiler_start_step": 5, + "reset_data_loader_on_init": true, + "rewarmup": null, + "skip_bad_steps": false, + "steps_per_eval": null, + "steps_per_export": 10000, + "steps_per_hf_export": -1, + "steps_per_task_eval": null, + "train_batch_size": 128, + "train_seq_len": null, + "warmup": null, + "watch": { + "include_histograms": false, + "include_norms": true, + "include_per_parameter_norms": true, + "interval": 10, + "split_scan_layers": true, + "watch_targets": [ + "grads", + "params" + ] + }, + "weight_decay": null, + "z_loss_weight": null + }, + "training_hardware_flops": 1.1712847361037679e+19, + "training_time": 1594.886623234978, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_130m_attngate_4096_splash_lr_sweep_lr_x3_5-4faf11" + } + } + ] +} \ No newline at end of file diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/1_2b/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/1_2b/speedrun_results.json new file mode 100644 index 0000000000..107d9b78b3 --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/1_2b/speedrun_results.json @@ -0,0 +1,143 @@ +{ + "runs": [ + { + "run_info": { + "author": { + "affiliation": "Stanford University", + "name": "Calvin Xu", + "url": "https://pinlinxu.com" + }, + "description": "Hackable Transformer (1_2b); Gated Attention (Splash); LR sweep multiplier=3.5", + "device_flops": 459000000000000.0, + "eval/paloma/c4_en/bpb": 0.9148565530776978, + "model_config": { + "activation_function": "silu", + "attn_backend": null, + "cross_entropy_block_size": null, + "flash_attention_block_size": null, + "gradient_checkpointing": true, + "head_dim": null, + "hidden_dim": 2048, + "initializer_range": 0.02, + "input_embedding_norm": false, + "intermediate_dim": 7168, + "layer_norm_epsilon": 1e-05, + "max_seq_len": 4096, + "num_heads": 16, + "num_kv_heads": 8, + "num_layers": 16, + "qk_norm": null, + "reference_checkpoint": "NousResearch/Llama-2-7b-hf", + "rope": { + "factor": 1.0, + "theta": 10000 + }, + "seq_len": 2048, + "tie_word_embeddings": false, + "tokenizer": null, + "upcast_attn": false, + "use_bias": false, + "use_gated_attention": true, + "use_layer_norm_weight": true + }, + "model_flops": 5.1738353514618185e+20, + "model_flops_per_token": 2877292544.0, + "model_size": 1498482688, + "num_chips": 32, + "num_devices": 32, + "resources": { + "cpu": 1, + "device": { + "kind": "tpu", + "topology": null, + "variant": "v5p-64" + }, + "disk": "1g", + "preemptible": true, + "ram": "128m", + "regions": null, + "replicas": 1 + }, + "run_completion_timestamp": "2025-12-16 22:33:13 UTC", + "tokenized_dataset": "gs://marin-us-east5/tokenized/subcache/fineweb-edu-10B-ac65f6", + "total_tokens": 59938701312, + "train_config": { + "allow_partial_checkpoint": false, + "beta1": null, + "beta2": null, + "cycle_length": null, + "data_seed": null, + "decay": null, + "ema_beta": null, + "epsilon": null, + "initialize_from_checkpoint_path": null, + "initialize_from_hf": null, + "int8": false, + "learning_rate": 0.014, + "lr_schedule": null, + "max_eval_batches": null, + "max_grad_norm": null, + "min_lr_ratio": null, + "num_train_steps": 57162, + "optimizer_config": { + "adam_lr": 0.0042, + "adam_weight_decay": null, + "backend_steps": 5, + "beta1": 0.8, + "beta2": 0.98, + "cooldown": null, + "cycle_length": null, + "cycles": null, + "decay": 1, + "default_weight_decay_mask": null, + "epsilon": 1e-15, + "haps": null, + "learning_rate": 0.014, + "lr": 0.02, + "lr_schedule": "linear", + "max_grad_norm": 2, + "min_lr_ratio": 0, + "momentum": 0.98, + "muon_epsilon": 1e-05, + "nesterov": true, + "rewarmup": 0.0, + "use_kimi_scaling": false, + "warmup": 0, + "weight_decay": 0.1, + "weight_decay_modules": null + }, + "per_device_eval_parallelism": null, + "profiler": false, + "profiler_num_steps": 100, + "profiler_start_step": 5, + "reset_data_loader_on_init": true, + "rewarmup": null, + "skip_bad_steps": false, + "steps_per_eval": null, + "steps_per_export": 10000, + "steps_per_hf_export": -1, + "steps_per_task_eval": null, + "train_batch_size": 256, + "train_seq_len": null, + "warmup": null, + "watch": { + "include_histograms": false, + "include_norms": true, + "include_per_parameter_norms": true, + "interval": 10, + "split_scan_layers": true, + "watch_targets": [ + "grads", + "params" + ] + }, + "weight_decay": null, + "z_loss_weight": null + }, + "training_hardware_flops": 1.8114652296478665e+21, + "training_time": 123329.6044150236, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_1_2b_attngate_2048_splash_lr_sweep_lr_x3_5-b0a3b2" + } + } + ] +} \ No newline at end of file diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/300m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/300m/speedrun_results.json new file mode 100644 index 0000000000..e638bc8495 --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/300m/speedrun_results.json @@ -0,0 +1,141 @@ +{ + "runs": [ + { + "run_info": { + "author": { + "affiliation": "Stanford University", + "name": "Calvin Xu", + "url": "https://pinlinxu.com" + }, + "description": "Hackable Transformer (300m); Gated Attention (Splash); LR sweep multiplier=3.5", + "device_flops": 459000000000000.0, + "eval/paloma/c4_en/bpb": 1.0644729137420654, + "model_config": { + "activation_function": "silu", + "attn_backend": null, + "cross_entropy_block_size": null, + "flash_attention_block_size": null, + "gradient_checkpointing": true, + "head_dim": null, + "hidden_dim": 768, + "initializer_range": 0.02, + "input_embedding_norm": false, + "intermediate_dim": 2688, + "layer_norm_epsilon": 1e-05, + "max_seq_len": 4096, + "num_heads": 12, + "num_kv_heads": 12, + "num_layers": 12, + "qk_norm": null, + "reference_checkpoint": "NousResearch/Llama-2-7b-hf", + "rope": { + "factor": 1.0, + "theta": 10000 + }, + "tie_word_embeddings": false, + "tokenizer": null, + "upcast_attn": false, + "use_bias": false, + "use_gated_attention": true, + "use_layer_norm_weight": true + }, + "model_flops": 1.021384111077458e+19, + "model_flops_per_token": 555024384.0, + "model_size": 306727680, + "num_chips": 16, + "num_devices": 16, + "resources": { + "cpu": 1, + "device": { + "topology": null, + "type": "v5p-32" + }, + "disk": "1g", + "preemptible": true, + "ram": "128m", + "regions": null, + "replicas": 1 + }, + "run_completion_timestamp": "2025-12-09 12:02:21 UTC", + "tokenized_dataset": "gs://marin-us-central1/tokenized/subcache/fineweb-edu-10B-ac65f6", + "total_tokens": 6134169600, + "train_config": { + "allow_partial_checkpoint": false, + "beta1": null, + "beta2": null, + "cycle_length": null, + "data_seed": null, + "decay": null, + "ema_beta": null, + "epsilon": null, + "initialize_from_checkpoint_path": null, + "initialize_from_hf": null, + "int8": false, + "learning_rate": 0.028, + "lr_schedule": null, + "max_eval_batches": null, + "max_grad_norm": null, + "min_lr_ratio": null, + "num_train_steps": 11700, + "optimizer_config": { + "adam_lr": 0.0084, + "adam_weight_decay": null, + "backend_steps": 5, + "beta1": 0.8, + "beta2": 0.98, + "cooldown": null, + "cycle_length": null, + "cycles": null, + "decay": 0.8, + "default_weight_decay_mask": null, + "epsilon": 1e-15, + "haps": null, + "learning_rate": 0.028, + "lr": 0.02, + "lr_schedule": "linear", + "max_grad_norm": 1, + "min_lr_ratio": 0, + "momentum": 0.98, + "muon_epsilon": 1e-05, + "nesterov": true, + "rewarmup": 0.0, + "use_kimi_scaling": false, + "warmup": 0, + "weight_decay": 0.1, + "weight_decay_modules": null + }, + "per_device_eval_parallelism": null, + "profiler": false, + "profiler_num_steps": 100, + "profiler_start_step": 5, + "reset_data_loader_on_init": true, + "rewarmup": null, + "skip_bad_steps": false, + "steps_per_eval": null, + "steps_per_export": 10000, + "steps_per_hf_export": -1, + "steps_per_task_eval": null, + "train_batch_size": 128, + "train_seq_len": null, + "warmup": null, + "watch": { + "include_histograms": false, + "include_norms": true, + "include_per_parameter_norms": true, + "interval": 10, + "split_scan_layers": true, + "watch_targets": [ + "grads", + "params" + ] + }, + "weight_decay": null, + "z_loss_weight": null + }, + "training_hardware_flops": 5.073143302509889e+19, + "training_time": 6907.874867252028, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_300m_attngate_4096_splash_lr_sweep_lr_x3_5-1a6ee3" + } + } + ] +} \ No newline at end of file diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/520m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/520m/speedrun_results.json new file mode 100644 index 0000000000..137bcda734 --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/520m/speedrun_results.json @@ -0,0 +1,141 @@ +{ + "runs": [ + { + "run_info": { + "author": { + "affiliation": "Stanford University", + "name": "Calvin Xu", + "url": "https://pinlinxu.com" + }, + "description": "Hackable Transformer (520m); Gated Attention (Splash); LR sweep multiplier=3.5", + "device_flops": 459000000000000.0, + "eval/paloma/c4_en/bpb": 0.9976317882537842, + "model_config": { + "activation_function": "silu", + "attn_backend": null, + "cross_entropy_block_size": null, + "flash_attention_block_size": null, + "gradient_checkpointing": true, + "head_dim": null, + "hidden_dim": 1024, + "initializer_range": 0.02, + "input_embedding_norm": false, + "intermediate_dim": 3584, + "layer_norm_epsilon": 1e-05, + "max_seq_len": 4096, + "num_heads": 16, + "num_kv_heads": 8, + "num_layers": 24, + "qk_norm": null, + "reference_checkpoint": "NousResearch/Llama-2-7b-hf", + "rope": { + "factor": 1.0, + "theta": 10000 + }, + "tie_word_embeddings": false, + "tokenizer": null, + "upcast_attn": false, + "use_bias": false, + "use_gated_attention": true, + "use_layer_norm_weight": true + }, + "model_flops": 5.081733891346976e+19, + "model_flops_per_token": 1349517312.0, + "model_size": 627622912, + "num_chips": 16, + "num_devices": 16, + "resources": { + "cpu": 1, + "device": { + "topology": null, + "type": "v5p-32" + }, + "disk": "1g", + "preemptible": true, + "ram": "128m", + "regions": null, + "replicas": 1 + }, + "run_completion_timestamp": "2025-12-09 19:33:19 UTC", + "tokenized_dataset": "gs://marin-us-central1/tokenized/subcache/fineweb-edu-10B-ac65f6", + "total_tokens": 12551979008, + "train_config": { + "allow_partial_checkpoint": false, + "beta1": null, + "beta2": null, + "cycle_length": null, + "data_seed": null, + "decay": null, + "ema_beta": null, + "epsilon": null, + "initialize_from_checkpoint_path": null, + "initialize_from_hf": null, + "int8": false, + "learning_rate": 0.028, + "lr_schedule": null, + "max_eval_batches": null, + "max_grad_norm": null, + "min_lr_ratio": null, + "num_train_steps": 23941, + "optimizer_config": { + "adam_lr": 0.0084, + "adam_weight_decay": null, + "backend_steps": 5, + "beta1": 0.8, + "beta2": 0.98, + "cooldown": null, + "cycle_length": null, + "cycles": null, + "decay": 1, + "default_weight_decay_mask": null, + "epsilon": 1e-25, + "haps": null, + "learning_rate": 0.028, + "lr": 0.02, + "lr_schedule": "linear", + "max_grad_norm": 1, + "min_lr_ratio": 0, + "momentum": 0.98, + "muon_epsilon": 1e-05, + "nesterov": true, + "rewarmup": 0.0, + "use_kimi_scaling": false, + "warmup": 0, + "weight_decay": 0.1, + "weight_decay_modules": null + }, + "per_device_eval_parallelism": null, + "profiler": false, + "profiler_num_steps": 100, + "profiler_start_step": 5, + "reset_data_loader_on_init": true, + "rewarmup": null, + "skip_bad_steps": false, + "steps_per_eval": null, + "steps_per_export": 10000, + "steps_per_hf_export": -1, + "steps_per_task_eval": null, + "train_batch_size": 128, + "train_seq_len": null, + "warmup": null, + "watch": { + "include_histograms": false, + "include_norms": true, + "include_per_parameter_norms": true, + "interval": 10, + "split_scan_layers": true, + "watch_targets": [ + "grads", + "params" + ] + }, + "weight_decay": null, + "z_loss_weight": null + }, + "training_hardware_flops": 2.2274891561232957e+20, + "training_time": 30330.73469666797, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_520m_attngate_4096_splash_lr_sweep_lr_x3_5-ea774d" + } + } + ] +} \ No newline at end of file diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x4/130m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x4/130m/speedrun_results.json new file mode 100644 index 0000000000..99da32d436 --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x4/130m/speedrun_results.json @@ -0,0 +1,141 @@ +{ + "runs": [ + { + "run_info": { + "author": { + "affiliation": "Stanford University", + "name": "Calvin Xu", + "url": "https://pinlinxu.com" + }, + "description": "Hackable Transformer (130m); Gated Attention (Splash); LR sweep multiplier=4", + "device_flops": 459000000000000.0, + "eval/paloma/c4_en/bpb": 1.171983242034912, + "model_config": { + "activation_function": "silu", + "attn_backend": null, + "cross_entropy_block_size": null, + "flash_attention_block_size": null, + "gradient_checkpointing": true, + "head_dim": null, + "hidden_dim": 512, + "initializer_range": 0.02, + "input_embedding_norm": false, + "intermediate_dim": 1792, + "layer_norm_epsilon": 1e-05, + "max_seq_len": 4096, + "num_heads": 8, + "num_kv_heads": 8, + "num_layers": 6, + "qk_norm": null, + "reference_checkpoint": "NousResearch/Llama-2-7b-hf", + "rope": { + "factor": 1.0, + "theta": 10000 + }, + "tie_word_embeddings": false, + "tokenizer": null, + "upcast_attn": false, + "use_bias": false, + "use_gated_attention": true, + "use_layer_norm_weight": true + }, + "model_flops": 2.1289341996446515e+18, + "model_flops_per_token": 227868672.0, + "model_size": 155720192, + "num_chips": 16, + "num_devices": 16, + "resources": { + "cpu": 1, + "device": { + "topology": null, + "type": "v5p-32" + }, + "disk": "1g", + "preemptible": true, + "ram": "128m", + "regions": null, + "replicas": 1 + }, + "run_completion_timestamp": "2025-12-09 10:44:46 UTC", + "tokenized_dataset": "gs://marin-us-central1/tokenized/subcache/fineweb-edu-10B-ac65f6", + "total_tokens": 3114270720, + "train_config": { + "allow_partial_checkpoint": false, + "beta1": null, + "beta2": null, + "cycle_length": null, + "data_seed": null, + "decay": null, + "ema_beta": null, + "epsilon": null, + "initialize_from_checkpoint_path": null, + "initialize_from_hf": null, + "int8": false, + "learning_rate": 0.064, + "lr_schedule": null, + "max_eval_batches": null, + "max_grad_norm": null, + "min_lr_ratio": null, + "num_train_steps": 5940, + "optimizer_config": { + "adam_lr": 0.0128, + "adam_weight_decay": null, + "backend_steps": 5, + "beta1": 0.8, + "beta2": 0.98, + "cooldown": null, + "cycle_length": null, + "cycles": null, + "decay": 0.8, + "default_weight_decay_mask": null, + "epsilon": 1e-15, + "haps": null, + "learning_rate": 0.064, + "lr": 0.02, + "lr_schedule": "linear", + "max_grad_norm": 1, + "min_lr_ratio": 0, + "momentum": 0.95, + "muon_epsilon": 1e-05, + "nesterov": true, + "rewarmup": 0.0, + "use_kimi_scaling": false, + "warmup": 0, + "weight_decay": 0.1, + "weight_decay_modules": null + }, + "per_device_eval_parallelism": null, + "profiler": false, + "profiler_num_steps": 100, + "profiler_start_step": 5, + "reset_data_loader_on_init": true, + "rewarmup": null, + "skip_bad_steps": false, + "steps_per_eval": null, + "steps_per_export": 10000, + "steps_per_hf_export": -1, + "steps_per_task_eval": null, + "train_batch_size": 128, + "train_seq_len": null, + "warmup": null, + "watch": { + "include_histograms": false, + "include_norms": true, + "include_per_parameter_norms": true, + "interval": 10, + "split_scan_layers": true, + "watch_targets": [ + "grads", + "params" + ] + }, + "weight_decay": null, + "z_loss_weight": null + }, + "training_hardware_flops": 1.1647166759460198e+19, + "training_time": 1585.9431862010074, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_130m_attngate_4096_splash_lr_sweep_lr_x4-b1cb5b" + } + } + ] +} \ No newline at end of file diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x4/1_2b/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x4/1_2b/speedrun_results.json new file mode 100644 index 0000000000..49e1cc6431 --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x4/1_2b/speedrun_results.json @@ -0,0 +1,143 @@ +{ + "runs": [ + { + "run_info": { + "author": { + "affiliation": "Stanford University", + "name": "Calvin Xu", + "url": "https://pinlinxu.com" + }, + "description": "Hackable Transformer (1_2b); Gated Attention (Splash); LR sweep multiplier=4", + "device_flops": 459000000000000.0, + "eval/paloma/c4_en/bpb": 0.9164798259735107, + "model_config": { + "activation_function": "silu", + "attn_backend": null, + "cross_entropy_block_size": null, + "flash_attention_block_size": null, + "gradient_checkpointing": true, + "head_dim": null, + "hidden_dim": 2048, + "initializer_range": 0.02, + "input_embedding_norm": false, + "intermediate_dim": 7168, + "layer_norm_epsilon": 1e-05, + "max_seq_len": 4096, + "num_heads": 16, + "num_kv_heads": 8, + "num_layers": 16, + "qk_norm": null, + "reference_checkpoint": "NousResearch/Llama-2-7b-hf", + "rope": { + "factor": 1.0, + "theta": 10000 + }, + "seq_len": 2048, + "tie_word_embeddings": false, + "tokenizer": null, + "upcast_attn": false, + "use_bias": false, + "use_gated_attention": true, + "use_layer_norm_weight": true + }, + "model_flops": 5.1738353514618185e+20, + "model_flops_per_token": 2877292544.0, + "model_size": 1498482688, + "num_chips": 32, + "num_devices": 32, + "resources": { + "cpu": 1, + "device": { + "kind": "tpu", + "topology": null, + "variant": "v5p-64" + }, + "disk": "1g", + "preemptible": true, + "ram": "128m", + "regions": null, + "replicas": 1 + }, + "run_completion_timestamp": "2025-12-17 02:26:48 UTC", + "tokenized_dataset": "gs://marin-us-east5/tokenized/subcache/fineweb-edu-10B-ac65f6", + "total_tokens": 59938701312, + "train_config": { + "allow_partial_checkpoint": false, + "beta1": null, + "beta2": null, + "cycle_length": null, + "data_seed": null, + "decay": null, + "ema_beta": null, + "epsilon": null, + "initialize_from_checkpoint_path": null, + "initialize_from_hf": null, + "int8": false, + "learning_rate": 0.016, + "lr_schedule": null, + "max_eval_batches": null, + "max_grad_norm": null, + "min_lr_ratio": null, + "num_train_steps": 57162, + "optimizer_config": { + "adam_lr": 0.0048, + "adam_weight_decay": null, + "backend_steps": 5, + "beta1": 0.8, + "beta2": 0.98, + "cooldown": null, + "cycle_length": null, + "cycles": null, + "decay": 1, + "default_weight_decay_mask": null, + "epsilon": 1e-15, + "haps": null, + "learning_rate": 0.016, + "lr": 0.02, + "lr_schedule": "linear", + "max_grad_norm": 2, + "min_lr_ratio": 0, + "momentum": 0.98, + "muon_epsilon": 1e-05, + "nesterov": true, + "rewarmup": 0.0, + "use_kimi_scaling": false, + "warmup": 0, + "weight_decay": 0.1, + "weight_decay_modules": null + }, + "per_device_eval_parallelism": null, + "profiler": false, + "profiler_num_steps": 100, + "profiler_start_step": 5, + "reset_data_loader_on_init": true, + "rewarmup": null, + "skip_bad_steps": false, + "steps_per_eval": null, + "steps_per_export": 10000, + "steps_per_hf_export": -1, + "steps_per_task_eval": null, + "train_batch_size": 256, + "train_seq_len": null, + "warmup": null, + "watch": { + "include_histograms": false, + "include_norms": true, + "include_per_parameter_norms": true, + "interval": 10, + "split_scan_layers": true, + "watch_targets": [ + "grads", + "params" + ] + }, + "weight_decay": null, + "z_loss_weight": null + }, + "training_hardware_flops": 2.051479930331274e+21, + "training_time": 139670.4745595911, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_1_2b_attngate_2048_splash_lr_sweep_lr_x4-f80807" + } + } + ] +} \ No newline at end of file diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x4/300m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x4/300m/speedrun_results.json new file mode 100644 index 0000000000..41091e8dd5 --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x4/300m/speedrun_results.json @@ -0,0 +1,141 @@ +{ + "runs": [ + { + "run_info": { + "author": { + "affiliation": "Stanford University", + "name": "Calvin Xu", + "url": "https://pinlinxu.com" + }, + "description": "Hackable Transformer (300m); Gated Attention (Splash); LR sweep multiplier=4", + "device_flops": 459000000000000.0, + "eval/paloma/c4_en/bpb": 1.0660113096237183, + "model_config": { + "activation_function": "silu", + "attn_backend": null, + "cross_entropy_block_size": null, + "flash_attention_block_size": null, + "gradient_checkpointing": true, + "head_dim": null, + "hidden_dim": 768, + "initializer_range": 0.02, + "input_embedding_norm": false, + "intermediate_dim": 2688, + "layer_norm_epsilon": 1e-05, + "max_seq_len": 4096, + "num_heads": 12, + "num_kv_heads": 12, + "num_layers": 12, + "qk_norm": null, + "reference_checkpoint": "NousResearch/Llama-2-7b-hf", + "rope": { + "factor": 1.0, + "theta": 10000 + }, + "tie_word_embeddings": false, + "tokenizer": null, + "upcast_attn": false, + "use_bias": false, + "use_gated_attention": true, + "use_layer_norm_weight": true + }, + "model_flops": 1.021384111077458e+19, + "model_flops_per_token": 555024384.0, + "model_size": 306727680, + "num_chips": 16, + "num_devices": 16, + "resources": { + "cpu": 1, + "device": { + "topology": null, + "type": "v5p-32" + }, + "disk": "1g", + "preemptible": true, + "ram": "128m", + "regions": null, + "replicas": 1 + }, + "run_completion_timestamp": "2025-12-09 12:26:35 UTC", + "tokenized_dataset": "gs://marin-us-central1/tokenized/subcache/fineweb-edu-10B-ac65f6", + "total_tokens": 6134169600, + "train_config": { + "allow_partial_checkpoint": false, + "beta1": null, + "beta2": null, + "cycle_length": null, + "data_seed": null, + "decay": null, + "ema_beta": null, + "epsilon": null, + "initialize_from_checkpoint_path": null, + "initialize_from_hf": null, + "int8": false, + "learning_rate": 0.032, + "lr_schedule": null, + "max_eval_batches": null, + "max_grad_norm": null, + "min_lr_ratio": null, + "num_train_steps": 11700, + "optimizer_config": { + "adam_lr": 0.0096, + "adam_weight_decay": null, + "backend_steps": 5, + "beta1": 0.8, + "beta2": 0.98, + "cooldown": null, + "cycle_length": null, + "cycles": null, + "decay": 0.8, + "default_weight_decay_mask": null, + "epsilon": 1e-15, + "haps": null, + "learning_rate": 0.032, + "lr": 0.02, + "lr_schedule": "linear", + "max_grad_norm": 1, + "min_lr_ratio": 0, + "momentum": 0.98, + "muon_epsilon": 1e-05, + "nesterov": true, + "rewarmup": 0.0, + "use_kimi_scaling": false, + "warmup": 0, + "weight_decay": 0.1, + "weight_decay_modules": null + }, + "per_device_eval_parallelism": null, + "profiler": false, + "profiler_num_steps": 100, + "profiler_start_step": 5, + "reset_data_loader_on_init": true, + "rewarmup": null, + "skip_bad_steps": false, + "steps_per_eval": null, + "steps_per_export": 10000, + "steps_per_hf_export": -1, + "steps_per_task_eval": null, + "train_batch_size": 128, + "train_seq_len": null, + "warmup": null, + "watch": { + "include_histograms": false, + "include_norms": true, + "include_per_parameter_norms": true, + "interval": 10, + "split_scan_layers": true, + "watch_targets": [ + "grads", + "params" + ] + }, + "weight_decay": null, + "z_loss_weight": null + }, + "training_hardware_flops": 5.059991395179448e+19, + "training_time": 6889.966496704042, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_300m_attngate_4096_splash_lr_sweep_lr_x4-85c5fd" + } + } + ] +} \ No newline at end of file diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x4/520m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x4/520m/speedrun_results.json new file mode 100644 index 0000000000..aa91796760 --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x4/520m/speedrun_results.json @@ -0,0 +1,141 @@ +{ + "runs": [ + { + "run_info": { + "author": { + "affiliation": "Stanford University", + "name": "Calvin Xu", + "url": "https://pinlinxu.com" + }, + "description": "Hackable Transformer (520m); Gated Attention (Splash); LR sweep multiplier=4", + "device_flops": 459000000000000.0, + "eval/paloma/c4_en/bpb": 1.0015015602111816, + "model_config": { + "activation_function": "silu", + "attn_backend": null, + "cross_entropy_block_size": null, + "flash_attention_block_size": null, + "gradient_checkpointing": true, + "head_dim": null, + "hidden_dim": 1024, + "initializer_range": 0.02, + "input_embedding_norm": false, + "intermediate_dim": 3584, + "layer_norm_epsilon": 1e-05, + "max_seq_len": 4096, + "num_heads": 16, + "num_kv_heads": 8, + "num_layers": 24, + "qk_norm": null, + "reference_checkpoint": "NousResearch/Llama-2-7b-hf", + "rope": { + "factor": 1.0, + "theta": 10000 + }, + "tie_word_embeddings": false, + "tokenizer": null, + "upcast_attn": false, + "use_bias": false, + "use_gated_attention": true, + "use_layer_norm_weight": true + }, + "model_flops": 5.081733891346976e+19, + "model_flops_per_token": 1349517312.0, + "model_size": 627622912, + "num_chips": 16, + "num_devices": 16, + "resources": { + "cpu": 1, + "device": { + "topology": null, + "type": "v5p-32" + }, + "disk": "1g", + "preemptible": true, + "ram": "128m", + "regions": null, + "replicas": 1 + }, + "run_completion_timestamp": "2025-12-09 19:01:46 UTC", + "tokenized_dataset": "gs://marin-us-central1/tokenized/subcache/fineweb-edu-10B-ac65f6", + "total_tokens": 12551979008, + "train_config": { + "allow_partial_checkpoint": false, + "beta1": null, + "beta2": null, + "cycle_length": null, + "data_seed": null, + "decay": null, + "ema_beta": null, + "epsilon": null, + "initialize_from_checkpoint_path": null, + "initialize_from_hf": null, + "int8": false, + "learning_rate": 0.032, + "lr_schedule": null, + "max_eval_batches": null, + "max_grad_norm": null, + "min_lr_ratio": null, + "num_train_steps": 23941, + "optimizer_config": { + "adam_lr": 0.0096, + "adam_weight_decay": null, + "backend_steps": 5, + "beta1": 0.8, + "beta2": 0.98, + "cooldown": null, + "cycle_length": null, + "cycles": null, + "decay": 1, + "default_weight_decay_mask": null, + "epsilon": 1e-25, + "haps": null, + "learning_rate": 0.032, + "lr": 0.02, + "lr_schedule": "linear", + "max_grad_norm": 1, + "min_lr_ratio": 0, + "momentum": 0.98, + "muon_epsilon": 1e-05, + "nesterov": true, + "rewarmup": 0.0, + "use_kimi_scaling": false, + "warmup": 0, + "weight_decay": 0.1, + "weight_decay_modules": null + }, + "per_device_eval_parallelism": null, + "profiler": false, + "profiler_num_steps": 100, + "profiler_start_step": 5, + "reset_data_loader_on_init": true, + "rewarmup": null, + "skip_bad_steps": false, + "steps_per_eval": null, + "steps_per_export": 10000, + "steps_per_hf_export": -1, + "steps_per_task_eval": null, + "train_batch_size": 128, + "train_seq_len": null, + "warmup": null, + "watch": { + "include_histograms": false, + "include_norms": true, + "include_per_parameter_norms": true, + "interval": 10, + "split_scan_layers": true, + "watch_targets": [ + "grads", + "params" + ] + }, + "weight_decay": null, + "z_loss_weight": null + }, + "training_hardware_flops": 2.236153294934872e+20, + "training_time": 30448.710443012962, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_520m_attngate_4096_splash_lr_sweep_lr_x4-f42244" + } + } + ] +} \ No newline at end of file From 7d38157d04a0f9ec85f413d06aaac134c9f1e05c Mon Sep 17 00:00:00 2001 From: "Pinlin [Calvin] Xu" Date: Mon, 29 Dec 2025 21:52:46 -0800 Subject: [PATCH 08/11] tweak --- .../hackable_transformer_attn_gate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/experiments/speedrun/hackable_transformer_attn_gate/hackable_transformer_attn_gate.py b/experiments/speedrun/hackable_transformer_attn_gate/hackable_transformer_attn_gate.py index b0bf58d458..2b767f2c96 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/hackable_transformer_attn_gate.py +++ b/experiments/speedrun/hackable_transformer_attn_gate/hackable_transformer_attn_gate.py @@ -470,7 +470,7 @@ def _resource_presets(use_gpu: bool = False): "130m": ResourceConfig.with_tpu("v5p-32"), "300m": ResourceConfig.with_tpu("v5p-32"), "520m": ResourceConfig.with_tpu("v5p-32"), - "1_2b": ResourceConfig.with_tpu("v5p-64"), + "1_2b": ResourceConfig.with_tpu("v5p-32"), } @@ -526,7 +526,7 @@ def build_run( ) lr_tag = f"_lr_x{_format_multiplier_label(lr_multiplier)}" if lr_multiplier is not None else "" - run_name = f"hacktx_{size}_{'attngate' if use_gate else 'stdattn'}_{seq_len}_splash_lr_sweep{lr_tag}" + run_name = f"hacktx_{size}_{'attngate' if use_gate else 'stdattn'}_{seq_len}_splash_lr_sweep{lr_tag}_v5p32" desc = ( f"Hackable Transformer ({size}); " f"{'Gated Attention' if use_gate else 'Std Attention'} (Splash); " From 03da33019398f7fe445b0747948f3912cfd59b44 Mon Sep 17 00:00:00 2001 From: "Pinlin [Calvin] Xu" Date: Sun, 4 Jan 2026 20:57:17 -0800 Subject: [PATCH 09/11] precommit --- .../lr_x1/130m/speedrun_results.json | 2 +- .../lr_x1/1_2b/speedrun_results.json | 2 +- .../lr_x1/300m/speedrun_results.json | 2 +- .../lr_x1/520m/speedrun_results.json | 2 +- .../lr_x1_5/130m/speedrun_results.json | 2 +- .../lr_x1_5/1_2b/speedrun_results.json | 2 +- .../lr_x1_5/300m/speedrun_results.json | 2 +- .../lr_x1_5/520m/speedrun_results.json | 2 +- .../lr_x2/130m/speedrun_results.json | 2 +- .../lr_x2/1_2b/speedrun_results.json | 2 +- .../lr_x2/300m/speedrun_results.json | 2 +- .../lr_x2/520m/speedrun_results.json | 2 +- .../lr_x2_5/130m/speedrun_results.json | 2 +- .../lr_x2_5/1_2b/speedrun_results.json | 2 +- .../lr_x2_5/300m/speedrun_results.json | 2 +- .../lr_x2_5/520m/speedrun_results.json | 2 +- .../lr_x3/130m/speedrun_results.json | 2 +- .../lr_x3/1_2b/speedrun_results.json | 2 +- .../lr_x3/300m/speedrun_results.json | 2 +- .../lr_x3/520m/speedrun_results.json | 2 +- .../lr_x3_5/130m/speedrun_results.json | 2 +- .../lr_x3_5/1_2b/speedrun_results.json | 2 +- .../lr_x3_5/300m/speedrun_results.json | 2 +- .../lr_x3_5/520m/speedrun_results.json | 2 +- .../lr_x4/130m/speedrun_results.json | 2 +- .../lr_x4/1_2b/speedrun_results.json | 2 +- .../lr_x4/300m/speedrun_results.json | 2 +- .../lr_x4/520m/speedrun_results.json | 2 +- .../naive_no_fuse_q/130m/speedrun_results.json | 2 +- .../naive_no_fuse_q/1_2b/speedrun_results.json | 2 +- .../naive_no_fuse_q/300m/speedrun_results.json | 2 +- .../naive_no_fuse_q/520m/speedrun_results.json | 2 +- lib/levanter/src/levanter/layers/attention.py | 12 ++++++------ lib/levanter/tests/test_attention.py | 6 +++--- 34 files changed, 41 insertions(+), 41 deletions(-) diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x1/130m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1/130m/speedrun_results.json index e4b12437cc..eb3f224ba6 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/lr_x1/130m/speedrun_results.json +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1/130m/speedrun_results.json @@ -138,4 +138,4 @@ } } ] -} \ No newline at end of file +} diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x1/1_2b/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1/1_2b/speedrun_results.json index b5a75910bf..9629799af5 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/lr_x1/1_2b/speedrun_results.json +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1/1_2b/speedrun_results.json @@ -140,4 +140,4 @@ } } ] -} \ No newline at end of file +} diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x1/300m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1/300m/speedrun_results.json index df4850adfe..1f74395a36 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/lr_x1/300m/speedrun_results.json +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1/300m/speedrun_results.json @@ -138,4 +138,4 @@ } } ] -} \ No newline at end of file +} diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x1/520m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1/520m/speedrun_results.json index a2423e58ac..b53ac37350 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/lr_x1/520m/speedrun_results.json +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1/520m/speedrun_results.json @@ -138,4 +138,4 @@ } } ] -} \ No newline at end of file +} diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/130m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/130m/speedrun_results.json index 8ded8f4651..3a9e7c6474 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/130m/speedrun_results.json +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/130m/speedrun_results.json @@ -138,4 +138,4 @@ } } ] -} \ No newline at end of file +} diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/1_2b/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/1_2b/speedrun_results.json index 9988e91b5c..a7a662d547 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/1_2b/speedrun_results.json +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/1_2b/speedrun_results.json @@ -140,4 +140,4 @@ } } ] -} \ No newline at end of file +} diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/300m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/300m/speedrun_results.json index dd9ffa16fa..59fc868322 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/300m/speedrun_results.json +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/300m/speedrun_results.json @@ -138,4 +138,4 @@ } } ] -} \ No newline at end of file +} diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/520m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/520m/speedrun_results.json index fa5babaacb..01540fd63b 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/520m/speedrun_results.json +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/520m/speedrun_results.json @@ -138,4 +138,4 @@ } } ] -} \ No newline at end of file +} diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x2/130m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2/130m/speedrun_results.json index c8dc916e77..cd1d1c9195 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/lr_x2/130m/speedrun_results.json +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2/130m/speedrun_results.json @@ -138,4 +138,4 @@ } } ] -} \ No newline at end of file +} diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x2/1_2b/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2/1_2b/speedrun_results.json index b89184f05f..bc85070a0a 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/lr_x2/1_2b/speedrun_results.json +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2/1_2b/speedrun_results.json @@ -140,4 +140,4 @@ } } ] -} \ No newline at end of file +} diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x2/300m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2/300m/speedrun_results.json index 63fa3f7a72..5c89c5e54e 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/lr_x2/300m/speedrun_results.json +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2/300m/speedrun_results.json @@ -138,4 +138,4 @@ } } ] -} \ No newline at end of file +} diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x2/520m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2/520m/speedrun_results.json index 69d348ea33..065735371c 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/lr_x2/520m/speedrun_results.json +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2/520m/speedrun_results.json @@ -138,4 +138,4 @@ } } ] -} \ No newline at end of file +} diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/130m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/130m/speedrun_results.json index 995167ee58..29236ac9b2 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/130m/speedrun_results.json +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/130m/speedrun_results.json @@ -138,4 +138,4 @@ } } ] -} \ No newline at end of file +} diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/1_2b/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/1_2b/speedrun_results.json index e4f78d80a9..1e3262390d 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/1_2b/speedrun_results.json +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/1_2b/speedrun_results.json @@ -140,4 +140,4 @@ } } ] -} \ No newline at end of file +} diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/300m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/300m/speedrun_results.json index c709179068..3396ba42e9 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/300m/speedrun_results.json +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/300m/speedrun_results.json @@ -138,4 +138,4 @@ } } ] -} \ No newline at end of file +} diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/520m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/520m/speedrun_results.json index 19845cb864..0996663372 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/520m/speedrun_results.json +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/520m/speedrun_results.json @@ -140,4 +140,4 @@ } } ] -} \ No newline at end of file +} diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x3/130m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3/130m/speedrun_results.json index 564e5b8078..d89452adff 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/lr_x3/130m/speedrun_results.json +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3/130m/speedrun_results.json @@ -138,4 +138,4 @@ } } ] -} \ No newline at end of file +} diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x3/1_2b/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3/1_2b/speedrun_results.json index b71df6db06..2539d30ffa 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/lr_x3/1_2b/speedrun_results.json +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3/1_2b/speedrun_results.json @@ -140,4 +140,4 @@ } } ] -} \ No newline at end of file +} diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x3/300m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3/300m/speedrun_results.json index 95d2277545..d89eee804d 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/lr_x3/300m/speedrun_results.json +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3/300m/speedrun_results.json @@ -138,4 +138,4 @@ } } ] -} \ No newline at end of file +} diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x3/520m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3/520m/speedrun_results.json index 85cc77d30d..27112f0710 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/lr_x3/520m/speedrun_results.json +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3/520m/speedrun_results.json @@ -138,4 +138,4 @@ } } ] -} \ No newline at end of file +} diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/130m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/130m/speedrun_results.json index 6727da96ff..5b6ee02988 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/130m/speedrun_results.json +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/130m/speedrun_results.json @@ -138,4 +138,4 @@ } } ] -} \ No newline at end of file +} diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/1_2b/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/1_2b/speedrun_results.json index 107d9b78b3..79c7c4fa23 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/1_2b/speedrun_results.json +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/1_2b/speedrun_results.json @@ -140,4 +140,4 @@ } } ] -} \ No newline at end of file +} diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/300m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/300m/speedrun_results.json index e638bc8495..898ff7c6fe 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/300m/speedrun_results.json +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/300m/speedrun_results.json @@ -138,4 +138,4 @@ } } ] -} \ No newline at end of file +} diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/520m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/520m/speedrun_results.json index 137bcda734..2d42e75484 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/520m/speedrun_results.json +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/520m/speedrun_results.json @@ -138,4 +138,4 @@ } } ] -} \ No newline at end of file +} diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x4/130m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x4/130m/speedrun_results.json index 99da32d436..c5015256b4 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/lr_x4/130m/speedrun_results.json +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x4/130m/speedrun_results.json @@ -138,4 +138,4 @@ } } ] -} \ No newline at end of file +} diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x4/1_2b/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x4/1_2b/speedrun_results.json index 49e1cc6431..817bcd9636 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/lr_x4/1_2b/speedrun_results.json +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x4/1_2b/speedrun_results.json @@ -140,4 +140,4 @@ } } ] -} \ No newline at end of file +} diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x4/300m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x4/300m/speedrun_results.json index 41091e8dd5..f1344d0d7d 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/lr_x4/300m/speedrun_results.json +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x4/300m/speedrun_results.json @@ -138,4 +138,4 @@ } } ] -} \ No newline at end of file +} diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x4/520m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x4/520m/speedrun_results.json index aa91796760..0be866a11f 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/lr_x4/520m/speedrun_results.json +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x4/520m/speedrun_results.json @@ -138,4 +138,4 @@ } } ] -} \ No newline at end of file +} diff --git a/experiments/speedrun/hackable_transformer_attn_gate/naive_no_fuse_q/130m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/naive_no_fuse_q/130m/speedrun_results.json index 349d7d46f9..6ef86d1991 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/naive_no_fuse_q/130m/speedrun_results.json +++ b/experiments/speedrun/hackable_transformer_attn_gate/naive_no_fuse_q/130m/speedrun_results.json @@ -137,4 +137,4 @@ } } ] -} \ No newline at end of file +} diff --git a/experiments/speedrun/hackable_transformer_attn_gate/naive_no_fuse_q/1_2b/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/naive_no_fuse_q/1_2b/speedrun_results.json index 60172d9a5a..c4b81a736f 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/naive_no_fuse_q/1_2b/speedrun_results.json +++ b/experiments/speedrun/hackable_transformer_attn_gate/naive_no_fuse_q/1_2b/speedrun_results.json @@ -137,4 +137,4 @@ } } ] -} \ No newline at end of file +} diff --git a/experiments/speedrun/hackable_transformer_attn_gate/naive_no_fuse_q/300m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/naive_no_fuse_q/300m/speedrun_results.json index 31f5452cc7..9e2fe8f2ea 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/naive_no_fuse_q/300m/speedrun_results.json +++ b/experiments/speedrun/hackable_transformer_attn_gate/naive_no_fuse_q/300m/speedrun_results.json @@ -137,4 +137,4 @@ } } ] -} \ No newline at end of file +} diff --git a/experiments/speedrun/hackable_transformer_attn_gate/naive_no_fuse_q/520m/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/naive_no_fuse_q/520m/speedrun_results.json index 5ff5fd85f9..b2960ec991 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/naive_no_fuse_q/520m/speedrun_results.json +++ b/experiments/speedrun/hackable_transformer_attn_gate/naive_no_fuse_q/520m/speedrun_results.json @@ -137,4 +137,4 @@ } } ] -} \ No newline at end of file +} diff --git a/lib/levanter/src/levanter/layers/attention.py b/lib/levanter/src/levanter/layers/attention.py index a5953db0f5..77eac09921 100644 --- a/lib/levanter/src/levanter/layers/attention.py +++ b/lib/levanter/src/levanter/layers/attention.py @@ -1572,9 +1572,9 @@ class AttentionConfig: gated: Literal["none", "headwise", "elementwise"] = "none" def __post_init__(self): - assert self.num_heads % self.num_kv_heads == 0, ( - f"num_heads={self.num_heads} not divisible by num_kv_heads={self.num_kv_heads}." - ) + assert ( + self.num_heads % self.num_kv_heads == 0 + ), f"num_heads={self.num_heads} not divisible by num_kv_heads={self.num_kv_heads}." @property def head_size(self) -> int: @@ -2331,9 +2331,9 @@ def __call__( if self.config.q_lora_rank is None: q = self.q_proj(x, key=k_q_a) else: - assert self.q_a_proj is not None and self.q_a_norm is not None and self.q_b_proj is not None, ( - "q_lora_rank defined, but LoRA matrices are not." - ) + assert ( + self.q_a_proj is not None and self.q_a_norm is not None and self.q_b_proj is not None + ), "q_lora_rank defined, but LoRA matrices are not." q = self.q_a_proj(x, key=k_q_a) q = self.q_a_norm(q) q = self.q_b_proj(q, key=k_q_b) diff --git a/lib/levanter/tests/test_attention.py b/lib/levanter/tests/test_attention.py index 42d7ea6777..e571767efc 100644 --- a/lib/levanter/tests/test_attention.py +++ b/lib/levanter/tests/test_attention.py @@ -605,9 +605,9 @@ def test_causal_offset_cross_attention(impl): precision=Precision.HIGHEST, ) - assert not jnp.allclose(offset_out.array, wrong_out.array, atol=2e-3, rtol=2e-3), ( - "Output should differ without offset" - ) + assert not jnp.allclose( + offset_out.array, wrong_out.array, atol=2e-3, rtol=2e-3 + ), "Output should differ without offset" # This is a bottleneck in tests From 91566925805f1f73a3102972f494fb8c809826b7 Mon Sep 17 00:00:00 2001 From: "Pinlin [Calvin] Xu" Date: Mon, 19 Jan 2026 16:41:04 -0800 Subject: [PATCH 10/11] Check in results on v5p32 --- .../lr_x1/1_2b/speedrun_results.json | 21 +-- .../lr_x1/1_2b/speedrun_results_v5p64.json | 143 ++++++++++++++++++ .../lr_x1_5/1_2b/speedrun_results.json | 21 +-- .../lr_x1_5/1_2b/speedrun_results_v5p64.json | 143 ++++++++++++++++++ .../lr_x2/1_2b/speedrun_results.json | 21 +-- .../lr_x2/1_2b/speedrun_results_v5p64.json | 143 ++++++++++++++++++ .../lr_x2_5/1_2b/speedrun_results.json | 21 +-- .../lr_x2_5/1_2b/speedrun_results_v5p64.json | 143 ++++++++++++++++++ .../lr_x3/1_2b/speedrun_results.json | 21 +-- .../lr_x3/1_2b/speedrun_results_v5p64.json | 143 ++++++++++++++++++ .../lr_x3_5/1_2b/speedrun_results.json | 21 +-- .../lr_x3_5/1_2b/speedrun_results_v5p64.json | 143 ++++++++++++++++++ .../lr_x4/1_2b/speedrun_results.json | 21 +-- .../lr_x4/1_2b/speedrun_results_v5p64.json | 143 ++++++++++++++++++ 14 files changed, 1078 insertions(+), 70 deletions(-) create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/lr_x1/1_2b/speedrun_results_v5p64.json create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/1_2b/speedrun_results_v5p64.json create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/lr_x2/1_2b/speedrun_results_v5p64.json create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/1_2b/speedrun_results_v5p64.json create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/lr_x3/1_2b/speedrun_results_v5p64.json create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/1_2b/speedrun_results_v5p64.json create mode 100644 experiments/speedrun/hackable_transformer_attn_gate/lr_x4/1_2b/speedrun_results_v5p64.json diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x1/1_2b/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1/1_2b/speedrun_results.json index 9629799af5..fd5f073b43 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/lr_x1/1_2b/speedrun_results.json +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1/1_2b/speedrun_results.json @@ -9,7 +9,7 @@ }, "description": "Hackable Transformer (1_2b); Gated Attention (Splash); LR sweep multiplier=1", "device_flops": 459000000000000.0, - "eval/paloma/c4_en/bpb": 0.9160401225090027, + "eval/paloma/c4_en/bpb": 0.9160435795783997, "model_config": { "activation_function": "silu", "attn_backend": null, @@ -37,20 +37,20 @@ "tokenizer": null, "upcast_attn": false, "use_bias": false, - "use_gated_attention": true, + "use_gated_attention": "elementwise", "use_layer_norm_weight": true }, "model_flops": 5.1738353514618185e+20, "model_flops_per_token": 2877292544.0, "model_size": 1498482688, - "num_chips": 32, - "num_devices": 32, + "num_chips": 16, + "num_devices": 16, "resources": { "cpu": 1, "device": { "kind": "tpu", "topology": null, - "variant": "v5p-64" + "variant": "v5p-32" }, "disk": "1g", "preemptible": true, @@ -58,8 +58,8 @@ "regions": null, "replicas": 1 }, - "run_completion_timestamp": "2025-12-17 04:43:09 UTC", - "tokenized_dataset": "gs://marin-us-east5/tokenized/subcache/fineweb-edu-10B-ac65f6", + "run_completion_timestamp": "2026-01-07 23:16:03 UTC", + "tokenized_dataset": "gs://marin-us-central1/tokenized/subcache/fineweb-edu-10B-ac65f6", "total_tokens": 59938701312, "train_config": { "allow_partial_checkpoint": false, @@ -70,6 +70,7 @@ "decay": null, "ema_beta": null, "epsilon": null, + "explicit_mesh_axes": false, "initialize_from_checkpoint_path": null, "initialize_from_hf": null, "int8": false, @@ -134,9 +135,9 @@ "weight_decay": null, "z_loss_weight": null }, - "training_hardware_flops": 1.80638282968129e+21, - "training_time": 122983.58045215753, - "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_1_2b_attngate_2048_splash_lr_sweep_lr_x1-ecb416" + "training_hardware_flops": 1.1808594941354136e+21, + "training_time": 160792.41477878726, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_1_2b_attngate_2048_splash_lr_sweep_lr_x1_v5p32-ec656c" } } ] diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x1/1_2b/speedrun_results_v5p64.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1/1_2b/speedrun_results_v5p64.json new file mode 100644 index 0000000000..9629799af5 --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1/1_2b/speedrun_results_v5p64.json @@ -0,0 +1,143 @@ +{ + "runs": [ + { + "run_info": { + "author": { + "affiliation": "Stanford University", + "name": "Calvin Xu", + "url": "https://pinlinxu.com" + }, + "description": "Hackable Transformer (1_2b); Gated Attention (Splash); LR sweep multiplier=1", + "device_flops": 459000000000000.0, + "eval/paloma/c4_en/bpb": 0.9160401225090027, + "model_config": { + "activation_function": "silu", + "attn_backend": null, + "cross_entropy_block_size": null, + "flash_attention_block_size": null, + "gradient_checkpointing": true, + "head_dim": null, + "hidden_dim": 2048, + "initializer_range": 0.02, + "input_embedding_norm": false, + "intermediate_dim": 7168, + "layer_norm_epsilon": 1e-05, + "max_seq_len": 4096, + "num_heads": 16, + "num_kv_heads": 8, + "num_layers": 16, + "qk_norm": null, + "reference_checkpoint": "NousResearch/Llama-2-7b-hf", + "rope": { + "factor": 1.0, + "theta": 10000 + }, + "seq_len": 2048, + "tie_word_embeddings": false, + "tokenizer": null, + "upcast_attn": false, + "use_bias": false, + "use_gated_attention": true, + "use_layer_norm_weight": true + }, + "model_flops": 5.1738353514618185e+20, + "model_flops_per_token": 2877292544.0, + "model_size": 1498482688, + "num_chips": 32, + "num_devices": 32, + "resources": { + "cpu": 1, + "device": { + "kind": "tpu", + "topology": null, + "variant": "v5p-64" + }, + "disk": "1g", + "preemptible": true, + "ram": "128m", + "regions": null, + "replicas": 1 + }, + "run_completion_timestamp": "2025-12-17 04:43:09 UTC", + "tokenized_dataset": "gs://marin-us-east5/tokenized/subcache/fineweb-edu-10B-ac65f6", + "total_tokens": 59938701312, + "train_config": { + "allow_partial_checkpoint": false, + "beta1": null, + "beta2": null, + "cycle_length": null, + "data_seed": null, + "decay": null, + "ema_beta": null, + "epsilon": null, + "initialize_from_checkpoint_path": null, + "initialize_from_hf": null, + "int8": false, + "learning_rate": 0.004, + "lr_schedule": null, + "max_eval_batches": null, + "max_grad_norm": null, + "min_lr_ratio": null, + "num_train_steps": 57162, + "optimizer_config": { + "adam_lr": 0.0012, + "adam_weight_decay": null, + "backend_steps": 5, + "beta1": 0.8, + "beta2": 0.98, + "cooldown": null, + "cycle_length": null, + "cycles": null, + "decay": 1, + "default_weight_decay_mask": null, + "epsilon": 1e-15, + "haps": null, + "learning_rate": 0.004, + "lr": 0.02, + "lr_schedule": "linear", + "max_grad_norm": 2, + "min_lr_ratio": 0, + "momentum": 0.98, + "muon_epsilon": 1e-05, + "nesterov": true, + "rewarmup": 0.0, + "use_kimi_scaling": false, + "warmup": 0, + "weight_decay": 0.1, + "weight_decay_modules": null + }, + "per_device_eval_parallelism": null, + "profiler": false, + "profiler_num_steps": 100, + "profiler_start_step": 5, + "reset_data_loader_on_init": true, + "rewarmup": null, + "skip_bad_steps": false, + "steps_per_eval": null, + "steps_per_export": 10000, + "steps_per_hf_export": -1, + "steps_per_task_eval": null, + "train_batch_size": 256, + "train_seq_len": null, + "warmup": null, + "watch": { + "include_histograms": false, + "include_norms": true, + "include_per_parameter_norms": true, + "interval": 10, + "split_scan_layers": true, + "watch_targets": [ + "grads", + "params" + ] + }, + "weight_decay": null, + "z_loss_weight": null + }, + "training_hardware_flops": 1.80638282968129e+21, + "training_time": 122983.58045215753, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_1_2b_attngate_2048_splash_lr_sweep_lr_x1-ecb416" + } + } + ] +} diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/1_2b/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/1_2b/speedrun_results.json index a7a662d547..10977b1ea4 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/1_2b/speedrun_results.json +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/1_2b/speedrun_results.json @@ -9,7 +9,7 @@ }, "description": "Hackable Transformer (1_2b); Gated Attention (Splash); LR sweep multiplier=1.5", "device_flops": 459000000000000.0, - "eval/paloma/c4_en/bpb": 0.9122039675712585, + "eval/paloma/c4_en/bpb": 0.9129290580749512, "model_config": { "activation_function": "silu", "attn_backend": null, @@ -37,20 +37,20 @@ "tokenizer": null, "upcast_attn": false, "use_bias": false, - "use_gated_attention": true, + "use_gated_attention": "elementwise", "use_layer_norm_weight": true }, "model_flops": 5.1738353514618185e+20, "model_flops_per_token": 2877292544.0, "model_size": 1498482688, - "num_chips": 32, - "num_devices": 32, + "num_chips": 16, + "num_devices": 16, "resources": { "cpu": 1, "device": { "kind": "tpu", "topology": null, - "variant": "v5p-64" + "variant": "v5p-32" }, "disk": "1g", "preemptible": true, @@ -58,8 +58,8 @@ "regions": null, "replicas": 1 }, - "run_completion_timestamp": "2025-12-17 00:31:25 UTC", - "tokenized_dataset": "gs://marin-us-east5/tokenized/subcache/fineweb-edu-10B-ac65f6", + "run_completion_timestamp": "2026-01-10 22:53:32 UTC", + "tokenized_dataset": "gs://marin-us-central1/tokenized/subcache/fineweb-edu-10B-ac65f6", "total_tokens": 59938701312, "train_config": { "allow_partial_checkpoint": false, @@ -70,6 +70,7 @@ "decay": null, "ema_beta": null, "epsilon": null, + "explicit_mesh_axes": false, "initialize_from_checkpoint_path": null, "initialize_from_hf": null, "int8": false, @@ -134,9 +135,9 @@ "weight_decay": null, "z_loss_weight": null }, - "training_hardware_flops": 1.8044518078036129e+21, - "training_time": 122852.11109774053, - "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_1_2b_attngate_2048_splash_lr_sweep_lr_x1_5-e5d647" + "training_hardware_flops": 1.1933404335361132e+21, + "training_time": 162491.88909805464, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_1_2b_attngate_2048_splash_lr_sweep_lr_x1_5_v5p32-f366f3" } } ] diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/1_2b/speedrun_results_v5p64.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/1_2b/speedrun_results_v5p64.json new file mode 100644 index 0000000000..a7a662d547 --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x1_5/1_2b/speedrun_results_v5p64.json @@ -0,0 +1,143 @@ +{ + "runs": [ + { + "run_info": { + "author": { + "affiliation": "Stanford University", + "name": "Calvin Xu", + "url": "https://pinlinxu.com" + }, + "description": "Hackable Transformer (1_2b); Gated Attention (Splash); LR sweep multiplier=1.5", + "device_flops": 459000000000000.0, + "eval/paloma/c4_en/bpb": 0.9122039675712585, + "model_config": { + "activation_function": "silu", + "attn_backend": null, + "cross_entropy_block_size": null, + "flash_attention_block_size": null, + "gradient_checkpointing": true, + "head_dim": null, + "hidden_dim": 2048, + "initializer_range": 0.02, + "input_embedding_norm": false, + "intermediate_dim": 7168, + "layer_norm_epsilon": 1e-05, + "max_seq_len": 4096, + "num_heads": 16, + "num_kv_heads": 8, + "num_layers": 16, + "qk_norm": null, + "reference_checkpoint": "NousResearch/Llama-2-7b-hf", + "rope": { + "factor": 1.0, + "theta": 10000 + }, + "seq_len": 2048, + "tie_word_embeddings": false, + "tokenizer": null, + "upcast_attn": false, + "use_bias": false, + "use_gated_attention": true, + "use_layer_norm_weight": true + }, + "model_flops": 5.1738353514618185e+20, + "model_flops_per_token": 2877292544.0, + "model_size": 1498482688, + "num_chips": 32, + "num_devices": 32, + "resources": { + "cpu": 1, + "device": { + "kind": "tpu", + "topology": null, + "variant": "v5p-64" + }, + "disk": "1g", + "preemptible": true, + "ram": "128m", + "regions": null, + "replicas": 1 + }, + "run_completion_timestamp": "2025-12-17 00:31:25 UTC", + "tokenized_dataset": "gs://marin-us-east5/tokenized/subcache/fineweb-edu-10B-ac65f6", + "total_tokens": 59938701312, + "train_config": { + "allow_partial_checkpoint": false, + "beta1": null, + "beta2": null, + "cycle_length": null, + "data_seed": null, + "decay": null, + "ema_beta": null, + "epsilon": null, + "initialize_from_checkpoint_path": null, + "initialize_from_hf": null, + "int8": false, + "learning_rate": 0.006, + "lr_schedule": null, + "max_eval_batches": null, + "max_grad_norm": null, + "min_lr_ratio": null, + "num_train_steps": 57162, + "optimizer_config": { + "adam_lr": 0.0018, + "adam_weight_decay": null, + "backend_steps": 5, + "beta1": 0.8, + "beta2": 0.98, + "cooldown": null, + "cycle_length": null, + "cycles": null, + "decay": 1, + "default_weight_decay_mask": null, + "epsilon": 1e-15, + "haps": null, + "learning_rate": 0.006, + "lr": 0.02, + "lr_schedule": "linear", + "max_grad_norm": 2, + "min_lr_ratio": 0, + "momentum": 0.98, + "muon_epsilon": 1e-05, + "nesterov": true, + "rewarmup": 0.0, + "use_kimi_scaling": false, + "warmup": 0, + "weight_decay": 0.1, + "weight_decay_modules": null + }, + "per_device_eval_parallelism": null, + "profiler": false, + "profiler_num_steps": 100, + "profiler_start_step": 5, + "reset_data_loader_on_init": true, + "rewarmup": null, + "skip_bad_steps": false, + "steps_per_eval": null, + "steps_per_export": 10000, + "steps_per_hf_export": -1, + "steps_per_task_eval": null, + "train_batch_size": 256, + "train_seq_len": null, + "warmup": null, + "watch": { + "include_histograms": false, + "include_norms": true, + "include_per_parameter_norms": true, + "interval": 10, + "split_scan_layers": true, + "watch_targets": [ + "grads", + "params" + ] + }, + "weight_decay": null, + "z_loss_weight": null + }, + "training_hardware_flops": 1.8044518078036129e+21, + "training_time": 122852.11109774053, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_1_2b_attngate_2048_splash_lr_sweep_lr_x1_5-e5d647" + } + } + ] +} diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x2/1_2b/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2/1_2b/speedrun_results.json index bc85070a0a..3f3bbec3e4 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/lr_x2/1_2b/speedrun_results.json +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2/1_2b/speedrun_results.json @@ -9,7 +9,7 @@ }, "description": "Hackable Transformer (1_2b); Gated Attention (Splash); LR sweep multiplier=2", "device_flops": 459000000000000.0, - "eval/paloma/c4_en/bpb": 0.9119555354118347, + "eval/paloma/c4_en/bpb": 0.911785364151001, "model_config": { "activation_function": "silu", "attn_backend": null, @@ -37,20 +37,20 @@ "tokenizer": null, "upcast_attn": false, "use_bias": false, - "use_gated_attention": true, + "use_gated_attention": "elementwise", "use_layer_norm_weight": true }, "model_flops": 5.1738353514618185e+20, "model_flops_per_token": 2877292544.0, "model_size": 1498482688, - "num_chips": 32, - "num_devices": 32, + "num_chips": 16, + "num_devices": 16, "resources": { "cpu": 1, "device": { "kind": "tpu", "topology": null, - "variant": "v5p-64" + "variant": "v5p-32" }, "disk": "1g", "preemptible": true, @@ -58,8 +58,8 @@ "regions": null, "replicas": 1 }, - "run_completion_timestamp": "2025-12-17 01:10:26 UTC", - "tokenized_dataset": "gs://marin-us-east5/tokenized/subcache/fineweb-edu-10B-ac65f6", + "run_completion_timestamp": "2026-01-07 23:37:29 UTC", + "tokenized_dataset": "gs://marin-us-central1/tokenized/subcache/fineweb-edu-10B-ac65f6", "total_tokens": 59938701312, "train_config": { "allow_partial_checkpoint": false, @@ -70,6 +70,7 @@ "decay": null, "ema_beta": null, "epsilon": null, + "explicit_mesh_axes": false, "initialize_from_checkpoint_path": null, "initialize_from_hf": null, "int8": false, @@ -134,9 +135,9 @@ "weight_decay": null, "z_loss_weight": null }, - "training_hardware_flops": 1.8017056211590617e+21, - "training_time": 122665.14305276837, - "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_1_2b_attngate_2048_splash_lr_sweep_lr_x2-be36f3" + "training_hardware_flops": 1.18325574649179e+21, + "training_time": 161118.70186435047, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_1_2b_attngate_2048_splash_lr_sweep_lr_x2_v5p32-0b6010" } } ] diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x2/1_2b/speedrun_results_v5p64.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2/1_2b/speedrun_results_v5p64.json new file mode 100644 index 0000000000..bc85070a0a --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2/1_2b/speedrun_results_v5p64.json @@ -0,0 +1,143 @@ +{ + "runs": [ + { + "run_info": { + "author": { + "affiliation": "Stanford University", + "name": "Calvin Xu", + "url": "https://pinlinxu.com" + }, + "description": "Hackable Transformer (1_2b); Gated Attention (Splash); LR sweep multiplier=2", + "device_flops": 459000000000000.0, + "eval/paloma/c4_en/bpb": 0.9119555354118347, + "model_config": { + "activation_function": "silu", + "attn_backend": null, + "cross_entropy_block_size": null, + "flash_attention_block_size": null, + "gradient_checkpointing": true, + "head_dim": null, + "hidden_dim": 2048, + "initializer_range": 0.02, + "input_embedding_norm": false, + "intermediate_dim": 7168, + "layer_norm_epsilon": 1e-05, + "max_seq_len": 4096, + "num_heads": 16, + "num_kv_heads": 8, + "num_layers": 16, + "qk_norm": null, + "reference_checkpoint": "NousResearch/Llama-2-7b-hf", + "rope": { + "factor": 1.0, + "theta": 10000 + }, + "seq_len": 2048, + "tie_word_embeddings": false, + "tokenizer": null, + "upcast_attn": false, + "use_bias": false, + "use_gated_attention": true, + "use_layer_norm_weight": true + }, + "model_flops": 5.1738353514618185e+20, + "model_flops_per_token": 2877292544.0, + "model_size": 1498482688, + "num_chips": 32, + "num_devices": 32, + "resources": { + "cpu": 1, + "device": { + "kind": "tpu", + "topology": null, + "variant": "v5p-64" + }, + "disk": "1g", + "preemptible": true, + "ram": "128m", + "regions": null, + "replicas": 1 + }, + "run_completion_timestamp": "2025-12-17 01:10:26 UTC", + "tokenized_dataset": "gs://marin-us-east5/tokenized/subcache/fineweb-edu-10B-ac65f6", + "total_tokens": 59938701312, + "train_config": { + "allow_partial_checkpoint": false, + "beta1": null, + "beta2": null, + "cycle_length": null, + "data_seed": null, + "decay": null, + "ema_beta": null, + "epsilon": null, + "initialize_from_checkpoint_path": null, + "initialize_from_hf": null, + "int8": false, + "learning_rate": 0.008, + "lr_schedule": null, + "max_eval_batches": null, + "max_grad_norm": null, + "min_lr_ratio": null, + "num_train_steps": 57162, + "optimizer_config": { + "adam_lr": 0.0024, + "adam_weight_decay": null, + "backend_steps": 5, + "beta1": 0.8, + "beta2": 0.98, + "cooldown": null, + "cycle_length": null, + "cycles": null, + "decay": 1, + "default_weight_decay_mask": null, + "epsilon": 1e-15, + "haps": null, + "learning_rate": 0.008, + "lr": 0.02, + "lr_schedule": "linear", + "max_grad_norm": 2, + "min_lr_ratio": 0, + "momentum": 0.98, + "muon_epsilon": 1e-05, + "nesterov": true, + "rewarmup": 0.0, + "use_kimi_scaling": false, + "warmup": 0, + "weight_decay": 0.1, + "weight_decay_modules": null + }, + "per_device_eval_parallelism": null, + "profiler": false, + "profiler_num_steps": 100, + "profiler_start_step": 5, + "reset_data_loader_on_init": true, + "rewarmup": null, + "skip_bad_steps": false, + "steps_per_eval": null, + "steps_per_export": 10000, + "steps_per_hf_export": -1, + "steps_per_task_eval": null, + "train_batch_size": 256, + "train_seq_len": null, + "warmup": null, + "watch": { + "include_histograms": false, + "include_norms": true, + "include_per_parameter_norms": true, + "interval": 10, + "split_scan_layers": true, + "watch_targets": [ + "grads", + "params" + ] + }, + "weight_decay": null, + "z_loss_weight": null + }, + "training_hardware_flops": 1.8017056211590617e+21, + "training_time": 122665.14305276837, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_1_2b_attngate_2048_splash_lr_sweep_lr_x2-be36f3" + } + } + ] +} diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/1_2b/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/1_2b/speedrun_results.json index 1e3262390d..35a23973f5 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/1_2b/speedrun_results.json +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/1_2b/speedrun_results.json @@ -9,7 +9,7 @@ }, "description": "Hackable Transformer (1_2b); Gated Attention (Splash); LR sweep multiplier=2.5", "device_flops": 459000000000000.0, - "eval/paloma/c4_en/bpb": 0.9125918745994568, + "eval/paloma/c4_en/bpb": 0.9130340218544006, "model_config": { "activation_function": "silu", "attn_backend": null, @@ -37,20 +37,20 @@ "tokenizer": null, "upcast_attn": false, "use_bias": false, - "use_gated_attention": true, + "use_gated_attention": "elementwise", "use_layer_norm_weight": true }, "model_flops": 5.1738353514618185e+20, "model_flops_per_token": 2877292544.0, "model_size": 1498482688, - "num_chips": 32, - "num_devices": 32, + "num_chips": 16, + "num_devices": 16, "resources": { "cpu": 1, "device": { "kind": "tpu", "topology": null, - "variant": "v5p-64" + "variant": "v5p-32" }, "disk": "1g", "preemptible": true, @@ -58,8 +58,8 @@ "regions": null, "replicas": 1 }, - "run_completion_timestamp": "2025-12-16 23:49:35 UTC", - "tokenized_dataset": "gs://marin-us-east5/tokenized/subcache/fineweb-edu-10B-ac65f6", + "run_completion_timestamp": "2026-01-07 23:35:44 UTC", + "tokenized_dataset": "gs://marin-us-central1/tokenized/subcache/fineweb-edu-10B-ac65f6", "total_tokens": 59938701312, "train_config": { "allow_partial_checkpoint": false, @@ -70,6 +70,7 @@ "decay": null, "ema_beta": null, "epsilon": null, + "explicit_mesh_axes": false, "initialize_from_checkpoint_path": null, "initialize_from_hf": null, "int8": false, @@ -134,9 +135,9 @@ "weight_decay": null, "z_loss_weight": null }, - "training_hardware_flops": 1.8052209610227266e+21, - "training_time": 122904.47719381309, - "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_1_2b_attngate_2048_splash_lr_sweep_lr_x2_5-2f4194" + "training_hardware_flops": 1.18302728568909e+21, + "training_time": 161087.5933672508, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_1_2b_attngate_2048_splash_lr_sweep_lr_x2_5_v5p32-909ccf" } } ] diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/1_2b/speedrun_results_v5p64.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/1_2b/speedrun_results_v5p64.json new file mode 100644 index 0000000000..1e3262390d --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x2_5/1_2b/speedrun_results_v5p64.json @@ -0,0 +1,143 @@ +{ + "runs": [ + { + "run_info": { + "author": { + "affiliation": "Stanford University", + "name": "Calvin Xu", + "url": "https://pinlinxu.com" + }, + "description": "Hackable Transformer (1_2b); Gated Attention (Splash); LR sweep multiplier=2.5", + "device_flops": 459000000000000.0, + "eval/paloma/c4_en/bpb": 0.9125918745994568, + "model_config": { + "activation_function": "silu", + "attn_backend": null, + "cross_entropy_block_size": null, + "flash_attention_block_size": null, + "gradient_checkpointing": true, + "head_dim": null, + "hidden_dim": 2048, + "initializer_range": 0.02, + "input_embedding_norm": false, + "intermediate_dim": 7168, + "layer_norm_epsilon": 1e-05, + "max_seq_len": 4096, + "num_heads": 16, + "num_kv_heads": 8, + "num_layers": 16, + "qk_norm": null, + "reference_checkpoint": "NousResearch/Llama-2-7b-hf", + "rope": { + "factor": 1.0, + "theta": 10000 + }, + "seq_len": 2048, + "tie_word_embeddings": false, + "tokenizer": null, + "upcast_attn": false, + "use_bias": false, + "use_gated_attention": true, + "use_layer_norm_weight": true + }, + "model_flops": 5.1738353514618185e+20, + "model_flops_per_token": 2877292544.0, + "model_size": 1498482688, + "num_chips": 32, + "num_devices": 32, + "resources": { + "cpu": 1, + "device": { + "kind": "tpu", + "topology": null, + "variant": "v5p-64" + }, + "disk": "1g", + "preemptible": true, + "ram": "128m", + "regions": null, + "replicas": 1 + }, + "run_completion_timestamp": "2025-12-16 23:49:35 UTC", + "tokenized_dataset": "gs://marin-us-east5/tokenized/subcache/fineweb-edu-10B-ac65f6", + "total_tokens": 59938701312, + "train_config": { + "allow_partial_checkpoint": false, + "beta1": null, + "beta2": null, + "cycle_length": null, + "data_seed": null, + "decay": null, + "ema_beta": null, + "epsilon": null, + "initialize_from_checkpoint_path": null, + "initialize_from_hf": null, + "int8": false, + "learning_rate": 0.01, + "lr_schedule": null, + "max_eval_batches": null, + "max_grad_norm": null, + "min_lr_ratio": null, + "num_train_steps": 57162, + "optimizer_config": { + "adam_lr": 0.0029999999999999996, + "adam_weight_decay": null, + "backend_steps": 5, + "beta1": 0.8, + "beta2": 0.98, + "cooldown": null, + "cycle_length": null, + "cycles": null, + "decay": 1, + "default_weight_decay_mask": null, + "epsilon": 1e-15, + "haps": null, + "learning_rate": 0.01, + "lr": 0.02, + "lr_schedule": "linear", + "max_grad_norm": 2, + "min_lr_ratio": 0, + "momentum": 0.98, + "muon_epsilon": 1e-05, + "nesterov": true, + "rewarmup": 0.0, + "use_kimi_scaling": false, + "warmup": 0, + "weight_decay": 0.1, + "weight_decay_modules": null + }, + "per_device_eval_parallelism": null, + "profiler": false, + "profiler_num_steps": 100, + "profiler_start_step": 5, + "reset_data_loader_on_init": true, + "rewarmup": null, + "skip_bad_steps": false, + "steps_per_eval": null, + "steps_per_export": 10000, + "steps_per_hf_export": -1, + "steps_per_task_eval": null, + "train_batch_size": 256, + "train_seq_len": null, + "warmup": null, + "watch": { + "include_histograms": false, + "include_norms": true, + "include_per_parameter_norms": true, + "interval": 10, + "split_scan_layers": true, + "watch_targets": [ + "grads", + "params" + ] + }, + "weight_decay": null, + "z_loss_weight": null + }, + "training_hardware_flops": 1.8052209610227266e+21, + "training_time": 122904.47719381309, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_1_2b_attngate_2048_splash_lr_sweep_lr_x2_5-2f4194" + } + } + ] +} diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x3/1_2b/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3/1_2b/speedrun_results.json index 2539d30ffa..5069a17d45 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/lr_x3/1_2b/speedrun_results.json +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3/1_2b/speedrun_results.json @@ -9,7 +9,7 @@ }, "description": "Hackable Transformer (1_2b); Gated Attention (Splash); LR sweep multiplier=3", "device_flops": 459000000000000.0, - "eval/paloma/c4_en/bpb": 0.9134978652000427, + "eval/paloma/c4_en/bpb": 0.9175548553466797, "model_config": { "activation_function": "silu", "attn_backend": null, @@ -37,20 +37,20 @@ "tokenizer": null, "upcast_attn": false, "use_bias": false, - "use_gated_attention": true, + "use_gated_attention": "elementwise", "use_layer_norm_weight": true }, "model_flops": 5.1738353514618185e+20, "model_flops_per_token": 2877292544.0, "model_size": 1498482688, - "num_chips": 32, - "num_devices": 32, + "num_chips": 16, + "num_devices": 16, "resources": { "cpu": 1, "device": { "kind": "tpu", "topology": null, - "variant": "v5p-64" + "variant": "v5p-32" }, "disk": "1g", "preemptible": true, @@ -58,8 +58,8 @@ "regions": null, "replicas": 1 }, - "run_completion_timestamp": "2025-12-17 00:34:36 UTC", - "tokenized_dataset": "gs://marin-us-east5/tokenized/subcache/fineweb-edu-10B-ac65f6", + "run_completion_timestamp": "2026-01-07 06:17:21 UTC", + "tokenized_dataset": "gs://marin-us-central1/tokenized/subcache/fineweb-edu-10B-ac65f6", "total_tokens": 59938701312, "train_config": { "allow_partial_checkpoint": false, @@ -70,6 +70,7 @@ "decay": null, "ema_beta": null, "epsilon": null, + "explicit_mesh_axes": false, "initialize_from_checkpoint_path": null, "initialize_from_hf": null, "int8": false, @@ -134,9 +135,9 @@ "weight_decay": null, "z_loss_weight": null }, - "training_hardware_flops": 1.81031625465692e+21, - "training_time": 123251.37899352668, - "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_1_2b_attngate_2048_splash_lr_sweep_lr_x3-e8942d" + "training_hardware_flops": 1.178280086680337e+21, + "training_time": 160441.18827346637, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_1_2b_attngate_2048_splash_lr_sweep_lr_x3_v5p32-4286b4" } } ] diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x3/1_2b/speedrun_results_v5p64.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3/1_2b/speedrun_results_v5p64.json new file mode 100644 index 0000000000..2539d30ffa --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3/1_2b/speedrun_results_v5p64.json @@ -0,0 +1,143 @@ +{ + "runs": [ + { + "run_info": { + "author": { + "affiliation": "Stanford University", + "name": "Calvin Xu", + "url": "https://pinlinxu.com" + }, + "description": "Hackable Transformer (1_2b); Gated Attention (Splash); LR sweep multiplier=3", + "device_flops": 459000000000000.0, + "eval/paloma/c4_en/bpb": 0.9134978652000427, + "model_config": { + "activation_function": "silu", + "attn_backend": null, + "cross_entropy_block_size": null, + "flash_attention_block_size": null, + "gradient_checkpointing": true, + "head_dim": null, + "hidden_dim": 2048, + "initializer_range": 0.02, + "input_embedding_norm": false, + "intermediate_dim": 7168, + "layer_norm_epsilon": 1e-05, + "max_seq_len": 4096, + "num_heads": 16, + "num_kv_heads": 8, + "num_layers": 16, + "qk_norm": null, + "reference_checkpoint": "NousResearch/Llama-2-7b-hf", + "rope": { + "factor": 1.0, + "theta": 10000 + }, + "seq_len": 2048, + "tie_word_embeddings": false, + "tokenizer": null, + "upcast_attn": false, + "use_bias": false, + "use_gated_attention": true, + "use_layer_norm_weight": true + }, + "model_flops": 5.1738353514618185e+20, + "model_flops_per_token": 2877292544.0, + "model_size": 1498482688, + "num_chips": 32, + "num_devices": 32, + "resources": { + "cpu": 1, + "device": { + "kind": "tpu", + "topology": null, + "variant": "v5p-64" + }, + "disk": "1g", + "preemptible": true, + "ram": "128m", + "regions": null, + "replicas": 1 + }, + "run_completion_timestamp": "2025-12-17 00:34:36 UTC", + "tokenized_dataset": "gs://marin-us-east5/tokenized/subcache/fineweb-edu-10B-ac65f6", + "total_tokens": 59938701312, + "train_config": { + "allow_partial_checkpoint": false, + "beta1": null, + "beta2": null, + "cycle_length": null, + "data_seed": null, + "decay": null, + "ema_beta": null, + "epsilon": null, + "initialize_from_checkpoint_path": null, + "initialize_from_hf": null, + "int8": false, + "learning_rate": 0.012, + "lr_schedule": null, + "max_eval_batches": null, + "max_grad_norm": null, + "min_lr_ratio": null, + "num_train_steps": 57162, + "optimizer_config": { + "adam_lr": 0.0036, + "adam_weight_decay": null, + "backend_steps": 5, + "beta1": 0.8, + "beta2": 0.98, + "cooldown": null, + "cycle_length": null, + "cycles": null, + "decay": 1, + "default_weight_decay_mask": null, + "epsilon": 1e-15, + "haps": null, + "learning_rate": 0.012, + "lr": 0.02, + "lr_schedule": "linear", + "max_grad_norm": 2, + "min_lr_ratio": 0, + "momentum": 0.98, + "muon_epsilon": 1e-05, + "nesterov": true, + "rewarmup": 0.0, + "use_kimi_scaling": false, + "warmup": 0, + "weight_decay": 0.1, + "weight_decay_modules": null + }, + "per_device_eval_parallelism": null, + "profiler": false, + "profiler_num_steps": 100, + "profiler_start_step": 5, + "reset_data_loader_on_init": true, + "rewarmup": null, + "skip_bad_steps": false, + "steps_per_eval": null, + "steps_per_export": 10000, + "steps_per_hf_export": -1, + "steps_per_task_eval": null, + "train_batch_size": 256, + "train_seq_len": null, + "warmup": null, + "watch": { + "include_histograms": false, + "include_norms": true, + "include_per_parameter_norms": true, + "interval": 10, + "split_scan_layers": true, + "watch_targets": [ + "grads", + "params" + ] + }, + "weight_decay": null, + "z_loss_weight": null + }, + "training_hardware_flops": 1.81031625465692e+21, + "training_time": 123251.37899352668, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_1_2b_attngate_2048_splash_lr_sweep_lr_x3-e8942d" + } + } + ] +} diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/1_2b/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/1_2b/speedrun_results.json index 79c7c4fa23..fef8dae8a7 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/1_2b/speedrun_results.json +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/1_2b/speedrun_results.json @@ -9,7 +9,7 @@ }, "description": "Hackable Transformer (1_2b); Gated Attention (Splash); LR sweep multiplier=3.5", "device_flops": 459000000000000.0, - "eval/paloma/c4_en/bpb": 0.9148565530776978, + "eval/paloma/c4_en/bpb": 0.9142972230911255, "model_config": { "activation_function": "silu", "attn_backend": null, @@ -37,20 +37,20 @@ "tokenizer": null, "upcast_attn": false, "use_bias": false, - "use_gated_attention": true, + "use_gated_attention": "elementwise", "use_layer_norm_weight": true }, "model_flops": 5.1738353514618185e+20, "model_flops_per_token": 2877292544.0, "model_size": 1498482688, - "num_chips": 32, - "num_devices": 32, + "num_chips": 16, + "num_devices": 16, "resources": { "cpu": 1, "device": { "kind": "tpu", "topology": null, - "variant": "v5p-64" + "variant": "v5p-32" }, "disk": "1g", "preemptible": true, @@ -58,8 +58,8 @@ "regions": null, "replicas": 1 }, - "run_completion_timestamp": "2025-12-16 22:33:13 UTC", - "tokenized_dataset": "gs://marin-us-east5/tokenized/subcache/fineweb-edu-10B-ac65f6", + "run_completion_timestamp": "2026-01-07 06:32:05 UTC", + "tokenized_dataset": "gs://marin-us-central1/tokenized/subcache/fineweb-edu-10B-ac65f6", "total_tokens": 59938701312, "train_config": { "allow_partial_checkpoint": false, @@ -70,6 +70,7 @@ "decay": null, "ema_beta": null, "epsilon": null, + "explicit_mesh_axes": false, "initialize_from_checkpoint_path": null, "initialize_from_hf": null, "int8": false, @@ -134,9 +135,9 @@ "weight_decay": null, "z_loss_weight": null }, - "training_hardware_flops": 1.8114652296478665e+21, - "training_time": 123329.6044150236, - "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_1_2b_attngate_2048_splash_lr_sweep_lr_x3_5-b0a3b2" + "training_hardware_flops": 1.1912866285370585e+21, + "training_time": 162212.23155461036, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_1_2b_attngate_2048_splash_lr_sweep_lr_x3_5_v5p32-1038f8" } } ] diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/1_2b/speedrun_results_v5p64.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/1_2b/speedrun_results_v5p64.json new file mode 100644 index 0000000000..79c7c4fa23 --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x3_5/1_2b/speedrun_results_v5p64.json @@ -0,0 +1,143 @@ +{ + "runs": [ + { + "run_info": { + "author": { + "affiliation": "Stanford University", + "name": "Calvin Xu", + "url": "https://pinlinxu.com" + }, + "description": "Hackable Transformer (1_2b); Gated Attention (Splash); LR sweep multiplier=3.5", + "device_flops": 459000000000000.0, + "eval/paloma/c4_en/bpb": 0.9148565530776978, + "model_config": { + "activation_function": "silu", + "attn_backend": null, + "cross_entropy_block_size": null, + "flash_attention_block_size": null, + "gradient_checkpointing": true, + "head_dim": null, + "hidden_dim": 2048, + "initializer_range": 0.02, + "input_embedding_norm": false, + "intermediate_dim": 7168, + "layer_norm_epsilon": 1e-05, + "max_seq_len": 4096, + "num_heads": 16, + "num_kv_heads": 8, + "num_layers": 16, + "qk_norm": null, + "reference_checkpoint": "NousResearch/Llama-2-7b-hf", + "rope": { + "factor": 1.0, + "theta": 10000 + }, + "seq_len": 2048, + "tie_word_embeddings": false, + "tokenizer": null, + "upcast_attn": false, + "use_bias": false, + "use_gated_attention": true, + "use_layer_norm_weight": true + }, + "model_flops": 5.1738353514618185e+20, + "model_flops_per_token": 2877292544.0, + "model_size": 1498482688, + "num_chips": 32, + "num_devices": 32, + "resources": { + "cpu": 1, + "device": { + "kind": "tpu", + "topology": null, + "variant": "v5p-64" + }, + "disk": "1g", + "preemptible": true, + "ram": "128m", + "regions": null, + "replicas": 1 + }, + "run_completion_timestamp": "2025-12-16 22:33:13 UTC", + "tokenized_dataset": "gs://marin-us-east5/tokenized/subcache/fineweb-edu-10B-ac65f6", + "total_tokens": 59938701312, + "train_config": { + "allow_partial_checkpoint": false, + "beta1": null, + "beta2": null, + "cycle_length": null, + "data_seed": null, + "decay": null, + "ema_beta": null, + "epsilon": null, + "initialize_from_checkpoint_path": null, + "initialize_from_hf": null, + "int8": false, + "learning_rate": 0.014, + "lr_schedule": null, + "max_eval_batches": null, + "max_grad_norm": null, + "min_lr_ratio": null, + "num_train_steps": 57162, + "optimizer_config": { + "adam_lr": 0.0042, + "adam_weight_decay": null, + "backend_steps": 5, + "beta1": 0.8, + "beta2": 0.98, + "cooldown": null, + "cycle_length": null, + "cycles": null, + "decay": 1, + "default_weight_decay_mask": null, + "epsilon": 1e-15, + "haps": null, + "learning_rate": 0.014, + "lr": 0.02, + "lr_schedule": "linear", + "max_grad_norm": 2, + "min_lr_ratio": 0, + "momentum": 0.98, + "muon_epsilon": 1e-05, + "nesterov": true, + "rewarmup": 0.0, + "use_kimi_scaling": false, + "warmup": 0, + "weight_decay": 0.1, + "weight_decay_modules": null + }, + "per_device_eval_parallelism": null, + "profiler": false, + "profiler_num_steps": 100, + "profiler_start_step": 5, + "reset_data_loader_on_init": true, + "rewarmup": null, + "skip_bad_steps": false, + "steps_per_eval": null, + "steps_per_export": 10000, + "steps_per_hf_export": -1, + "steps_per_task_eval": null, + "train_batch_size": 256, + "train_seq_len": null, + "warmup": null, + "watch": { + "include_histograms": false, + "include_norms": true, + "include_per_parameter_norms": true, + "interval": 10, + "split_scan_layers": true, + "watch_targets": [ + "grads", + "params" + ] + }, + "weight_decay": null, + "z_loss_weight": null + }, + "training_hardware_flops": 1.8114652296478665e+21, + "training_time": 123329.6044150236, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_1_2b_attngate_2048_splash_lr_sweep_lr_x3_5-b0a3b2" + } + } + ] +} diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x4/1_2b/speedrun_results.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x4/1_2b/speedrun_results.json index 817bcd9636..2192aa2862 100644 --- a/experiments/speedrun/hackable_transformer_attn_gate/lr_x4/1_2b/speedrun_results.json +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x4/1_2b/speedrun_results.json @@ -9,7 +9,7 @@ }, "description": "Hackable Transformer (1_2b); Gated Attention (Splash); LR sweep multiplier=4", "device_flops": 459000000000000.0, - "eval/paloma/c4_en/bpb": 0.9164798259735107, + "eval/paloma/c4_en/bpb": 0.9157812595367432, "model_config": { "activation_function": "silu", "attn_backend": null, @@ -37,20 +37,20 @@ "tokenizer": null, "upcast_attn": false, "use_bias": false, - "use_gated_attention": true, + "use_gated_attention": "elementwise", "use_layer_norm_weight": true }, "model_flops": 5.1738353514618185e+20, "model_flops_per_token": 2877292544.0, "model_size": 1498482688, - "num_chips": 32, - "num_devices": 32, + "num_chips": 16, + "num_devices": 16, "resources": { "cpu": 1, "device": { "kind": "tpu", "topology": null, - "variant": "v5p-64" + "variant": "v5p-32" }, "disk": "1g", "preemptible": true, @@ -58,8 +58,8 @@ "regions": null, "replicas": 1 }, - "run_completion_timestamp": "2025-12-17 02:26:48 UTC", - "tokenized_dataset": "gs://marin-us-east5/tokenized/subcache/fineweb-edu-10B-ac65f6", + "run_completion_timestamp": "2026-01-07 23:47:52 UTC", + "tokenized_dataset": "gs://marin-us-central1/tokenized/subcache/fineweb-edu-10B-ac65f6", "total_tokens": 59938701312, "train_config": { "allow_partial_checkpoint": false, @@ -70,6 +70,7 @@ "decay": null, "ema_beta": null, "epsilon": null, + "explicit_mesh_axes": false, "initialize_from_checkpoint_path": null, "initialize_from_hf": null, "int8": false, @@ -134,9 +135,9 @@ "weight_decay": null, "z_loss_weight": null }, - "training_hardware_flops": 2.051479930331274e+21, - "training_time": 139670.4745595911, - "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_1_2b_attngate_2048_splash_lr_sweep_lr_x4-f80807" + "training_hardware_flops": 1.1860119667024e+21, + "training_time": 161494.00418061003, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_1_2b_attngate_2048_splash_lr_sweep_lr_x4_v5p32-0d60e6" } } ] diff --git a/experiments/speedrun/hackable_transformer_attn_gate/lr_x4/1_2b/speedrun_results_v5p64.json b/experiments/speedrun/hackable_transformer_attn_gate/lr_x4/1_2b/speedrun_results_v5p64.json new file mode 100644 index 0000000000..817bcd9636 --- /dev/null +++ b/experiments/speedrun/hackable_transformer_attn_gate/lr_x4/1_2b/speedrun_results_v5p64.json @@ -0,0 +1,143 @@ +{ + "runs": [ + { + "run_info": { + "author": { + "affiliation": "Stanford University", + "name": "Calvin Xu", + "url": "https://pinlinxu.com" + }, + "description": "Hackable Transformer (1_2b); Gated Attention (Splash); LR sweep multiplier=4", + "device_flops": 459000000000000.0, + "eval/paloma/c4_en/bpb": 0.9164798259735107, + "model_config": { + "activation_function": "silu", + "attn_backend": null, + "cross_entropy_block_size": null, + "flash_attention_block_size": null, + "gradient_checkpointing": true, + "head_dim": null, + "hidden_dim": 2048, + "initializer_range": 0.02, + "input_embedding_norm": false, + "intermediate_dim": 7168, + "layer_norm_epsilon": 1e-05, + "max_seq_len": 4096, + "num_heads": 16, + "num_kv_heads": 8, + "num_layers": 16, + "qk_norm": null, + "reference_checkpoint": "NousResearch/Llama-2-7b-hf", + "rope": { + "factor": 1.0, + "theta": 10000 + }, + "seq_len": 2048, + "tie_word_embeddings": false, + "tokenizer": null, + "upcast_attn": false, + "use_bias": false, + "use_gated_attention": true, + "use_layer_norm_weight": true + }, + "model_flops": 5.1738353514618185e+20, + "model_flops_per_token": 2877292544.0, + "model_size": 1498482688, + "num_chips": 32, + "num_devices": 32, + "resources": { + "cpu": 1, + "device": { + "kind": "tpu", + "topology": null, + "variant": "v5p-64" + }, + "disk": "1g", + "preemptible": true, + "ram": "128m", + "regions": null, + "replicas": 1 + }, + "run_completion_timestamp": "2025-12-17 02:26:48 UTC", + "tokenized_dataset": "gs://marin-us-east5/tokenized/subcache/fineweb-edu-10B-ac65f6", + "total_tokens": 59938701312, + "train_config": { + "allow_partial_checkpoint": false, + "beta1": null, + "beta2": null, + "cycle_length": null, + "data_seed": null, + "decay": null, + "ema_beta": null, + "epsilon": null, + "initialize_from_checkpoint_path": null, + "initialize_from_hf": null, + "int8": false, + "learning_rate": 0.016, + "lr_schedule": null, + "max_eval_batches": null, + "max_grad_norm": null, + "min_lr_ratio": null, + "num_train_steps": 57162, + "optimizer_config": { + "adam_lr": 0.0048, + "adam_weight_decay": null, + "backend_steps": 5, + "beta1": 0.8, + "beta2": 0.98, + "cooldown": null, + "cycle_length": null, + "cycles": null, + "decay": 1, + "default_weight_decay_mask": null, + "epsilon": 1e-15, + "haps": null, + "learning_rate": 0.016, + "lr": 0.02, + "lr_schedule": "linear", + "max_grad_norm": 2, + "min_lr_ratio": 0, + "momentum": 0.98, + "muon_epsilon": 1e-05, + "nesterov": true, + "rewarmup": 0.0, + "use_kimi_scaling": false, + "warmup": 0, + "weight_decay": 0.1, + "weight_decay_modules": null + }, + "per_device_eval_parallelism": null, + "profiler": false, + "profiler_num_steps": 100, + "profiler_start_step": 5, + "reset_data_loader_on_init": true, + "rewarmup": null, + "skip_bad_steps": false, + "steps_per_eval": null, + "steps_per_export": 10000, + "steps_per_hf_export": -1, + "steps_per_task_eval": null, + "train_batch_size": 256, + "train_seq_len": null, + "warmup": null, + "watch": { + "include_histograms": false, + "include_norms": true, + "include_per_parameter_norms": true, + "interval": 10, + "split_scan_layers": true, + "watch_targets": [ + "grads", + "params" + ] + }, + "weight_decay": null, + "z_loss_weight": null + }, + "training_hardware_flops": 2.051479930331274e+21, + "training_time": 139670.4745595911, + "wandb_run_link": "https://wandb.ai/marin-community/marin/runs/hacktx_1_2b_attngate_2048_splash_lr_sweep_lr_x4-f80807" + } + } + ] +} From 7b4b128b72c35f09e7a59c0f62952f794e97316d Mon Sep 17 00:00:00 2001 From: "Pinlin [Calvin] Xu" Date: Mon, 19 Jan 2026 17:01:57 -0800 Subject: [PATCH 11/11] revert back to separate gate proj let XLA do its magic & not mess up dimension size alignment --- lib/levanter/src/levanter/layers/attention.py | 63 ++++++++----------- lib/levanter/tests/test_attention.py | 25 +++++--- 2 files changed, 40 insertions(+), 48 deletions(-) diff --git a/lib/levanter/src/levanter/layers/attention.py b/lib/levanter/src/levanter/layers/attention.py index 2faa2a95ed..3c102adad9 100644 --- a/lib/levanter/src/levanter/layers/attention.py +++ b/lib/levanter/src/levanter/layers/attention.py @@ -1647,26 +1647,17 @@ class Attention(eqx.Module): q_norm: Optional[LayerNormBase] = None k_norm: Optional[LayerNormBase] = None rot_embs: Optional[RotaryEmbeddings] = None + gate_proj: Optional[hnn.Linear] = None @staticmethod def init(config: AttentionConfig, *, key) -> "Attention": use_bias = config.use_bias use_output_bias = config.use_output_bias if config.use_output_bias is not None else use_bias - k_q, k_k, k_v, k_o = jrandom.split(key, 4) - - # For gated attention, the gate is fused with Q projection (following the paper). - # The Q projection outputs [KVHeads, QHeadsPerGroup, HeadSize + GateSize]. - # For headwise gating: GateSize = 1 (one scalar per head) - # For elementwise gating: GateSize = HeadSize (one value per element) - if config.gated != "none": - QGateAxis = Axis("q_gate_combined", config.HeadSize.size + config.GateSize.size) - q_out_axes = (config.KVHeads, config.QHeadsPerGroup, QGateAxis) - else: - q_out_axes = (config.KVHeads, config.QHeadsPerGroup, config.HeadSize) + k_q, k_k, k_v, k_o, k_g = jrandom.split(key, 5) q_proj = hnn.Linear.init( In=config.Embed, - Out=q_out_axes, + Out=(config.KVHeads, config.QHeadsPerGroup, config.HeadSize), key=k_q, use_bias=use_bias, out_first=True, @@ -1693,6 +1684,19 @@ def init(config: AttentionConfig, *, key) -> "Attention": out_first=True, ) + # For gated attention, create a separate gate projection. + # For headwise gating: GateSize = 1 (one scalar per head) + # For elementwise gating: GateSize = HeadSize (one value per element) + gate_proj = None + if config.gated != "none": + gate_proj = hnn.Linear.init( + In=config.Embed, + Out=(config.KVHeads, config.QHeadsPerGroup, config.GateSize), + key=k_g, + use_bias=use_bias, + out_first=True, + ) + q_norm = None k_norm = None if config.qk_norm is not None: @@ -1702,27 +1706,7 @@ def init(config: AttentionConfig, *, key) -> "Attention": # Build rotary embeddings once during initialization if configured rot_embs = config.rope.build(config.HeadSize) if config.rope is not None else None - return Attention(config, q_proj, k_proj, v_proj, o_proj, q_norm, k_norm, rot_embs) - - def _split_q_and_gate(self, q_combined: NamedArray) -> tuple[NamedArray, NamedArray | None]: - """Split the combined Q+gate projection into Q and gate components. - - Args: - q_combined: The combined output from q_proj with shape [..., q_gate_combined]. - - Returns: - A tuple of (q, gate) where: - - q has shape [..., head_size] - - gate has shape [..., gate_size] (1 for headwise, head_size for elementwise) - or None if gating is disabled. - """ - if self.config.gated == "none": - return q_combined, None - - combined_axis = q_combined.resolve_axis("q_gate_combined") - q, gate = hax.split(q_combined, axis=combined_axis, new_axes=[self.config.HeadSize, self.config.GateSize]) - - return q, gate + return Attention(config, q_proj, k_proj, v_proj, o_proj, q_norm, k_norm, rot_embs, gate_proj) def empty_page_cache(self, spec: PageTableSpec, *, dtype) -> "KvPageCache": return KvPageCache.init(spec, self.config.KVHeads, self.config.HeadSize, dtype=dtype) @@ -1852,16 +1836,18 @@ def _compute_qkv( A tuple of (q, k, v, gate) where gate is None if gating is disabled. """ - # Split the projection key into three – one for each of Q, K, V - key_q, key_k, key_v = maybe_rng_split(key, 3) + # Split the projection key into four – one for each of Q, K, V, and gate + key_q, key_k, key_v, key_g = maybe_rng_split(key, 4) # Linear projections - q_combined = self.q_proj(x, key=key_q) + q = self.q_proj(x, key=key_q) k = self.k_proj(x, key=key_k) v = self.v_proj(x, key=key_v) - # Split Q and gate if gated attention is enabled - q, gate = self._split_q_and_gate(q_combined) + # Compute gate if gated attention is enabled + gate = None + if self.gate_proj is not None: + gate = self.gate_proj(x, key=key_g) # Optional QK layer-norm (applied only to Q, not gate) if self.config.qk_norm is not None: @@ -2414,6 +2400,7 @@ def init(config: AttentionConfig, *, key) -> "AttentionWithSink": base.q_norm, base.k_norm, base.rot_embs, + base.gate_proj, sinks, ) diff --git a/lib/levanter/tests/test_attention.py b/lib/levanter/tests/test_attention.py index e571767efc..5bcd5eb9aa 100644 --- a/lib/levanter/tests/test_attention.py +++ b/lib/levanter/tests/test_attention.py @@ -135,10 +135,9 @@ def test_attention_with_sink_module(): def test_attention_with_gating_module(): """Test elementwise gated attention. - The gate is fused with q_proj. When gated="elementwise", q_proj outputs - [head_size + head_size] and the second half is the gate. + When gated="elementwise", a separate gate_proj outputs [kv_head, q_heads_per_group, head_size]. - With zero weights/biases for Q (and gate), the gate output is sigmoid(0) = 0.5. + With zero weights/biases for Q and gate, the gate output is sigmoid(0) = 0.5. With v_proj bias=1 and o_proj weight=1, the attention output before gating is 1. After gating: 1 * 0.5 = 0.5 """ @@ -148,8 +147,8 @@ def test_attention_with_gating_module(): config = AttentionConfig(Embed=Embed, num_heads=1, num_kv_heads=1, use_bias=True, gated="elementwise") attn = Attention.init(config, key=jrandom.PRNGKey(0)) - # q_proj now has shape [embed, kv_head, q_heads_per_group, head_size*2] for elementwise gating - # The first half is Q, the second half is the gate + # q_proj has shape [embed, kv_head, q_heads_per_group, head_size] + # gate_proj is a separate projection with same output shape attn = eqx.tree_at(lambda a: a.q_proj.weight, attn, hax.zeros(attn.q_proj.weight.axes)) attn = eqx.tree_at(lambda a: a.q_proj.bias, attn, hax.zeros(attn.q_proj.bias.axes)) attn = eqx.tree_at(lambda a: a.k_proj.weight, attn, hax.zeros(attn.k_proj.weight.axes)) @@ -158,6 +157,9 @@ def test_attention_with_gating_module(): attn = eqx.tree_at(lambda a: a.v_proj.bias, attn, hax.ones(attn.v_proj.bias.axes)) attn = eqx.tree_at(lambda a: a.o_proj.weight, attn, hax.ones(attn.o_proj.weight.axes)) attn = eqx.tree_at(lambda a: a.o_proj.bias, attn, hax.zeros(attn.o_proj.bias.axes)) + # Zero out gate_proj so sigmoid(0) = 0.5 + attn = eqx.tree_at(lambda a: a.gate_proj.weight, attn, hax.zeros(attn.gate_proj.weight.axes)) + attn = eqx.tree_at(lambda a: a.gate_proj.bias, attn, hax.zeros(attn.gate_proj.bias.axes)) x = hax.zeros((Pos, Embed)) out = attn(x, None) @@ -169,10 +171,10 @@ def test_attention_with_gating_module(): def test_attention_with_headwise_gating_module(): """Test headwise gated attention. - The gate is fused with q_proj. When gated="headwise", q_proj outputs - [head_size + 1] and the last value is the gate (one scalar per head). + When gated="headwise", a separate gate_proj outputs [kv_head, q_heads_per_group, 1] + (one scalar per head). - With zero weights/biases for Q (and gate), the gate output is sigmoid(0) = 0.5. + With zero weights/biases for Q and gate, the gate output is sigmoid(0) = 0.5. With v_proj bias=1 and o_proj weight=1, the attention output before gating is 1. After gating: 1 * 0.5 = 0.5 """ @@ -182,8 +184,8 @@ def test_attention_with_headwise_gating_module(): config = AttentionConfig(Embed=Embed, num_heads=1, num_kv_heads=1, use_bias=True, gated="headwise") attn = Attention.init(config, key=jrandom.PRNGKey(0)) - # q_proj now has shape [embed, kv_head, q_heads_per_group, head_size+1] for headwise gating - # The first head_size values are Q, the last value is the gate + # q_proj has shape [embed, kv_head, q_heads_per_group, head_size] + # gate_proj is a separate projection with output [kv_head, q_heads_per_group, 1] attn = eqx.tree_at(lambda a: a.q_proj.weight, attn, hax.zeros(attn.q_proj.weight.axes)) attn = eqx.tree_at(lambda a: a.q_proj.bias, attn, hax.zeros(attn.q_proj.bias.axes)) attn = eqx.tree_at(lambda a: a.k_proj.weight, attn, hax.zeros(attn.k_proj.weight.axes)) @@ -192,6 +194,9 @@ def test_attention_with_headwise_gating_module(): attn = eqx.tree_at(lambda a: a.v_proj.bias, attn, hax.ones(attn.v_proj.bias.axes)) attn = eqx.tree_at(lambda a: a.o_proj.weight, attn, hax.ones(attn.o_proj.weight.axes)) attn = eqx.tree_at(lambda a: a.o_proj.bias, attn, hax.zeros(attn.o_proj.bias.axes)) + # Zero out gate_proj so sigmoid(0) = 0.5 + attn = eqx.tree_at(lambda a: a.gate_proj.weight, attn, hax.zeros(attn.gate_proj.weight.axes)) + attn = eqx.tree_at(lambda a: a.gate_proj.bias, attn, hax.zeros(attn.gate_proj.bias.axes)) x = hax.zeros((Pos, Embed)) out = attn(x, None)