From 6c9a727cd7dedb2208832f7b6c56e92d36c1cb04 Mon Sep 17 00:00:00 2001 From: Franz Louis Cesista Date: Fri, 5 Sep 2025 23:45:32 +0800 Subject: [PATCH 1/4] add code for batch size scaling --- .../muon_adamw_batch_size_sweep.py | 242 ++++++++++++++++++ 1 file changed, 242 insertions(+) create mode 100644 experiments/speedrun/muon_adamw_llama_batch_size_scaling/muon_adamw_batch_size_sweep.py diff --git a/experiments/speedrun/muon_adamw_llama_batch_size_scaling/muon_adamw_batch_size_sweep.py b/experiments/speedrun/muon_adamw_llama_batch_size_scaling/muon_adamw_batch_size_sweep.py new file mode 100644 index 0000000000..7b061fff31 --- /dev/null +++ b/experiments/speedrun/muon_adamw_llama_batch_size_scaling/muon_adamw_batch_size_sweep.py @@ -0,0 +1,242 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Speedruns using the Muon optimizer for various Llama model sizes (Chinchilla optimal steps). + +Optimizer configs were searched & provided by Kaiyue Wen in https://wandb.ai/marin-community/marin/reports/Fantastic-Optimizers-and-Where-to-Find-Them--VmlldzoxMjgzMzQ2NQ +""" + +import dataclasses +import logging + +from levanter.optim import AdamConfig, MuonConfig + +from experiments.llama import llama_1_4b, llama_150m, llama_300m, llama_600m +from experiments.simple_train_config import SimpleTrainConfig +from marin.execution.executor import executor_main +from marin.resources import TpuPodConfig +from marin.speedrun.speedrun import Author, SpeedrunConfig, default_speedrun + +AUTHOR = Author(name="Franz Cesista", affiliation="", url="https://leloykun.github.io") + +logger = logging.getLogger("ray") + + +def get_num_train_steps(param_count, batch_size, seq_len): + """Compute the number of steps for Chinchilla optimal training (20x params tokens).""" + total_tokens = param_count * 20 + tokens_per_step = batch_size * seq_len + return total_tokens // tokens_per_step + + +def build_config(optimizer_name: str, size: str, batch_size: int, seq_len: int=4096) -> tuple[str, SpeedrunConfig]: + # Parameter counts + param_counts = { + "130m": 130_000_000, + "300m": 300_000_000, + "520m": 520_000_000, + "1_2b": 1_200_000_000, + } + + # Model configs + model_cfgs = { + "130m": llama_150m, + "300m": llama_300m, + "520m": llama_600m, + "1_2b": llama_1_4b, + } + + # Resource configs + resource_cfgs = { + "130m": TpuPodConfig(tpu_type="v5p-32"), + "300m": TpuPodConfig(tpu_type="v5p-32"), + "520m": TpuPodConfig(tpu_type="v5p-32"), + "1_2b": TpuPodConfig(tpu_type="v5p-32"), + } + + # Optimizer configs for each size + muon_configs = { + "130m": MuonConfig( + learning_rate=0.016, + adam_lr=0.0032, + weight_decay=0.1, + min_lr_ratio=0, + warmup=0, + momentum=0.95, + beta1=0.8, + beta2=0.98, + epsilon=1e-15, + muon_epsilon=1e-5, + max_grad_norm=1, + lr_schedule="linear", + decay=0.8, + ), + "300m": MuonConfig( + learning_rate=0.008, + adam_lr=0.0024, + weight_decay=0.1, + min_lr_ratio=0, + warmup=0, + momentum=0.98, + beta1=0.8, + beta2=0.98, + epsilon=1e-15, + muon_epsilon=1e-5, + max_grad_norm=1, + lr_schedule="linear", + decay=0.8, + ), + "520m": MuonConfig( + learning_rate=0.008, + adam_lr=0.0024, + weight_decay=0.1, + min_lr_ratio=0, + warmup=0, + momentum=0.98, + beta1=0.8, + beta2=0.98, + epsilon=1e-25, + muon_epsilon=1e-5, + max_grad_norm=1, + lr_schedule="linear", + decay=1, + ), + "1_2b": MuonConfig( + learning_rate=0.004, + adam_lr=0.0012, + weight_decay=0.1, + min_lr_ratio=0, + warmup=0, + momentum=0.98, + beta1=0.8, + beta2=0.98, + epsilon=1e-15, + muon_epsilon=1e-5, + max_grad_norm=2, + lr_schedule="linear", + decay=1, + ), + } + # AdamW optimizer configs for each size + adam_configs = { + "130m": AdamConfig( + learning_rate=0.008, + weight_decay=0.1, + min_lr_ratio=0, + warmup=2000, + beta1=0.9, + beta2=0.98, + epsilon=1e-20, + max_grad_norm=1, + nesterov=False, + ), + "300m": AdamConfig( + learning_rate=0.008, + weight_decay=0.1, + min_lr_ratio=0, + warmup=2000, + beta1=0.9, + beta2=0.98, + epsilon=1e-10, + max_grad_norm=1, + nesterov=False, + ), + "520m": AdamConfig( + learning_rate=0.004, + weight_decay=0.2, + min_lr_ratio=0, + warmup=1000, + beta1=0.9, + beta2=0.98, + epsilon=1e-10, + max_grad_norm=1, + nesterov=False, + ), + "1_2b": AdamConfig( + learning_rate=0.002, + weight_decay=0.2, + min_lr_ratio=0, + warmup=1000, + beta1=0.9, + beta2=0.98, + epsilon=1e-25, + max_grad_norm=2, + nesterov=False, + ), + } + + # Descriptions + descriptions = { + "130m": "130M parameter model trained with the Muon optimizer.", + "300m": "300M parameter model trained with the Muon optimizer.", + "520m": "520M parameter model trained with the Muon optimizer.", + "1_2b": "1.2B parameter model trained with the Muon optimizer.", + } + + # Names for the runs + run_names = { + "130m": f"llama_130m_{optimizer_name}_tps{seq_len*batch_size}", + "300m": f"llama_300m_{optimizer_name}_tps{seq_len*batch_size}", + "520m": f"llama_520m_{optimizer_name}_tps{seq_len*batch_size}", + "1_2b": f"llama_1_2b_{optimizer_name}_tps{seq_len*batch_size}", + } + + # Gather config for the requested size + if size not in param_counts: + raise ValueError(f"Unknown size: {size}") + + param_count = param_counts[size] + model_config = dataclasses.replace(model_cfgs[size], seq_len=seq_len) + seq_len = model_config.seq_len + resource_config = resource_cfgs[size] + if optimizer_name == "muon": + optimizer_config = muon_configs[size] + elif optimizer_name == "adamw": + optimizer_config = adam_configs[size] + else: + raise NotImplementedError(f"Optimizer {optimizer_name} not supported yet in this sweep.") + description = descriptions[size] + run_name = run_names[size] + + num_train_steps = get_num_train_steps(param_count, batch_size, seq_len) + + train = SimpleTrainConfig( + resource_config, + train_batch_size=batch_size, + num_train_steps=num_train_steps, + learning_rate=optimizer_config.learning_rate, + optimizer_config=optimizer_config, + ) + cfg = SpeedrunConfig( + author=AUTHOR, + description=description, + model_config=model_config, + train_config=train, + ) + return run_name, cfg + + +if __name__ == "__main__": + runs = [] + for optimizer_name in ["muon", "adamw"]: + for model_size in ["130m", "300m"]: # For future sweep, add "520m", "1_2b" + for batch_size in [128, 256, 512, 1024]: + runs.append(build_config(optimizer_name, model_size, batch_size)) + + steps = [] + for name, cfg in runs: + cfg.print_run_info() + steps.extend(default_speedrun(name, cfg)) + + executor_main(steps=steps, description="Muon/AdamW speedruns (Chinchilla optimal) | Batch Size Sweep") From dc7fc0a54cf99b1a01c564c8e113bf10c42c87a9 Mon Sep 17 00:00:00 2001 From: Franz Louis Cesista Date: Fri, 5 Sep 2025 23:49:37 +0800 Subject: [PATCH 2/4] add 64 bz & fix script description --- .../muon_adamw_batch_size_sweep.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/experiments/speedrun/muon_adamw_llama_batch_size_scaling/muon_adamw_batch_size_sweep.py b/experiments/speedrun/muon_adamw_llama_batch_size_scaling/muon_adamw_batch_size_sweep.py index 7b061fff31..5e94594e25 100644 --- a/experiments/speedrun/muon_adamw_llama_batch_size_scaling/muon_adamw_batch_size_sweep.py +++ b/experiments/speedrun/muon_adamw_llama_batch_size_scaling/muon_adamw_batch_size_sweep.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Speedruns using the Muon optimizer for various Llama model sizes (Chinchilla optimal steps). +"""Speedruns using the AdamW/Muon optimizer for various Llama model sizes (Chinchilla optimal steps) and batch size. Optimizer configs were searched & provided by Kaiyue Wen in https://wandb.ai/marin-community/marin/reports/Fantastic-Optimizers-and-Where-to-Find-Them--VmlldzoxMjgzMzQ2NQ """ @@ -231,7 +231,7 @@ def build_config(optimizer_name: str, size: str, batch_size: int, seq_len: int=4 runs = [] for optimizer_name in ["muon", "adamw"]: for model_size in ["130m", "300m"]: # For future sweep, add "520m", "1_2b" - for batch_size in [128, 256, 512, 1024]: + for batch_size in [64, 128, 256, 512, 1024]: runs.append(build_config(optimizer_name, model_size, batch_size)) steps = [] From 48de2a759fb56cd5324e14dde599e04f6585970d Mon Sep 17 00:00:00 2001 From: Franz Louis Cesista Date: Fri, 5 Sep 2025 23:57:33 +0800 Subject: [PATCH 3/4] improve run descriptions --- .../muon_adamw_batch_size_sweep.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/experiments/speedrun/muon_adamw_llama_batch_size_scaling/muon_adamw_batch_size_sweep.py b/experiments/speedrun/muon_adamw_llama_batch_size_scaling/muon_adamw_batch_size_sweep.py index 5e94594e25..70ce5efa2f 100644 --- a/experiments/speedrun/muon_adamw_llama_batch_size_scaling/muon_adamw_batch_size_sweep.py +++ b/experiments/speedrun/muon_adamw_llama_batch_size_scaling/muon_adamw_batch_size_sweep.py @@ -178,10 +178,10 @@ def build_config(optimizer_name: str, size: str, batch_size: int, seq_len: int=4 # Descriptions descriptions = { - "130m": "130M parameter model trained with the Muon optimizer.", - "300m": "300M parameter model trained with the Muon optimizer.", - "520m": "520M parameter model trained with the Muon optimizer.", - "1_2b": "1.2B parameter model trained with the Muon optimizer.", + "130m": f"130M parameter model trained with the {optimizer_name} optimizer with tokens-per-step={seq_len*batch_size}.", + "300m": f"300M parameter model trained with the {optimizer_name} optimizer with tokens-per-step={seq_len*batch_size}.", + "520m": f"520M parameter model trained with the {optimizer_name} optimizer with tokens-per-step={seq_len*batch_size}.", + "1_2b": f"1.2B parameter model trained with the {optimizer_name} optimizer with tokens-per-step={seq_len*batch_size}.", } # Names for the runs From fe25d7bdf677250534a38fe4d8184afaed5e876d Mon Sep 17 00:00:00 2001 From: Franz Louis Cesista Date: Sat, 6 Sep 2025 02:38:31 +0800 Subject: [PATCH 4/4] impl lr ~ sqrt(BS) law --- .../muon_adamw_batch_size_sweep.py | 23 ++++++++++++++----- 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/experiments/speedrun/muon_adamw_llama_batch_size_scaling/muon_adamw_batch_size_sweep.py b/experiments/speedrun/muon_adamw_llama_batch_size_scaling/muon_adamw_batch_size_sweep.py index 70ce5efa2f..2c38af8f45 100644 --- a/experiments/speedrun/muon_adamw_llama_batch_size_scaling/muon_adamw_batch_size_sweep.py +++ b/experiments/speedrun/muon_adamw_llama_batch_size_scaling/muon_adamw_batch_size_sweep.py @@ -40,7 +40,7 @@ def get_num_train_steps(param_count, batch_size, seq_len): return total_tokens // tokens_per_step -def build_config(optimizer_name: str, size: str, batch_size: int, seq_len: int=4096) -> tuple[str, SpeedrunConfig]: +def build_config(optimizer_name: str, size: str, batch_size: int, seq_len: int = 4096) -> tuple[str, SpeedrunConfig]: # Parameter counts param_counts = { "130m": 130_000_000, @@ -178,10 +178,18 @@ def build_config(optimizer_name: str, size: str, batch_size: int, seq_len: int=4 # Descriptions descriptions = { - "130m": f"130M parameter model trained with the {optimizer_name} optimizer with tokens-per-step={seq_len*batch_size}.", - "300m": f"300M parameter model trained with the {optimizer_name} optimizer with tokens-per-step={seq_len*batch_size}.", - "520m": f"520M parameter model trained with the {optimizer_name} optimizer with tokens-per-step={seq_len*batch_size}.", - "1_2b": f"1.2B parameter model trained with the {optimizer_name} optimizer with tokens-per-step={seq_len*batch_size}.", + "130m": ( + f"130M parameter model trained with the {optimizer_name} optimizer with tokens-per-step={seq_len*batch_size}." + ), + "300m": ( + f"300M parameter model trained with the {optimizer_name} optimizer with tokens-per-step={seq_len*batch_size}." + ), + "520m": ( + f"520M parameter model trained with the {optimizer_name} optimizer with tokens-per-step={seq_len*batch_size}." + ), + "1_2b": ( + f"1.2B parameter model trained with the {optimizer_name} optimizer with tokens-per-step={seq_len*batch_size}." + ), } # Names for the runs @@ -211,11 +219,14 @@ def build_config(optimizer_name: str, size: str, batch_size: int, seq_len: int=4 num_train_steps = get_num_train_steps(param_count, batch_size, seq_len) + # Taken from Simo Ryu's observation that lr ~ sqrt(BS) also holds for Shampoo & Muon: https://x.com/cloneofsimo/status/1907731069878825400 + baseline_batch_size = 128 + learning_rate = optimizer_config.learning_rate * (batch_size / baseline_batch_size)**0.5 train = SimpleTrainConfig( resource_config, train_batch_size=batch_size, num_train_steps=num_train_steps, - learning_rate=optimizer_config.learning_rate, + learning_rate=learning_rate, optimizer_config=optimizer_config, ) cfg = SpeedrunConfig(