Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions recipe/partial_rollout/config/partial_rollout_trainer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
hydra:
searchpath:
- file://verl/trainer/config
# - file://recipe/dapo/config

defaults:
- ppo_trainer
#- dapo_trainer
- _self_



data:
gen_batch_size: ${data.train_batch_size}

reward_model:
reward_manager: dapo
overlong_buffer:
enable: False # We try to avoid forgetting to set enable
len: 0
penalty_factor: 0.0
log: False

algorithm:
filter_groups:
_target_: verl.trainer.config.FilterGroupsConfig
enable: False # We try to avoid forgetting to set enable
metric: null # acc / score / seq_reward / seq_final_reward / ...
max_num_gen_batches: 0 # Non-positive values mean no upper limit

trainer:
project_name: verl-dapo



actor_rollout_ref:
# config for the rollout (only for resource isolation)
rollout:
# max rounds of rollout before the prompt is forced finished, 1 means no partial rollout
partial_rollout_max_split: -1

partial_rollout_mode: async

rollout_threshold_rate: 0.8
487 changes: 487 additions & 0 deletions recipe/partial_rollout/dapo_ray_trainer.py

Large diffs are not rendered by default.

195 changes: 195 additions & 0 deletions recipe/partial_rollout/main_dapo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
"""

import os
import socket

import hydra
import ray
from omegaconf import OmegaConf

from recipe.dapo.dapo_ray_trainer import RayDAPOTrainer
from verl.trainer.ppo.reward import load_reward_manager
from verl.utils.device import is_cuda_available

from .patch_manager import apply_partialrollout_patch


@hydra.main(config_path="config", config_name="partial_rollout_trainer", version_base=None)
def main(config):
run_ppo(config)


def run_ppo(config) -> None:
if not ray.is_initialized():
# this is for local ray cluster
default_runtime_env = {
"env_vars": {
"TOKENIZERS_PARALLELISM": "true",
"NCCL_DEBUG": "WARN",
"VLLM_LOGGING_LEVEL": "WARN",
}
}
ray_init_kwargs = config.ray_kwargs.get("ray_init", {})
runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {})
runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs)
ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env})
print(f"ray init kwargs: {ray_init_kwargs}")
ray.init(**OmegaConf.to_container(ray_init_kwargs))

try:
if (
is_cuda_available
and config.global_profiler.tool == "nsys"
and OmegaConf.select(config.global_profiler, "steps") is not None
and len(OmegaConf.select(config.global_profiler, "steps")) > 0
):
nsight_options = OmegaConf.to_container(
config.global_profiler.global_tool_config.nsys.controller_nsight_options
)
runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote()
else:
runner = TaskRunner.remote()
ray.get(runner.run.remote(config))
finally:
if ray.is_initialized():
ray.shutdown()


@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
class TaskRunner:
def run(self, config):
# print initial config
from pprint import pprint

apply_partialrollout_patch()
from omegaconf import OmegaConf

from verl.utils.fs import copy_to_local

print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}")

pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
OmegaConf.resolve(config)

# download the checkpoint from hdfs
local_path = copy_to_local(config.actor_rollout_ref.model.path)

# instantiate tokenizer
from verl.utils import hf_processor, hf_tokenizer

trust_remote_code = config.data.get("trust_remote_code", False)
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
# used for multimodal LLM, could be none
processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)

from verl.single_controller.ray import RayWorkerGroup

# define worker classes
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
assert config.critic.strategy in {"fsdp", "fsdp2"}

from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker

ray_worker_group_cls = RayWorkerGroup

elif config.actor_rollout_ref.actor.strategy == "megatron":
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker

ray_worker_group_cls = RayWorkerGroup

else:
raise NotImplementedError

from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role

role_worker_mapping = {
Role.ActorRollout: ray.remote(ActorRolloutRefWorker),
Role.Critic: ray.remote(CriticWorker),
}

global_pool_id = "global_pool"
resource_pool_spec = {
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
}
mapping = {
Role.ActorRollout: global_pool_id,
Role.Critic: global_pool_id,
}

# we should adopt a multi-source reward function here
# - for rule-based rm, we directly call a reward score
# - for model-based rm, we call a model
# - for code related prompt, we send to a sandbox if there are test cases
# - finally, we combine all the rewards together
# - The reward type depends on the tag of the data
if config.reward_model.enable:
if config.reward_model.strategy in {"fsdp", "fsdp2"}:
from verl.workers.fsdp_workers import RewardModelWorker
elif config.reward_model.strategy == "megatron":
from verl.workers.megatron_workers import RewardModelWorker
else:
raise NotImplementedError
role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
mapping[Role.RewardModel] = global_pool_id

# reference model
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
mapping[Role.RefPolicy] = global_pool_id

reward_fn = load_reward_manager(
config,
tokenizer,
0,
max_resp_len=config.data.max_response_length,
overlong_buffer_cfg=config.reward_model.overlong_buffer,
)
if (
getattr(config.actor_rollout_ref.rollout, "partial_rollout_mode", None) == "async"
and getattr(config.actor_rollout_ref.rollout, "partial_rollout_max_split", -1) > 0
):
from recipe.partial_rollout.ray_trainer import AggregatorActor

role_worker_mapping[Role.Aggregator] = ray.remote(AggregatorActor)
# Note that we always use function-based RM for validation
val_reward_fn = load_reward_manager(
config,
tokenizer,
1,
max_resp_len=config.data.max_response_length,
overlong_buffer_cfg=config.reward_model.overlong_buffer,
)
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
trainer = RayDAPOTrainer(
config=config,
tokenizer=tokenizer,
processor=processor,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn,
)
trainer.init_workers()
trainer.fit()


if __name__ == "__main__":
main()
Loading