Skip to content

Commit 2585881

Browse files
delocksfc-gh-truwasetohtana
authored
Make Muon optimizer easier to enable (#7555)
The original Muon optimizer PR (#7509) requires user to explicitly set `use_muon` flags in `model.parameters()`, as shown in test https://github.com/deepspeedai/DeepSpeed/blob/master/tests/unit/ops/muon/test_muon.py#L27 . This PR integrate setting of `use_muon` into DeepSpeed before engine initialization. This makes Muon optimizer easier to use. User only needs to change optimizer in `config.json` from `AdamW` to `Muon`, no need to change code. It will solve the following issue #7552 --------- Signed-off-by: Ma, Guokai <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Masahiro Tanaka <[email protected]>
1 parent aa539c6 commit 2585881

File tree

2 files changed

+12
-11
lines changed

2 files changed

+12
-11
lines changed

deepspeed/__init__.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from .accelerator import get_accelerator
2929
from .constants import TORCH_DISTRIBUTED_DEFAULT_PORT
3030
from .runtime.engine import DeepSpeedEngine, DeepSpeedOptimizerCallable, DeepSpeedSchedulerCallable
31-
from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER
31+
from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER, MUON_OPTIMIZER
3232
from .runtime.hybrid_engine import DeepSpeedHybridEngine
3333
from .runtime.pipe.engine import PipelineEngine
3434
from .inference.engine import InferenceEngine
@@ -66,6 +66,15 @@ def _parse_version(version_str):
6666
dist = None
6767

6868

69+
def set_optimizer_flags(config_class, model):
70+
if config_class.optimizer_name == MUON_OPTIMIZER:
71+
for p in model.parameters():
72+
if p.ndim >= 2:
73+
setattr(p, "use_muon", True)
74+
else:
75+
setattr(p, "use_muon", False)
76+
77+
6978
def initialize(args=None,
7079
model: torch.nn.Module = None,
7180
optimizer: Optional[Union[Optimizer, DeepSpeedOptimizerCallable]] = None,
@@ -177,6 +186,7 @@ def initialize(args=None,
177186
assert config is not None, "DeepSpeed requires --deepspeed_config to specify configuration file"
178187
if not isinstance(model, PipelineModule):
179188
config_class = DeepSpeedConfig(config, mpu, mesh_device=mesh_device)
189+
set_optimizer_flags(config_class, model)
180190
if config_class.hybrid_engine.enabled:
181191
engine = DeepSpeedHybridEngine(args=args,
182192
model=model,
@@ -206,6 +216,7 @@ def initialize(args=None,
206216
assert mpu is None, "mpu must be None with pipeline parallelism"
207217
mpu = model.mpu()
208218
config_class = DeepSpeedConfig(config, mpu)
219+
set_optimizer_flags(config_class, model)
209220
engine = PipelineEngine(args=args,
210221
model=model,
211222
optimizer=optimizer,

tests/unit/ops/muon/test_muon.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,6 @@
2424
muon_configs.append([optimizer_name, stage, lr, model_dim, nlayer])
2525

2626

27-
def set_muon_flag(params):
28-
for p in params:
29-
if p.ndim >= 2:
30-
setattr(p, "use_muon", True)
31-
else:
32-
setattr(p, "use_muon", False)
33-
34-
3527
@pytest.mark.parametrize('optimizer_type, zero_stage, lr, hidden_dim, nlayer', muon_configs)
3628
class TestMuonConfigs(DistributedTest):
3729

@@ -55,8 +47,6 @@ def test(self, optimizer_type, zero_stage, lr, hidden_dim, nlayer):
5547
# Perform a few training steps to ensure the optimizer works correctly
5648

5749
model = SimpleModel(hidden_dim=hidden_dim, nlayers=nlayer)
58-
if 'muon' in optimizer_type:
59-
set_muon_flag(model.parameters())
6050
initial_params = [p.clone().cpu() for p in model.parameters()]
6151
engine, optimizer, _, _ = deepspeed.initialize(
6252
config=config_dict,

0 commit comments

Comments
 (0)