Skip to content

Commit

Permalink
add soap
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Dec 6, 2024
1 parent bcbf8fd commit d5b71e1
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 57 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ description = "ZeroBand is a production ready codebase for decentralized trainin
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"torch==2.4.1",
"torch==2.5.1",
"numpy",
"setuptools",
"transformers>=4.44.2",
Expand All @@ -19,6 +19,7 @@ dependencies = [
"pyarrow",
"toposolve",
"psutil",
"torch-shampoo @ git+https://github.com/facebookresearch/optimizers.git@main",
]

[project.optional-dependencies]
Expand Down
29 changes: 28 additions & 1 deletion src/zeroband/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
from typing import TypeAlias
from pydantic_config import BaseConfig
import torch
from zeroband.optimizers.muon import Muon, AdamConfig, MuonConfig
from distributed_shampoo import EighEigenvalueCorrectionConfig, DistributedShampoo, FullyShardShampooConfig


OptimizersConfig: TypeAlias = AdamConfig | MuonConfig
class SoapConfig(BaseConfig):
lr: float = 4e-4
weight_decay: float = 0.1
betas1: float = 0.9
betas2: float = 0.95

max_preconditioner_dim: int = 8192
precondition_frequency: int = 100


OptimizersConfig: TypeAlias = AdamConfig | MuonConfig | SoapConfig


def get_optimizer(params: list[torch.nn.Parameter], config: OptimizersConfig) -> torch.optim.Optimizer:
Expand All @@ -25,6 +37,21 @@ def get_optimizer(params: list[torch.nn.Parameter], config: OptimizersConfig) ->
adamw_betas=(config.adam.betas1, config.adam.betas2),
adamw_wd=config.adam.weight_decay,
)
elif isinstance(config, SoapConfig):
return DistributedShampoo(
params,
lr=config.lr,
betas=(config.betas1, config.betas2),
epsilon=1e-12,
weight_decay=1e-05,
max_preconditioner_dim=8192,
precondition_frequency=100,
use_decoupled_weight_decay=True,
# This can also be set to `QREigenvalueCorrectionConfig` which is less expensive
# and might therefore allow for a smaller `precondition_frequency`.
preconditioner_computation_config=EighEigenvalueCorrectionConfig(),
distributed_config=FullyShardShampooConfig(),
)
else:
raise ValueError(f"Unknown optimizer {config.optimizer}")

Expand Down
11 changes: 11 additions & 0 deletions tests/test_torchrun/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,14 @@ def test_muon(diloco: bool):
extra_args=["--optim.optim.pseudo_order_steps", "6"],
diloco=diloco,
)


@pytest.mark.parametrize("diloco", [False, True])
def test_soap(diloco: bool):
num_gpus = [1, 2] if diloco else [2, 1]
_test_multi_gpu(
num_gpus,
"debug/diloco.toml" if diloco else "debug/normal.toml",
extra_args=["--optim.optim.precondition_frequency", "1"],
diloco=diloco,
)
Loading

0 comments on commit d5b71e1

Please sign in to comment.