Skip to content

Commit

Permalink
update type opt
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Dec 6, 2024
1 parent 30a0f23 commit bcbf8fd
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
8 changes: 6 additions & 2 deletions src/zeroband/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from typing import TypeAlias
import torch
from zeroband.optimizers.muon import Muon, AdamConfig, MuonConfig


def get_optimizer(params: list[torch.nn.Parameter], config: AdamConfig | MuonConfig) -> torch.optim.Optimizer:
OptimizersConfig: TypeAlias = AdamConfig | MuonConfig


def get_optimizer(params: list[torch.nn.Parameter], config: OptimizersConfig) -> torch.optim.Optimizer:
if isinstance(config, AdamConfig):
return torch.optim.AdamW(
params,
Expand All @@ -25,4 +29,4 @@ def get_optimizer(params: list[torch.nn.Parameter], config: AdamConfig | MuonCon
raise ValueError(f"Unknown optimizer {config.optimizer}")


__all__ = ["AdamConfig", "MuonConfig", "get_optimizer"]
__all__ = ["OptimizersConfig", "get_optimizer"]
4 changes: 2 additions & 2 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from zeroband.diloco import Diloco, DilocoConfig
from zeroband.comms import ElasticDeviceMesh
from zeroband.loss import cross_entropy_max_z_loss
from zeroband.optimizers import AdamConfig, MuonConfig, get_optimizer
from zeroband.optimizers import AdamConfig, OptimizersConfig, get_optimizer

from zeroband.utils import (
FakeTokenizer,
Expand All @@ -42,7 +42,7 @@


class OptimConfig(BaseConfig):
optim: AdamConfig | MuonConfig = AdamConfig()
optim: OptimizersConfig = AdamConfig()

sched_type: Literal["cosine", "linear", "wsd-sqrt"] = "cosine"
warmup_steps: int = 1000
Expand Down

0 comments on commit bcbf8fd

Please sign in to comment.