|
28 | 28 | from .accelerator import get_accelerator
|
29 | 29 | from .constants import TORCH_DISTRIBUTED_DEFAULT_PORT
|
30 | 30 | from .runtime.engine import DeepSpeedEngine, DeepSpeedOptimizerCallable, DeepSpeedSchedulerCallable
|
31 |
| -from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER |
| 31 | +from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER, MUON_OPTIMIZER |
32 | 32 | from .runtime.hybrid_engine import DeepSpeedHybridEngine
|
33 | 33 | from .runtime.pipe.engine import PipelineEngine
|
34 | 34 | from .inference.engine import InferenceEngine
|
@@ -66,6 +66,15 @@ def _parse_version(version_str):
|
66 | 66 | dist = None
|
67 | 67 |
|
68 | 68 |
|
| 69 | +def set_optimizer_flags(config_class, model): |
| 70 | + if config_class.optimizer_name == MUON_OPTIMIZER: |
| 71 | + for p in model.parameters(): |
| 72 | + if p.ndim >= 2: |
| 73 | + setattr(p, "use_muon", True) |
| 74 | + else: |
| 75 | + setattr(p, "use_muon", False) |
| 76 | + |
| 77 | + |
69 | 78 | def initialize(args=None,
|
70 | 79 | model: torch.nn.Module = None,
|
71 | 80 | optimizer: Optional[Union[Optimizer, DeepSpeedOptimizerCallable]] = None,
|
@@ -177,6 +186,7 @@ def initialize(args=None,
|
177 | 186 | assert config is not None, "DeepSpeed requires --deepspeed_config to specify configuration file"
|
178 | 187 | if not isinstance(model, PipelineModule):
|
179 | 188 | config_class = DeepSpeedConfig(config, mpu, mesh_device=mesh_device)
|
| 189 | + set_optimizer_flags(config_class, model) |
180 | 190 | if config_class.hybrid_engine.enabled:
|
181 | 191 | engine = DeepSpeedHybridEngine(args=args,
|
182 | 192 | model=model,
|
@@ -206,6 +216,7 @@ def initialize(args=None,
|
206 | 216 | assert mpu is None, "mpu must be None with pipeline parallelism"
|
207 | 217 | mpu = model.mpu()
|
208 | 218 | config_class = DeepSpeedConfig(config, mpu)
|
| 219 | + set_optimizer_flags(config_class, model) |
209 | 220 | engine = PipelineEngine(args=args,
|
210 | 221 | model=model,
|
211 | 222 | optimizer=optimizer,
|
|
0 commit comments