Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add muon code #168

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions configs/10B/H100.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@ sched_type = "wsd-sqrt"
batch_size = 128 #1M tokens bs
warmup_steps = 1000
total_steps = 1_000_000_000_000
lr = 7.5e-5

adam_betas1 = 0.9
adam_betas2 = 0.95
weight_decay = 0.1

z_loss = true

[optim.optim]
lr = 7.5e-5
betas1 = 0.9
betas2 = 0.95
weight_decay = 0.1

[data]
seq_length = 8192
dataset_name_or_paths = "/data/datasets/fineweb-edu,/data/datasets/fineweb,/data/datasets/StackV1-popular,/data/datasets/dclm-baseline-1.0-parquet,/data/datasets/open-web-math"
Expand Down
11 changes: 6 additions & 5 deletions configs/10B/H100_cooldown.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@ batch_size = 128 #1M tokens bs
warmup_steps = 1000
stable_steps = 74700
total_steps = 90400
lr = 7.5e-5

adam_betas1 = 0.9
adam_betas2 = 0.95
weight_decay = 0.1

z_loss = true

[optim.optim]
lr = 7.5e-5
betas1 = 0.9
betas2 = 0.95
weight_decay = 0.1

[data]
seq_length = 8192
dataset_name_or_paths = "/data/datasets/fineweb-edu,/data/datasets/fineweb,/data/datasets/StackV1-popular"
Expand Down
2 changes: 2 additions & 0 deletions configs/13B/H100.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ ac_ckpt = true
batch_size = 1024 #2M tokens bs
warmup_steps = 1000
total_steps = 88_000

[optim.optim]
lr = 3e-4

[data]
Expand Down
6 changes: 5 additions & 1 deletion configs/150M/3090.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,8 @@ reshard_after_forward = true
batch_size = 512
warmup_steps = 1000
total_steps = 88_000
lr = 4e-4


[optim.optim]
lr = 4e-4

5 changes: 4 additions & 1 deletion configs/150M/A40.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,7 @@ reshard_after_forward = true
batch_size = 512
warmup_steps = 1000
total_steps = 88_000
lr = 4e-4

[optim.optim]
lr = 4e-4

5 changes: 4 additions & 1 deletion configs/150M/H100.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,7 @@ reshard_after_forward = true
batch_size = 512
warmup_steps = 1000
total_steps = 88_000
lr = 4e-4

[optim.optim]
lr = 4e-4

5 changes: 4 additions & 1 deletion configs/150M_short/3090.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,7 @@ reshard_after_forward = true
batch_size = 512
warmup_steps = 500
total_steps = 8192
lr = 4e-4


[optim.optim]
lr = 4e-4
6 changes: 5 additions & 1 deletion configs/150M_short/A40.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@ type_model = "llama2"
micro_bs = 32 # change this base on the gpu
reshard_after_forward = true


[optim]
batch_size = 512
warmup_steps = 500
total_steps = 8192
lr = 4e-4


[optim.optim]
lr = 4e-4
5 changes: 4 additions & 1 deletion configs/150M_short/H100.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,7 @@ reshard_after_forward = true
batch_size = 512
warmup_steps = 500
total_steps = 8192
lr = 4e-4


[optim.optim]
lr = 4e-4
11 changes: 7 additions & 4 deletions configs/1B/H100.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@ project = "debug_1B_zero_band"
type_model = "llama2"

[train]
micro_bs = 16
micro_bs = 32
reshard_after_forward = true

[optim]
batch_size = 2048
batch_size = 1024
warmup_steps = 1000
total_steps = 88_000
lr = 4e-4
total_steps = 8192

[optim.optim]
lr = 7e-4
15 changes: 0 additions & 15 deletions configs/1B/H100_c4.toml

This file was deleted.

21 changes: 0 additions & 21 deletions configs/1B/H100_llama2_edu.toml

This file was deleted.

23 changes: 0 additions & 23 deletions configs/1B/H100_llama2_edu_no_feat.toml

This file was deleted.

22 changes: 0 additions & 22 deletions configs/1B/H100_llama3.toml

This file was deleted.

25 changes: 0 additions & 25 deletions configs/1B_diloco/H100.toml

This file was deleted.

2 changes: 2 additions & 0 deletions configs/7B/H100.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ micro_bs = 1
batch_size = 1024 #2M tokens bs
warmup_steps = 1000
total_steps = 88_000

[optim.optim]
lr = 3e-4

[data]
Expand Down
2 changes: 2 additions & 0 deletions configs/7B_diloco/H100.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ micro_bs = 1
batch_size = 1024 #2M tokens bs
warmup_steps = 1000
total_steps = 88_000

[optim.optim]
lr = 3e-4

[data]
Expand Down
2 changes: 2 additions & 0 deletions configs/test.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,6 @@ num_workers = 1
batch_size = 128
warmup_steps = 1000
total_steps = 88_000

[optim.optim]
lr = 4e-4
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies = [
"pyarrow",
"toposolve",
"psutil",
"torch-shampoo @ git+https://github.com/facebookresearch/optimizers.git@main",
]

[project.optional-dependencies]
Expand Down
2 changes: 1 addition & 1 deletion scripts/simulate_multi_node_diloco.sh
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ export GLOO_SOCKET_IFNAME=lo
for i in $(seq 0 $(($N - 1 )))
do
> logs/log$i.log
WANDB_MODE=$([ $i -eq 0 ] && echo "online" || echo "online") GLOBAL_UNIQUE_ID=$i GLOBAL_RANK=$i CUDA_VISIBLE_DEVICES=$(get_cuda_devices $NUM_GPU $i) uv run torchrun --nproc_per_node=$NUM_GPU --node-rank 0 --rdzv-endpoint localhost:$((BASE_PORT + $i)) --nnodes=1 $@ --data.data_rank $i --data.data_world_size $N > logs/log$i.log 2>&1 &
WANDB_MODE=$([ $i -eq 0 ] && echo "online" || echo "offline") GLOBAL_UNIQUE_ID=$i GLOBAL_RANK=$i CUDA_VISIBLE_DEVICES=$(get_cuda_devices $NUM_GPU $i) uv run torchrun --nproc_per_node=$NUM_GPU --node-rank 0 --rdzv-endpoint localhost:$((BASE_PORT + $i)) --nnodes=1 $@ --data.data_rank $i --data.data_world_size $N > logs/log$i.log 2>&1 &
child_pids+=($!)
done

Expand Down
59 changes: 59 additions & 0 deletions src/zeroband/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import TypeAlias
from pydantic_config import BaseConfig
import torch
from zeroband.optimizers.muon import Muon, AdamConfig, MuonConfig
from distributed_shampoo import EighEigenvalueCorrectionConfig, DistributedShampoo, FullyShardShampooConfig


class SoapConfig(BaseConfig):
lr: float = 4e-4
weight_decay: float = 1e-05
betas1: float = 0.9
betas2: float = 0.95

max_preconditioner_dim: int = 8192
precondition_frequency: int = 100


OptimizersConfig: TypeAlias = AdamConfig | MuonConfig | SoapConfig


def get_optimizer(params: list[torch.nn.Parameter], config: OptimizersConfig) -> torch.optim.Optimizer:
if isinstance(config, AdamConfig):
return torch.optim.AdamW(
params,
lr=config.lr,
weight_decay=config.weight_decay,
betas=(config.betas1, config.betas2),
)
elif isinstance(config, MuonConfig):
return Muon(
params,
lr=config.lr,
momentum=config.momentum,
nesterov=config.nesterov,
ns_steps=config.pseudo_order_steps,
adamw_lr=config.adam.lr,
adamw_betas=(config.adam.betas1, config.adam.betas2),
adamw_wd=config.adam.weight_decay,
)
elif isinstance(config, SoapConfig):
return DistributedShampoo(
params,
lr=config.lr,
betas=(config.betas1, config.betas2),
epsilon=1e-12,
weight_decay=config.weight_decay,
max_preconditioner_dim=config.max_preconditioner_dim,
precondition_frequency=config.precondition_frequency,
use_decoupled_weight_decay=True,
# This can also be set to `QREigenvalueCorrectionConfig` which is less expensive
# and might therefore allow for a smaller `precondition_frequency`.
preconditioner_computation_config=EighEigenvalueCorrectionConfig(),
distributed_config=FullyShardShampooConfig(),
)
else:
raise ValueError(f"Unknown optimizer {config.optimizer}")


__all__ = ["OptimizersConfig", "get_optimizer"]
Loading
Loading