Skip to content

Commit 6cf90ce

Browse files
authored
[BREAKING][misc] feat: Abstract optimizer (#3656)
Abstract optimizer so can be used with whatever module and method a user wants, should be backwards compatible as default is `torch.optim.AdamW`, adds `{actor_rollout_ref.actor,critic}.optim.{optimizer,optimizer_impl,override_optimizer_config}` ```yaml # Default optimizer_impl: torch.optim optimizer: AdamW ``` ```yaml # Example optimizer_impl: torchao.optim optimizer: _AdamW override_optimizer_config: bf16_stochastic_round: true ``` **Important**: fsdp_sft_trainer optim aligned with FSDP optim `optim.warmup_steps_ratio`->`optim.lr_warmup_steps_ratio`
1 parent f2d6271 commit 6cf90ce

File tree

12 files changed

+133
-61
lines changed

12 files changed

+133
-61
lines changed

docs/examples/config.rst

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -643,21 +643,28 @@ Optim
643643
.. code:: yaml
644644
645645
optim:
646+
optimizer: AdamW
647+
optimizer_impl: torch.optim
646648
lr: 1e-5
647649
weight_decay: 0.01
648-
warmup_steps_ratio: 0.1
650+
lr_warmup_steps_ratio: 0.1
649651
clip_grad: 1.0
650652
lr_scheduler: cosine
653+
override_optimizer_config: null
651654
655+
- ``optimizer``: Optimizer class name (e.g., ``"AdamW"``, ``"AdamW8bit"``, ``"_AdamW"``). The class name as it appears in the module.
656+
- ``optimizer_impl``: Module path to import optimizer from (e.g., ``"torch.optim"``, ``"torchao.optim"``, ``"bitsandbytes.optim"``).
652657
- ``optim.lr``: Learning rate for the optimizer.
653658
- ``optim.weight_decay``: Weight decay for the optimizer.
654-
- ``optim.warmup_steps_ratio``: Ratio of warmup steps to total training steps.
659+
- ``optim.lr_warmup_steps_ratio``: Ratio of warmup steps to total training steps.
655660
- ``optim.clip_grad``: Gradient clipping value.
656661
- ``optim.lr_scheduler``: Learning rate scheduler type. Options:
657662

658663
- ``cosine``: Cosine learning rate scheduler with warmup (default).
659664
- ``wsd``: Warmup-Stable-Decay scheduler that provides a stable learning rate phase between warmup and decay phases.
660665

666+
- ``override_optimizer_config``: Dictionary of additional optimizer-specific keyword arguments. For example, to use ``torchao.optim``'s ``_AdamW`` with BF16 stochastic rounding: ``{"bf16_stochastic_round": true}``
667+
661668
Model
662669
~~~~~~~~~~~~
663670

recipe/prime/prime_fsdp_workers.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
)
4141
from verl.utils.import_utils import import_external_libs
4242
from verl.utils.profiler import log_gpu_memory_usage
43+
from verl.workers.config.optimizer import build_optimizer
4344
from verl.workers.fsdp_workers import create_device_mesh, get_sharding_strategy
4445
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
4546

@@ -87,7 +88,6 @@ def __init__(self, config):
8788

8889
def _build_reward_ref_model_optimizer(self, config):
8990
# the following line is necessary
90-
from torch import optim
9191
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
9292
from torch.distributed.fsdp import MixedPrecision
9393

@@ -219,12 +219,7 @@ def _build_reward_ref_model_optimizer(self, config):
219219
cpu_offload=None,
220220
)
221221

222-
reward_optimizer = optim.AdamW(
223-
reward_module.parameters(),
224-
lr=config.model.optim.lr,
225-
betas=config.model.optim.get("betas", (0.9, 0.999)),
226-
weight_decay=config.model.optim.get("weight_decay", 1e-2),
227-
)
222+
reward_optimizer = build_optimizer(reward_module.parameters(), config.model.optim)
228223

229224
total_steps = config.model.optim.get("total_training_steps", 0)
230225
num_warmup_steps = int(config.model.optim.get("lr_warmup_steps", -1))

tests/workers/config/test_actor_config_on_cpu.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@
1616
import unittest
1717

1818
from verl.utils.config import omega_conf_to_dataclass
19-
from verl.workers.config import ActorConfig, FSDPActorConfig, McoreActorConfig, OptimizerConfig
19+
from verl.workers.config import (
20+
ActorConfig,
21+
FSDPActorConfig,
22+
McoreActorConfig,
23+
OptimizerConfig,
24+
)
2025

2126

2227
class TestActorConfig(unittest.TestCase):
@@ -31,7 +36,7 @@ def test_config_inheritance(self):
3136
"ppo_micro_batch_size_per_gpu": 256,
3237
"clip_ratio": 0.2,
3338
"optim": {
34-
"_target_": "verl.workers.config.OptimizerConfig",
39+
"_target_": "verl.workers.config.McoreOptimizerConfig",
3540
"lr": 0.1,
3641
},
3742
}
@@ -42,7 +47,7 @@ def test_config_inheritance(self):
4247
"ppo_micro_batch_size_per_gpu": 256,
4348
"clip_ratio": 0.2,
4449
"optim": {
45-
"_target_": "verl.workers.config.OptimizerConfig",
50+
"_target_": "verl.workers.config.FSDPOptimizerConfig",
4651
"lr": 0.1,
4752
},
4853
}

tests/workers/config/test_critic_config_on_cpu.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
from verl.workers.config import (
2424
CriticConfig,
2525
FSDPCriticConfig,
26+
FSDPOptimizerConfig,
2627
McoreCriticConfig,
28+
McoreOptimizerConfig,
2729
OptimizerConfig,
2830
)
2931

@@ -103,16 +105,15 @@ def test_fsdp_critic_config_instantiation_from_yaml(self, config_dir):
103105

104106
def test_config_inheritance_hierarchy(self):
105107
"""Test that the inheritance hierarchy is correct."""
106-
optim = OptimizerConfig(lr=0.1)
107-
megatron_config = McoreCriticConfig(ppo_micro_batch_size_per_gpu=1, optim=optim)
108+
megatron_config = McoreCriticConfig(ppo_micro_batch_size_per_gpu=1, optim=McoreOptimizerConfig(lr=0.1))
108109
assert isinstance(megatron_config, CriticConfig)
109110
assert isinstance(megatron_config, McoreCriticConfig)
110111

111-
fsdp_config = FSDPCriticConfig(ppo_micro_batch_size_per_gpu=1, optim=optim)
112+
fsdp_config = FSDPCriticConfig(ppo_micro_batch_size_per_gpu=1, optim=FSDPOptimizerConfig(lr=0.1))
112113
assert isinstance(fsdp_config, CriticConfig)
113114
assert isinstance(fsdp_config, FSDPCriticConfig)
114115

115-
critic_config = CriticConfig(ppo_micro_batch_size_per_gpu=1, strategy="fsdp2", optim=optim)
116+
critic_config = CriticConfig(ppo_micro_batch_size_per_gpu=1, strategy="fsdp2", optim=OptimizerConfig(lr=0.1))
116117
assert isinstance(critic_config, CriticConfig)
117118
assert not isinstance(critic_config, McoreCriticConfig)
118119
assert not isinstance(critic_config, FSDPCriticConfig)
@@ -136,22 +137,21 @@ def test_config_dict_interface(self):
136137

137138
def test_frozen_fields_immutability(self):
138139
"""Test that frozen fields raise exceptions when modified after creation."""
139-
optim = OptimizerConfig(lr=0.1)
140-
critic_config = CriticConfig(ppo_micro_batch_size_per_gpu=1, strategy="fsdp2", optim=optim)
140+
critic_config = CriticConfig(ppo_micro_batch_size_per_gpu=1, strategy="fsdp2", optim=OptimizerConfig(lr=0.1))
141141
frozen_fields = ["rollout_n", "strategy", "cliprange_value"]
142142

143143
for field_name in frozen_fields:
144144
with pytest.raises((AttributeError, TypeError, ValueError)):
145145
setattr(critic_config, field_name, "modified_value")
146146

147-
megatron_config = McoreCriticConfig(ppo_micro_batch_size_per_gpu=1, optim=optim)
147+
megatron_config = McoreCriticConfig(ppo_micro_batch_size_per_gpu=1, optim=McoreOptimizerConfig(lr=0.1))
148148
megatron_frozen_fields = ["nccl_timeout", "load_weight", "data_loader_seed"]
149149

150150
for field_name in megatron_frozen_fields:
151151
with pytest.raises((AttributeError, TypeError, ValueError)):
152152
setattr(megatron_config, field_name, "modified_value")
153153

154-
fsdp_config = FSDPCriticConfig(ppo_micro_batch_size_per_gpu=1, optim=optim)
154+
fsdp_config = FSDPCriticConfig(ppo_micro_batch_size_per_gpu=1, optim=FSDPOptimizerConfig(lr=0.1))
155155
fsdp_frozen_fields = ["ulysses_sequence_parallel_size", "grad_clip"]
156156

157157
for field_name in fsdp_frozen_fields:
@@ -171,7 +171,7 @@ def test_batch_size_fields_modifiable(self):
171171
assert critic_config.ppo_micro_batch_size == 4
172172
assert critic_config.ppo_micro_batch_size_per_gpu == 2
173173

174-
fsdp_config = FSDPCriticConfig(ppo_micro_batch_size_per_gpu=1, optim=optim)
174+
fsdp_config = FSDPCriticConfig(ppo_micro_batch_size_per_gpu=1, optim=FSDPOptimizerConfig(lr=0.1))
175175

176176
fsdp_config.forward_micro_batch_size = 16
177177
fsdp_config.forward_micro_batch_size_per_gpu = 8
@@ -277,12 +277,11 @@ def test_micro_batch_size_divisibility_validation(self):
277277

278278
def test_fsdp_sequence_parallelism_validation(self):
279279
"""Test FSDP sequence parallelism validation in FSDPCriticConfig.__post_init__."""
280-
optim = OptimizerConfig(lr=0.1)
281280
valid_config = FSDPCriticConfig(
282281
ppo_micro_batch_size_per_gpu=2,
283282
ulysses_sequence_parallel_size=2,
284283
model={"use_remove_padding": True},
285-
optim=optim,
284+
optim=FSDPOptimizerConfig(lr=0.1),
286285
)
287286
assert valid_config.ulysses_sequence_parallel_size == 2
288287

@@ -293,13 +292,13 @@ def test_fsdp_sequence_parallelism_validation(self):
293292
ppo_micro_batch_size_per_gpu=2,
294293
ulysses_sequence_parallel_size=2,
295294
model={"use_remove_padding": False},
296-
optim=optim,
295+
optim=FSDPOptimizerConfig(lr=0.1),
297296
)
298297

299298
valid_config_no_sp = FSDPCriticConfig(
300299
ppo_micro_batch_size_per_gpu=2,
301300
ulysses_sequence_parallel_size=1,
302301
model={"use_remove_padding": False},
303-
optim=optim,
302+
optim=FSDPOptimizerConfig(lr=0.1),
304303
)
305304
assert valid_config_no_sp.ulysses_sequence_parallel_size == 1

tests/workers/critic/test_special_dp_critic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from transformers import AutoConfig
2323

2424
from verl import DataProto
25-
from verl.workers.config import FSDPCriticConfig, OptimizerConfig
25+
from verl.workers.config import FSDPCriticConfig, FSDPOptimizerConfig
2626
from verl.workers.config.critic import FSDPCriticModelCfg
2727
from verl.workers.config.engine import FSDPEngineConfig
2828
from verl.workers.fsdp_workers import CriticWorker
@@ -72,7 +72,7 @@ def setUp(self):
7272
use_dynamic_bsz=False,
7373
ulysses_sequence_parallel_size=1,
7474
rollout_n=1,
75-
optim=OptimizerConfig(lr=1e-6),
75+
optim=FSDPOptimizerConfig(lr=1e-6),
7676
model=FSDPCriticModelCfg(
7777
path="Qwen/Qwen2.5-0.5B-Instruct",
7878
tokenizer_path="Qwen/Qwen2.5-0.5B-Instruct",

verl/trainer/config/_generated_ppo_trainer.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ actor_rollout_ref:
77
actor:
88
optim:
99
_target_: verl.workers.config.FSDPOptimizerConfig
10+
optimizer: AdamW
11+
optimizer_impl: torch.optim
1012
lr: 1.0e-06
1113
lr_warmup_steps_ratio: 0.0
1214
total_training_steps: -1
@@ -20,6 +22,7 @@ actor_rollout_ref:
2022
num_cycles: 0.5
2123
lr_scheduler_type: constant
2224
warmup_style: null
25+
override_optimizer_config: null
2326
fsdp_config:
2427
_target_: verl.workers.config.FSDPEngineConfig
2528
wrap_policy:
@@ -307,6 +310,8 @@ data:
307310
critic:
308311
optim:
309312
_target_: verl.workers.config.FSDPOptimizerConfig
313+
optimizer: AdamW
314+
optimizer_impl: torch.optim
310315
lr: 1.0e-05
311316
lr_warmup_steps_ratio: 0.0
312317
total_training_steps: -1
@@ -320,6 +325,7 @@ critic:
320325
num_cycles: 0.5
321326
lr_scheduler_type: constant
322327
warmup_style: null
328+
override_optimizer_config: null
323329
model:
324330
fsdp_config:
325331
_target_: verl.workers.config.FSDPEngineConfig

verl/trainer/config/optim/fsdp.yaml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
# Target class for this configuration
22
_target_: verl.workers.config.FSDPOptimizerConfig
33

4+
# Optimizer class name (e.g., "AdamW", "AdamW8bit", "_AdamW", "Adam")
5+
optimizer: AdamW
6+
7+
# Module path to import optimizer
8+
# Examples: "torch.optim", "torchao.optim", "bitsandbytes.optim"
9+
optimizer_impl: torch.optim
10+
411
# Learning rate
512
lr: 1e-3
613

@@ -33,3 +40,11 @@ lr_scheduler_type: constant
3340

3441
# deprecated
3542
warmup_style: null
43+
44+
# Additional optimizer-specific keyword arguments
45+
# Example for torchao with bf16 stochastic rounding:
46+
# optimizer_impl: torchao.optim
47+
# optimizer: _AdamW
48+
# override_optimizer_config:
49+
# bf16_stochastic_round: true
50+
override_optimizer_config: null

verl/trainer/config/sft_trainer.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
defaults:
2+
- optim: fsdp
3+
- _self_
4+
15
data:
26
train_batch_size: 256
37
micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu
@@ -47,7 +51,7 @@ optim:
4751
lr: 1e-5
4852
betas: [0.9, 0.95]
4953
weight_decay: 0.01
50-
warmup_steps_ratio: 0.1
54+
lr_warmup_steps_ratio: 0.1
5155
clip_grad: 1.0
5256
lr_scheduler: cosine
5357
ulysses_sequence_parallel_size: 1

verl/trainer/fsdp_sft_trainer.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from omegaconf import DictConfig, OmegaConf
3535
from peft import LoraConfig, TaskType, get_peft_model
3636
from tensordict import TensorDict
37-
from torch import nn, optim
37+
from torch import nn
3838
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
3939
from torch.distributed.fsdp import CPUOffload, MixedPrecision, ShardingStrategy
4040
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
@@ -73,6 +73,7 @@
7373
get_ulysses_sequence_parallel_world_size,
7474
ulysses_pad_and_slice_inputs,
7575
)
76+
from verl.workers.config.optimizer import build_optimizer
7677
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
7778

7879
logger = logging.getLogger(__file__)
@@ -317,14 +318,7 @@ def _build_model_optimizer(self):
317318

318319
log_gpu_memory_usage("After FSDP wrapping", logger=logger)
319320

320-
self.optimizer = optim.AdamW(
321-
self.fsdp_model.parameters(),
322-
lr=self.config.optim.lr,
323-
betas=self.config.optim.betas,
324-
weight_decay=self.config.optim.weight_decay,
325-
eps=self.config.optim.get("eps", 1e-08),
326-
fused=True,
327-
)
321+
self.optimizer = build_optimizer(self.fsdp_model.parameters(), self.config.optim)
328322

329323
log_gpu_memory_usage("After initialize optimizer", logger=logger)
330324

@@ -337,7 +331,7 @@ def _build_model_optimizer(self):
337331
f"{self.config.trainer.total_epochs}, total number of steps {self.total_steps}"
338332
)
339333

340-
num_warmup_steps = int(self.total_steps * self.config.optim.warmup_steps_ratio)
334+
num_warmup_steps = int(self.total_steps * self.config.optim.lr_warmup_steps_ratio)
341335

342336
if not hasattr(self.config.optim, "lr_scheduler") or self.config.optim.lr_scheduler == "cosine":
343337
self.lr_scheduler = get_cosine_schedule_with_warmup(

0 commit comments

Comments
 (0)