-
Notifications
You must be signed in to change notification settings - Fork 71
speedrun submission: Add llama_50m_muon_1x - Muon optimizer at 1× Chinchilla scale #2185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| # 50M Llama with Muon Optimizer (1× Chinchilla) | ||
|
|
||
| **Model:** 50M parameter Llama | ||
| **Optimizer:** Muon | ||
| **Data Scale:** 1× Chinchilla-optimal (1B tokens) | ||
| **Expected BPB:** 1.38-1.42 | ||
|
|
||
| ## Rationale | ||
|
|
||
| Baseline validation experiment. Tests whether Muon beats Adam at standard Chinchilla-optimal scale before investing in 4× data runs. This experiment establishes whether Muon's advantages hold at the canonical 20:1 token-to-parameter ratio. | ||
|
|
||
| ## Hyperparameters | ||
|
|
||
| - **Learning rate:** 0.020 | ||
| - **Batch size:** 128 | ||
| - **Training steps:** 7,629 (1B tokens / 128 batch / 1024 seq_len) | ||
| - **Optimizer:** Muon with momentum=0.95, warmup=0 | ||
| - **Evaluation:** Every 500 steps | ||
|
|
||
| ## Key Differences from 4x Experiment | ||
|
|
||
| - 1/4 the data (1B vs 4B tokens) | ||
| - 1/4 the training steps (7,629 vs 30,518) | ||
| - Shorter time limit (30 minutes vs 2 hours) | ||
| - Same hyperparameters | ||
|
|
||
| ## Running | ||
|
|
||
| ```bash | ||
| sbatch submit_slurm.sh | ||
| ``` | ||
|
|
||
| ## Expected Results | ||
|
|
||
| If Muon maintains its efficiency advantage at 1× scale, we expect BPB around 1.38-1.42. This would validate proceeding with 4× data experiments. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,130 @@ | ||
| { | ||
| "runs": [ | ||
| { | ||
| "run_info": { | ||
| "author": { | ||
| "affiliation": "Northeastern University", | ||
| "name": "redagavin", | ||
| "url": "https://redagavin.github.io/" | ||
| }, | ||
| "description": "Phase 1: 50M Llama with Muon at 1\u00d7 Chinchilla (1B tokens). Baseline validation.", | ||
| "device_flops": 989500000000000.0, | ||
| "eval/paloma/c4_en/bpb": 1.398867130279541, | ||
| "model_config": { | ||
| "activation_function": "silu", | ||
| "attn_backend": null, | ||
| "cross_entropy_block_size": null, | ||
| "flash_attention_block_size": null, | ||
| "gradient_checkpointing": true, | ||
| "head_dim": null, | ||
| "hidden_dim": 192, | ||
| "hybrid_norm": false, | ||
| "initializer_range": 0.02, | ||
| "input_embedding_norm": false, | ||
| "intermediate_dim": 448, | ||
| "layer_norm_epsilon": 1e-05, | ||
| "num_heads": 2, | ||
| "num_kv_heads": 2, | ||
| "num_layers": 4, | ||
| "reference_checkpoint": "NousResearch/Llama-2-7b-hf", | ||
| "rope": { | ||
| "factor": 1.0, | ||
| "theta": 10000 | ||
| }, | ||
| "scan_layers": true, | ||
| "seq_len": 1024, | ||
| "tie_word_embeddings": false, | ||
| "tokenizer": null, | ||
| "upcast_attn": false, | ||
| "use_bias": false, | ||
| "use_layer_norm_weight": true, | ||
| "use_qk_norm": false | ||
| }, | ||
| "model_flops": 1.6698528441040896e+17, | ||
| "model_size": 50874048, | ||
| "num_chips": 1, | ||
| "num_devices": 1, | ||
| "resources": { | ||
| "accelerator_type": "H200", | ||
| "device_flops_override": null, | ||
| "gpu_count": 1 | ||
| }, | ||
| "run_completion_timestamp": "2025-11-23 07:58:21 UTC", | ||
| "tokenized_dataset": "ExecutorStep(name='tokenized/subcache/fineweb-edu-10B', fn=<function _actually_download_pretokenized_cache at 0x154b84eb4ea0>, config=PretokenizedCacheDownloadConfig(cache_path=OutputName(name=None), tokenizer=VersionedValue(value='stanford-crfm/marin-tokenizer'), hf_repo_id=VersionedValue(value='marin-community/fineweb-edu-pretokenized-10B'), hf_revision=VersionedValue(value=None), hf_repo_type_prefix='datasets', hf_token=None, format=TextLmDatasetFormat(text_key='text'), cache_options=None, tags=[]), description=None, override_output_path=None, pip_dependency_groups=None)", | ||
| "total_tokens": 999948288, | ||
| "train_config": { | ||
| "allow_partial_checkpoint": false, | ||
| "beta1": null, | ||
| "beta2": null, | ||
| "cycle_length": null, | ||
| "data_seed": null, | ||
| "decay": null, | ||
| "ema_beta": null, | ||
| "epsilon": null, | ||
| "initialize_from_checkpoint_path": null, | ||
| "initialize_from_hf": null, | ||
| "int8": false, | ||
| "learning_rate": 0.02, | ||
| "lr_schedule": null, | ||
| "max_eval_batches": null, | ||
| "max_grad_norm": null, | ||
| "min_lr_ratio": null, | ||
| "num_train_steps": 7629, | ||
| "optimizer_config": { | ||
| "adam_lr": 0.004, | ||
| "adam_weight_decay": null, | ||
| "backend_steps": 5, | ||
| "beta1": 0.8, | ||
| "beta2": 0.98, | ||
| "cooldown": null, | ||
| "cycle_length": null, | ||
| "cycles": null, | ||
| "decay": 0.8, | ||
| "default_weight_decay_mask": null, | ||
| "epsilon": 1e-15, | ||
| "haps": null, | ||
| "learning_rate": 0.02, | ||
| "lr": 0.02, | ||
| "lr_schedule": "linear", | ||
| "max_grad_norm": 1, | ||
| "min_lr_ratio": 0, | ||
| "momentum": 0.95, | ||
| "muon_epsilon": 1e-05, | ||
| "nesterov": true, | ||
| "rewarmup": 0.0, | ||
| "use_kimi_scaling": false, | ||
| "warmup": 0, | ||
| "weight_decay": 0.0, | ||
| "weight_decay_modules": null | ||
| }, | ||
| "per_device_eval_parallelism": null, | ||
| "reset_data_loader_on_init": true, | ||
| "rewarmup": null, | ||
| "skip_bad_steps": false, | ||
| "steps_per_eval": 500, | ||
| "steps_per_export": 10000, | ||
| "steps_per_hf_export": null, | ||
| "steps_per_task_eval": null, | ||
| "train_batch_size": 128, | ||
| "warmup": null, | ||
| "watch": { | ||
| "include_histograms": false, | ||
| "include_norms": true, | ||
| "include_per_parameter_norms": true, | ||
| "interval": 10, | ||
| "split_scan_layers": true, | ||
| "watch_targets": [ | ||
| "grads", | ||
| "params" | ||
| ] | ||
| }, | ||
| "weight_decay": null, | ||
| "z_loss_weight": null | ||
| }, | ||
| "training_hardware_flops": 1.1666472839337308e+18, | ||
| "training_time": 1179.0270681492984, | ||
| "wandb_run_link": "https://wandb.ai/marin-speedrun/marin-speedrun/runs/llama_50m_muon_1x-bd5fc4" | ||
| } | ||
| } | ||
| ] | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
| # Copyright 2025 The Marin Authors | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """ | ||
| Phase 1: 50M Llama with Muon optimizer at 1× Chinchilla-optimal data. | ||
|
|
||
| Baseline validation: Does Muon beat Adam at standard scale? | ||
| Expected: 1.38-1.42 BPB | ||
| """ | ||
|
|
||
| from levanter.optim import MuonConfig | ||
|
|
||
| from experiments.llama import llama_50m | ||
| from experiments.simple_train_config import SimpleTrainConfig | ||
| from marin.execution.executor import executor_main | ||
| from marin.resources import GpuConfig | ||
| from marin.speedrun.speedrun import Author, SpeedrunConfig, default_speedrun | ||
|
|
||
|
|
||
| muon_config = MuonConfig( | ||
| learning_rate=0.020, | ||
| adam_lr=0.004, | ||
| momentum=0.95, | ||
| beta1=0.8, | ||
| beta2=0.98, | ||
| epsilon=1e-15, | ||
| muon_epsilon=1e-5, | ||
| max_grad_norm=1, | ||
| warmup=0, | ||
| min_lr_ratio=0, | ||
| lr_schedule="linear", | ||
| decay=0.8, | ||
| ) | ||
|
|
||
| # Calculate steps for 1× Chinchilla (1B tokens) | ||
| # 50M params × 1 × 20 = 1B tokens | ||
| # 1B tokens / (128 batch × 1024 seq_len) = 7,629 steps | ||
| num_train_steps = 7629 | ||
|
|
||
| speedrun_config = SpeedrunConfig( | ||
| author=Author( | ||
| name="redagavin", | ||
| affiliation="Northeastern University", | ||
| url="https://redagavin.github.io/" | ||
| ), | ||
| description="Phase 1: 50M Llama with Muon at 1× Chinchilla (1B tokens). Baseline validation.", | ||
| model_config=llama_50m, | ||
| train_config=SimpleTrainConfig( | ||
| GpuConfig(gpu_count=1, accelerator_type="H200"), | ||
|
||
| train_batch_size=128, | ||
| num_train_steps=num_train_steps, | ||
| learning_rate=muon_config.learning_rate, | ||
| optimizer_config=muon_config, | ||
| steps_per_eval=500, | ||
| ), | ||
| ) | ||
|
|
||
| speedrun_config.print_run_info() | ||
|
|
||
| if __name__ == "__main__": | ||
| executor_main(steps=default_speedrun("llama_50m_muon_1x", speedrun_config)) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The module
marin.resourcesdoes not exist.GpuConfigshould be imported fromfray.clusterinstead. Change to:from fray.cluster import ResourceConfig