diff --git a/recipe/tppo/config/tppo_trainer.yaml b/recipe/tppo/config/tppo_trainer.yaml new file mode 100644 index 00000000000..67cc7b8aaa8 --- /dev/null +++ b/recipe/tppo/config/tppo_trainer.yaml @@ -0,0 +1,57 @@ +hydra: + searchpath: + - file://verl/trainer/config + +defaults: + - ppo_trainer + - _self_ + +algorithm: + all_samples_with_grad: True + all_samples_with_grad_sync: True + use_variable_lambda: True + variable_lambda_scalar: 0.05 + use_separate_critic_lam: True + critic_lam: 1.0 + add_eos: False + rollout_pool: + strategy: v1 + min_score: -1 + max_score: 1 + +data: + actor_training_batch_size: 510 + window_response_length: 8192 + answer_key: answer + +actor_rollout_ref: + actor: + loss_agg_mode: batch + window_response_length: ${data.window_response_length} + lm_loss_weight: 0.1 + scale_pg_by_local_kl: False + scale_pg_by_kl: False + + rollout: + train_generate_kwargs: + max_new_tokens: 8192 + num_bon: 16 + bon_strategy: all + +critic: + cliprange_value_low: 0.5 + cliprange_value_high: 0.6 + optim: + lr_warmup_steps: 20 + +reward_model: + delete_eos: False + mean: 0.0 + std: 1.0 + use_last_response: False + punish_format: False + format_punish_score: -0.5 + add_int_verify: False + strict_box_verify: False + need_punish_duplicate: True + punish_score: \'rule-lighteval/MATH_v2:-1\' diff --git a/recipe/tppo/main_tppo.py b/recipe/tppo/main_tppo.py new file mode 100644 index 00000000000..9e5bbf82449 --- /dev/null +++ b/recipe/tppo/main_tppo.py @@ -0,0 +1,277 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. +""" + +import os +import socket + +import hydra +import ray +from omegaconf import OmegaConf + +from recipe.tppo.tppo_trainer import RayPPOTrainer + + +@hydra.main(config_path="config", config_name="tppo_trainer", version_base=None) +def main(config): + run_ppo(config) + + +# Define a function to run the PPO-like training process +def run_ppo(config) -> None: + # Check if Ray is not initialized + if not ray.is_initialized(): + # Initialize Ray with a local cluster configuration + # Set environment variables in the runtime environment to control tokenizer parallelism, + # NCCL debug level, VLLM logging level, and allow runtime LoRA updating + # `num_cpus` specifies the number of CPU cores Ray can use, obtained from the configuration + ray.init( + runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN", "VLLM_ALLOW_RUNTIME_LORA_UPDATING": "true"}}, + num_cpus=config.ray_init.num_cpus, + ) + + # Create a remote instance of the TaskRunner class, and + # Execute the `run` method of the TaskRunner instance remotely and wait for it to complete + if OmegaConf.select(config.trainer, "profile_steps") is not None and len(OmegaConf.select(config.trainer, "profile_steps")) > 0: + nsight_options = OmegaConf.to_container(config.trainer.controller_nsight_options) + runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote() + else: + runner = TaskRunner.remote() + ray.get(runner.run.remote(config)) + + # [Optional] get the path of the timeline trace file from the configuration, default to None + # This file is used for performance analysis + timeline_json_file = config.ray_init.get("timeline_json_file", None) + if timeline_json_file: + ray.timeline(filename=timeline_json_file) + + +@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head +class TaskRunner: + def run(self, config): + # Print the initial configuration. `resolve=True` will evaluate symbolic values. + from pprint import pprint + + from omegaconf import OmegaConf + + from verl.utils.fs import copy_to_local + + print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}") + + pprint(OmegaConf.to_container(config, resolve=True)) + + OmegaConf.resolve(config) + + # Download the checkpoint from HDFS to the local machine. + # `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on + local_path = copy_to_local(config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False)) + + # Instantiate the tokenizer and processor. + from verl.utils import hf_processor, hf_tokenizer + + trust_remote_code = config.data.get("trust_remote_code", False) + tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) + # Used for multimodal LLM, could be None + processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True) + + # Version validation for vllm. + if config.actor_rollout_ref.rollout.name in ["vllm"]: + from verl.utils.vllm_utils import is_version_ge + + if config.actor_rollout_ref.model.get("lora_rank", 0) > 0: + if not is_version_ge(pkg="vllm", minver="0.7.3"): + raise NotImplementedError("PPO LoRA is not supported before vllm 0.7.3") + + # Define worker classes based on the actor strategy. + if config.actor_rollout_ref.actor.strategy in ["fsdp", "fsdp2"]: + assert config.critic.strategy in ["fsdp", "fsdp2"] + from verl.single_controller.ray import RayWorkerGroup + from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker + + actor_rollout_cls = AsyncActorRolloutRefWorker if config.actor_rollout_ref.rollout.mode == "async" else ActorRolloutRefWorker + ray_worker_group_cls = RayWorkerGroup + + elif config.actor_rollout_ref.actor.strategy == "megatron": + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup + from verl.workers.megatron_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker + + actor_rollout_cls = AsyncActorRolloutRefWorker if config.actor_rollout_ref.rollout.mode == "async" else ActorRolloutRefWorker + ray_worker_group_cls = NVMegatronRayWorkerGroup + + else: + raise NotImplementedError + + from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role + + # Map roles to their corresponding remote worker classes. + role_worker_mapping = { + Role.ActorRollout: ray.remote(actor_rollout_cls), + Role.Critic: ray.remote(CriticWorker), + } + + # Define the resource pool specification. + # Map roles to the resource pool. + global_pool_id = "global_pool" + resource_pool_spec = { + global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, + } + mapping = { + Role.ActorRollout: global_pool_id, + Role.Critic: global_pool_id, + } + + # We should adopt a multi-source reward function here: + # - for rule-based rm, we directly call a reward score + # - for model-based rm, we call a model + # - for code related prompt, we send to a sandbox if there are test cases + # finally, we combine all the rewards together + # The reward type depends on the tag of the data + if config.reward_model.enable: + if config.reward_model.strategy in ["fsdp", "fsdp2"]: + from verl.workers.fsdp_workers import RewardModelWorker + elif config.reward_model.strategy == "megatron": + from verl.workers.megatron_workers import RewardModelWorker + else: + raise NotImplementedError + role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) + mapping[Role.RewardModel] = global_pool_id + + # Add a reference policy worker if KL loss or KL reward is used. + if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: + role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker) + mapping[Role.RefPolicy] = global_pool_id + + from recipe.tppo.tppo_reward_manager import TPPORewardManager + from verl.utils.reward_score import default_compute_score + # Load the reward manager for training and validation. + reward_fn = TPPORewardManager( + config=config, + tokenizer=tokenizer, + num_examine=0, + compute_score=default_compute_score, + reward_fn_key=config.data.reward_fn_key, + **config.reward_model.get("reward_kwargs", {}) + ) + val_reward_fn = TPPORewardManager( + config=config, + tokenizer=tokenizer, + num_examine=1, + compute_score=default_compute_score, + reward_fn_key=config.data.reward_fn_key, + **config.reward_model.get("reward_kwargs", {}) + ) + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + + from verl.utils.dataset.rl_dataset import collate_fn + + # Create training and validation datasets. + train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor) + val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor) + train_sampler = create_rl_sampler(config.data, train_dataset) + + # Initialize the PPO trainer. + trainer = RayPPOTrainer( + config=config, + tokenizer=tokenizer, + processor=processor, + role_worker_mapping=role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + reward_fn=reward_fn, + val_reward_fn=val_reward_fn, + train_dataset=train_dataset, + val_dataset=val_dataset, + collate_fn=collate_fn, + train_sampler=train_sampler, + device_name=config.trainer.device, + ) + # Initialize the workers of the trainer. + trainer.init_workers() + # Start the training process. + trainer.fit() + + +def create_rl_dataset(data_paths, data_config, tokenizer, processor): + """Create a dataset. + + Arguments: + data_paths: List of paths to data files. + data_config: The data config. + tokenizer (Tokenizer): The tokenizer. + processor (Processor): The processor. + + Returns: + dataset (Dataset): The dataset. + """ + from torch.utils.data import Dataset + + from verl.utils.dataset.rl_dataset import RLHFDataset + + # Check if a custom dataset class is specified in the data configuration + # and if the path to the custom class is provided + if "custom_cls" in data_config and data_config.custom_cls.get("path", None) is not None: + from verl.utils.import_utils import load_extern_type + + # Dynamically load the custom dataset class + dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name) + # Verify that the custom dataset class inherits from torch.utils.data.Dataset + if not issubclass(dataset_cls, Dataset): + raise TypeError(f"The custom dataset class '{data_config.custom_cls.name}' from '{data_config.custom_cls.path}' must inherit from torch.utils.data.Dataset") + else: + # Use the default RLHFDataset class if no custom class is specified + dataset_cls = RLHFDataset + print(f"Using dataset class: {dataset_cls.__name__}") + + # Instantiate the dataset using the determined dataset class + dataset = dataset_cls( + data_files=data_paths, + tokenizer=tokenizer, + processor=processor, + config=data_config, + ) + + return dataset + + +def create_rl_sampler(data_config, dataset): + """Create a sampler for the dataset. + + Arguments: + data_config: The data config. + dataset (Dataset): The dataset. + + Returns: + sampler (Sampler): The sampler. + """ + import torch + from torch.utils.data import RandomSampler, SequentialSampler + + # Use a sampler to facilitate checkpoint resumption. + # If shuffling is enabled in the data configuration, create a random sampler. + if data_config.shuffle: + train_dataloader_generator = torch.Generator() + train_dataloader_generator.manual_seed(data_config.get("seed", 1)) + sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator) + else: + # If shuffling is disabled, use a sequential sampler to iterate through the dataset in order. + sampler = SequentialSampler(data_source=dataset) + + return sampler + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/recipe/tppo/run_ppo.sh b/recipe/tppo/run_ppo.sh new file mode 100644 index 00000000000..0e96bc9ff90 --- /dev/null +++ b/recipe/tppo/run_ppo.sh @@ -0,0 +1,205 @@ +set -x + +# EXP=0514_qwen_ppo_partial_rollout_values_mix_diff_bo16_v2 +# ckpt和路径 +MODEL_PATH=/file_system/common-models +#MODEL_NAME=Qwen2.5-7B-Instruct +MODEL_NAME=Qwen/Qwen3-4B + +SFT_MODEL_PATH=$MODEL_PATH/$MODEL_NAME +RM_MODEL_PATH=$MODEL_PATH/$MODEL_NAME +TRAIN_FILE=$HOME/data/gsm8k/train.parquet +TEST_FILE=$HOME/data/gsm8k/test.parquet + +NNODES=1 + +# SFT_MODEL_PATH=hdfs://haruna/home/byte_data_seed/ssd_hldy/user/fantiantian.tt/fantiantian/alphaseed_workspace/grpo/alpha_seed_DeepSeek-R1-Distill-Qwen-7B +# RM_MODEL_PATH=hdfs://haruna/home/byte_data_seed/ssd_hldy/user/fantiantian.tt/fantiantian/alphaseed_workspace/grpo/alpha_seed_qwen7B_SFT32_MATH1222a14p123_ppo_rm_ntk20_clip02_lam998_priority0124/checkpoints/global_step_25/critic/huggingface +# TRAIN_FILE=hdfs://haruna/home/byte_data_seed/lf_lq/user/qiying.01/datasets/alphaseed/release1.5/0224d1.parquet +# TEST_FILE=hdfs://haruna/home/byte_data_seed/lf_lq/user/qiying.01/datasets/alphaseed/release1.5/0224d1_eval.parquet +# # default_hdfs_dir=hdfs://haruna/home/byte_data_seed/ssd_hldy/user/zhouht.00/tppo/qwen_7b_test + +chat_template=raw + +# 训练长度 +max_prompt_length=2048 +max_response_length=24576 +max_num_batched_tokens=32768 +# batch size && 训练epoch +# step = total_epochs * 4? +train_batch_size=1536 +ppo_epochs=1 +ppo_mini_batch_size=512 +val_batch_size=960 +total_epochs=50 +test_freq=5 +save_freq=5 +# 算法相关的参数 +actor_lr=8e-7 +critic_lr=2e-6 +lr_warmup_steps=20 # 10 / (train_size * total_epochs / train_batch_size) +kl_coef=0.0 +use_last_response=False +use_ref_answer=False +gae_gamma=1.0 +gae_lam=0.95 +force_append_eos=False +upgo_loss_weight=0.0 +upgo_loss_version=1 +clip_ratio_low=0.2 +clip_ratio_high=0.28 +cliprange_value_low=0.5 +cliprange_value_high=0.6 +loss_agg_mode='token-mean' # token-mean, seq-mean-token-sum, seq-mean-token-mean, seq-mean-token-sum-norm +clip_ratio2=10.0 +kl_penalty=low_var_kl +weight_decay=0.1 +adv_estimator=gae +kl_loss_weight=0.0 +num_bon=16 +bon_strategy=all +# tracking实验名 +project_name='msft' +experiment_name=${EXP} +# 工程参数 +gen_micro_batch_size=512 # use_dynamic_bsz=True时仍然生效 +infer_micro_batch_size=512 # use_dynamic_bsz=True时不生效 +train_micro_batch_size=64 # use_dynamic_bsz=True时不生效 +actor_sp_size=4 +critic_sp_size=4 +ref_sp_size=4 +reward_sp_size=4 +num_attention_heads=28 +use_dynamic_bsz=True +actor_ppo_max_token_len=50000 +critic_ppo_max_token_len=50000 +infer_ppo_max_token_len=50000 +fsdp_size=30 +gen_tp=4 +critic_tp=1 + +python3 -m verl.trainer.main_ppo \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + data.shuffle=False \ + critic.ppo_epochs=${ppo_epochs} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_weight} \ + actor_rollout_ref.actor.shuffle=False \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.kl_penalty=${kl_penalty} \ + data.train_files=${TRAIN_FILE} \ + data.val_files=${TEST_FILE} \ + data.prompt_key=prompt \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_batch_size} \ + data.val_batch_size=${val_batch_size} \ + data.truncation='left' \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path=${SFT_MODEL_PATH} \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.optim.lr=${actor_lr} \ + actor_rollout_ref.actor.optim.lr_warmup_steps=${lr_warmup_steps} \ + actor_rollout_ref.actor.ppo_mini_batch_size=${ppo_mini_batch_size} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.entropy_coeff=0.0 \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=${clip_ratio2} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${actor_sp_size} \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.name=vllm \ + +actor_rollout_ref.rollout.use_vllm=True \ + +actor_rollout_ref.rollout.num_slots=256 \ + +actor_rollout_ref.rollout.slot_block_size=512 \ + actor_rollout_ref.rollout.max_num_batched_tokens=${max_num_batched_tokens} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.optim.weight_decay=${weight_decay} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${ref_sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \ + critic.use_dynamic_bsz=${use_dynamic_bsz} \ + critic.ppo_max_token_len_per_gpu=${critic_ppo_max_token_len} \ + critic.optim.lr=${critic_lr} \ + critic.optim.lr_warmup_steps=${lr_warmup_steps} \ + critic.model.path=${RM_MODEL_PATH} \ + critic.model.enable_gradient_checkpointing=True \ + critic.ppo_micro_batch_size=${train_micro_batch_size} \ + critic.model.fsdp_config.param_offload=False \ + critic.ulysses_sequence_parallel_size=${critic_sp_size} \ + +critic.model.override_config.attention_dropout=0. \ + +critic.model.override_config.embd_pdrop=0. \ + +critic.model.override_config.resid_pdrop=0. \ + critic.model.use_remove_padding=True \ + reward_model.enable=False \ + reward_model.model.input_tokenizer=null \ + reward_model.model.path=${RM_MODEL_PATH} \ + reward_model.micro_batch_size=${infer_micro_batch_size} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.gamma=${gae_gamma} \ + algorithm.lam=${gae_lam} \ + trainer.critic_warmup=10 \ + trainer.logger=['console'] \ + trainer.project_name=${project_name} \ + trainer.experiment_name=${experiment_name} \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=${NNODES} \ + trainer.save_freq=${save_freq} \ + trainer.test_freq=${test_freq} \ + trainer.total_epochs=${total_epochs} \ + trainer.val_only=False \ + + + data.actor_training_batch_size=510 \ + algorithm.all_samples_with_grad=True \ + algorithm.all_samples_with_grad_sync=True \ + critic.cliprange_value_low=${cliprange_value_low} \ + critic.cliprange_value_high=${cliprange_value_high} \ + data.window_response_length=8192 \ + actor_rollout_ref.rollout.train_generate_kwargs.max_new_tokens=8192 \ + actor_rollout_ref.actor.window_response_length=8192 \ + actor_rollout_ref.actor.lm_loss_weight=0.1 \ + algorithm.use_variable_lambda=True \ + algorithm.variable_lambda_scalar=0.05 \ + algorithm.use_separate_critic_lam=True \ + algorithm.critic_lam=1.0 \ + +algorithm.use_actual_values=True \ + +algorithm.adv_whiten=True \ + +algorithm.adv_bias=0.0 \ + +algorithm.adv_clamp=True \ + reward_model.delete_eos=False \ + algorithm.add_eos=False \ + +algorithm.force_append_eos=${false_append_eos} \ + actor_rollout_ref.actor.ppo_epochs=${ppo_epochs} \ + actor_rollout_ref.actor.scale_pg_by_local_kl=False \ + actor_rollout_ref.rollout.num_bon=${num_bon} \ + actor_rollout_ref.rollout.bon_strategy=${bon_strategy} \ + data.answer_key=answer \ + +actor_rollout_ref.model.use_rmpad=True \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.actor.scale_pg_by_kl=False \ + reward_model.mean=0.0 \ + reward_model.std=1.0 \ + reward_model.use_last_response=${use_last_response} \ + reward_model.punish_format=False \ + reward_model.format_punish_score=-0.1 \ + + + reward_model.add_int_verify=False \ + reward_model.strict_box_verify=False \ + reward_model.need_punish_duplicate=True \ + reward_model.punish_score=\'rule-lighteval/MATH_v2:-1\' + + trainer.default_hdfs_dir=${default_hdfs_dir} \ + actor_rollout_ref.model.external_lib=seed_models \ + + critic.model.external_lib=seed_models \ \ No newline at end of file diff --git a/recipe/tppo/run_tppo.sh b/recipe/tppo/run_tppo.sh new file mode 100644 index 00000000000..b9961c96da1 --- /dev/null +++ b/recipe/tppo/run_tppo.sh @@ -0,0 +1,208 @@ +set -x + +# EXP=0514_qwen_ppo_partial_rollout_values_mix_diff_bo16_v2 +# ckpt和路径 +MODEL_PATH=/file_system/common-models +#MODEL_NAME=Qwen2.5-7B-Instruct +MODEL_NAME=Qwen/Qwen3-4B +ckpts_home="/file_system/dhl/save_ckpt/run_tppo/checkpoints" + +SFT_MODEL_PATH=$MODEL_PATH/$MODEL_NAME +RM_MODEL_PATH=$MODEL_PATH/$MODEL_NAME +TRAIN_FILE=$HOME/data/gsm8k/train.parquet +TEST_FILE=$HOME/data/gsm8k/test.parquet + +NNODES=1 + +# SFT_MODEL_PATH=hdfs://haruna/home/byte_data_seed/ssd_hldy/user/fantiantian.tt/fantiantian/alphaseed_workspace/grpo/alpha_seed_DeepSeek-R1-Distill-Qwen-7B +# RM_MODEL_PATH=hdfs://haruna/home/byte_data_seed/ssd_hldy/user/fantiantian.tt/fantiantian/alphaseed_workspace/grpo/alpha_seed_qwen7B_SFT32_MATH1222a14p123_ppo_rm_ntk20_clip02_lam998_priority0124/checkpoints/global_step_25/critic/huggingface +# TRAIN_FILE=hdfs://haruna/home/byte_data_seed/lf_lq/user/qiying.01/datasets/alphaseed/release1.5/0224d1.parquet +# TEST_FILE=hdfs://haruna/home/byte_data_seed/lf_lq/user/qiying.01/datasets/alphaseed/release1.5/0224d1_eval.parquet +# # default_hdfs_dir=hdfs://haruna/home/byte_data_seed/ssd_hldy/user/zhouht.00/tppo/qwen_7b_test + +chat_template=raw + +# 训练长度 +max_prompt_length=2048 +max_response_length=8192 +max_num_batched_tokens=32768 +# batch size && 训练epoch +# step = total_epochs * 4? +train_batch_size=1536 +ppo_epochs=1 +ppo_mini_batch_size=512 +val_batch_size=960 +total_epochs=50 +# 测试精度时只需要关注 test_freq,不用开 save +test_freq=5 +save_freq=200 +# 算法相关的参数 +actor_lr=8e-7 +critic_lr=2e-6 +#lr_warmup_steps=20 # 10 / (train_size * total_epochs / train_batch_size) +lr_warmup_steps=2 +kl_coef=0.0 +use_last_response=False +use_ref_answer=False +gae_gamma=1.0 +gae_lam=0.95 +force_append_eos=False +upgo_loss_weight=0.0 +upgo_loss_version=1 +clip_ratio_low=0.2 +clip_ratio_high=0.28 +cliprange_value_low=0.5 +cliprange_value_high=0.6 +loss_agg_mode='batch-sum' # token-mean, seq-mean-token-sum, seq-mean-token-mean, seq-mean-token-sum-norm, minibatch-sum, batch-sum +clip_ratio2=10.0 +kl_penalty=low_var_kl +weight_decay=0.1 +adv_estimator=gae-trunc +kl_loss_weight=0.0 +num_bon=16 +bon_strategy=all +# tracking实验名 +project_name='msft' +experiment_name=${EXP} +# 工程参数 +gen_micro_batch_size=512 # use_dynamic_bsz=True时仍然生效 +infer_micro_batch_size=512 # use_dynamic_bsz=True时不生效 +train_micro_batch_size=64 # use_dynamic_bsz=True时不生效 +actor_sp_size=4 +critic_sp_size=4 +ref_sp_size=4 +reward_sp_size=4 +num_attention_heads=28 +use_dynamic_bsz=True +actor_ppo_max_token_len=50000 +critic_ppo_max_token_len=50000 +infer_ppo_max_token_len=50000 +fsdp_size=30 +gen_tp=4 +critic_tp=1 + +python3 -m recipe.tppo.main_tppo \ + data.actor_training_batch_size=510 \ + algorithm.all_samples_with_grad=True \ + algorithm.all_samples_with_grad_sync=True \ + critic.cliprange_value_low=${cliprange_value_low} \ + critic.cliprange_value_high=${cliprange_value_high} \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.rollout.train_generate_kwargs.max_new_tokens=8192 \ + data.window_response_length=8192 \ + actor_rollout_ref.actor.window_response_length=8192 \ + actor_rollout_ref.actor.lm_loss_weight=0.1 \ + algorithm.use_variable_lambda=True \ + algorithm.variable_lambda_scalar=0.05 \ + algorithm.use_separate_critic_lam=True \ + algorithm.critic_lam=1.0 \ + +algorithm.use_actual_values=True \ + +algorithm.adv_whiten=True \ + +algorithm.adv_bias=0.0 \ + +algorithm.adv_clamp=True \ + reward_model.delete_eos=False \ + data.shuffle=False \ + algorithm.add_eos=False \ + +algorithm.force_append_eos=${false_append_eos} \ + actor_rollout_ref.actor.ppo_epochs=${ppo_epochs} \ + critic.ppo_epochs=${ppo_epochs} \ + actor_rollout_ref.actor.scale_pg_by_local_kl=False \ + actor_rollout_ref.rollout.num_bon=${num_bon} \ + actor_rollout_ref.rollout.bon_strategy=${bon_strategy} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_weight} \ + actor_rollout_ref.actor.shuffle=False \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.kl_penalty=${kl_penalty} \ + data.train_files=${TRAIN_FILE} \ + data.val_files=${TEST_FILE} \ + data.prompt_key=prompt \ + data.answer_key=answer \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_batch_size} \ + data.val_batch_size=${val_batch_size} \ + data.truncation='left' \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path=${SFT_MODEL_PATH} \ + +actor_rollout_ref.model.use_rmpad=True \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.optim.lr=${actor_lr} \ + actor_rollout_ref.actor.optim.lr_warmup_steps=${lr_warmup_steps} \ + actor_rollout_ref.actor.ppo_mini_batch_size=${ppo_mini_batch_size} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.entropy_coeff=0.0 \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=${clip_ratio2} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${actor_sp_size} \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.name=vllm \ + +actor_rollout_ref.rollout.use_vllm=True \ + +actor_rollout_ref.rollout.num_slots=256 \ + +actor_rollout_ref.rollout.slot_block_size=512 \ + actor_rollout_ref.rollout.max_num_batched_tokens=${max_num_batched_tokens} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.scale_pg_by_kl=False \ + actor_rollout_ref.actor.optim.weight_decay=${weight_decay} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${ref_sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \ + critic.use_dynamic_bsz=${use_dynamic_bsz} \ + critic.ppo_max_token_len_per_gpu=${critic_ppo_max_token_len} \ + critic.optim.lr=${critic_lr} \ + critic.optim.lr_warmup_steps=${lr_warmup_steps} \ + critic.model.path=${RM_MODEL_PATH} \ + critic.model.enable_gradient_checkpointing=True \ + critic.ppo_micro_batch_size=${train_micro_batch_size} \ + critic.model.fsdp_config.param_offload=False \ + critic.ulysses_sequence_parallel_size=${critic_sp_size} \ + +critic.tp_size=${critic_tp} \ + +critic.model.override_config.attention_dropout=0. \ + +critic.model.override_config.embd_pdrop=0. \ + +critic.model.override_config.resid_pdrop=0. \ + critic.model.use_remove_padding=True \ + +critic.use_rmpad=True \ + reward_model.enable=False \ + reward_model.model.input_tokenizer=null \ + reward_model.model.path=${RM_MODEL_PATH} \ + reward_model.micro_batch_size=${infer_micro_batch_size} \ + reward_model.mean=0.0 \ + reward_model.std=1.0 \ + reward_model.use_last_response=${use_last_response} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.gamma=${gae_gamma} \ + algorithm.lam=${gae_lam} \ + trainer.critic_warmup=10 \ + trainer.logger=['console'] \ + trainer.project_name=${project_name} \ + trainer.experiment_name=${experiment_name} \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=${NNODES} \ + trainer.save_freq=${save_freq} \ + trainer.test_freq=${test_freq} \ + trainer.total_epochs=${total_epochs} \ + trainer.val_only=False \ + reward_model.punish_format=False \ + reward_model.format_punish_score=-0.1 \ + trainer.default_local_dir="${ckpts_home}" \ + + reward_model.add_int_verify=False \ + reward_model.strict_box_verify=False \ + reward_model.need_punish_duplicate=True \ + reward_model.punish_score=\'rule-lighteval/MATH_v2:-1\' + + trainer.default_hdfs_dir=${default_hdfs_dir} \ + actor_rollout_ref.model.external_lib=seed_models \ + + critic.model.external_lib=seed_models \ \ No newline at end of file diff --git a/recipe/tppo/run_tppo_tiny.sh b/recipe/tppo/run_tppo_tiny.sh new file mode 100644 index 00000000000..0834a1ffaa5 --- /dev/null +++ b/recipe/tppo/run_tppo_tiny.sh @@ -0,0 +1,193 @@ +set -x + +# EXP=0514_qwen_ppo_partial_rollout_values_mix_diff_bo16_v2 +# ckpt和路径 +MODEL_PATH=/file_system/common-models +#MODEL_NAME=Qwen2.5-7B-Instruct +MODEL_NAME=Qwen/Qwen3-4B + +SFT_MODEL_PATH=$MODEL_PATH/$MODEL_NAME +RM_MODEL_PATH=$MODEL_PATH/$MODEL_NAME +TRAIN_FILE=$HOME/data/gsm8k/train.parquet +TEST_FILE=$HOME/data/gsm8k/test.parquet + +NNODES=1 + +# SFT_MODEL_PATH=hdfs://haruna/home/byte_data_seed/ssd_hldy/user/fantiantian.tt/fantiantian/alphaseed_workspace/grpo/alpha_seed_DeepSeek-R1-Distill-Qwen-7B +# RM_MODEL_PATH=hdfs://haruna/home/byte_data_seed/ssd_hldy/user/fantiantian.tt/fantiantian/alphaseed_workspace/grpo/alpha_seed_qwen7B_SFT32_MATH1222a14p123_ppo_rm_ntk20_clip02_lam998_priority0124/checkpoints/global_step_25/critic/huggingface +# TRAIN_FILE=hdfs://haruna/home/byte_data_seed/lf_lq/user/qiying.01/datasets/alphaseed/release1.5/0224d1.parquet +# TEST_FILE=hdfs://haruna/home/byte_data_seed/lf_lq/user/qiying.01/datasets/alphaseed/release1.5/0224d1_eval.parquet +# # default_hdfs_dir=hdfs://haruna/home/byte_data_seed/ssd_hldy/user/zhouht.00/tppo/qwen_7b_test + +chat_template=raw + +# 训练长度 +max_prompt_length=2048 +max_response_length=8192 +max_num_batched_tokens=32768 +# batch size && 训练epoch +train_batch_size=1536 +ppo_epochs=1 +ppo_mini_batch_size=512 +val_batch_size=960 +total_epochs=2 +test_freq=5 +save_freq=5 +# 算法相关的参数 +actor_lr=8e-7 +critic_lr=2e-6 +lr_warmup_steps=2 # 10 / (train_size * total_epochs / train_batch_size) ?? +kl_coef=0.0 +use_last_response=False +use_ref_answer=False +gae_gamma=1.0 +gae_lam=0.95 +force_append_eos=False +upgo_loss_weight=0.0 +upgo_loss_version=1 +clip_ratio_low=0.2 +clip_ratio_high=0.28 +cliprange_value_low=0.5 +cliprange_value_high=0.6 +loss_agg_mode='batch-sum' # token-mean, seq-mean-token-sum, seq-mean-token-mean, seq-mean-token-sum-norm, minibatch-sum, batch-sum +clip_ratio2=10.0 +kl_penalty=low_var_kl +weight_decay=0.1 +adv_estimator=gae-trunc +kl_loss_weight=0.0 +num_bon=16 +bon_strategy=all +# tracking实验名 +project_name='msft' +experiment_name=${EXP} +# 工程参数 +gen_micro_batch_size=512 # use_dynamic_bsz=True时仍然生效 +infer_micro_batch_size=512 # use_dynamic_bsz=True时不生效 +train_micro_batch_size=64 # use_dynamic_bsz=True时不生效 +actor_sp_size=4 +critic_sp_size=4 +ref_sp_size=4 +reward_sp_size=4 +num_attention_heads=28 +use_dynamic_bsz=True +actor_ppo_max_token_len=50000 +critic_ppo_max_token_len=50000 +infer_ppo_max_token_len=50000 +fsdp_size=30 +gen_tp=4 +critic_tp=1 + +python3 -m recipe.tppo.main_tppo \ + data.actor_training_batch_size=510 \ + algorithm.all_samples_with_grad=True \ + algorithm.all_samples_with_grad_sync=True \ + critic.cliprange_value_low=${cliprange_value_low} \ + critic.cliprange_value_high=${cliprange_value_high} \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.rollout.train_generate_kwargs.max_new_tokens=8192 \ + data.window_response_length=8192 \ + actor_rollout_ref.actor.window_response_length=8192 \ + actor_rollout_ref.actor.lm_loss_weight=0.1 \ + algorithm.use_variable_lambda=True \ + algorithm.variable_lambda_scalar=0.05 \ + algorithm.use_separate_critic_lam=True \ + algorithm.critic_lam=1.0 \ + +algorithm.use_actual_values=True \ + +algorithm.adv_whiten=True \ + +algorithm.adv_bias=0.0 \ + +algorithm.adv_clamp=True \ + reward_model.delete_eos=False \ + data.shuffle=False \ + algorithm.add_eos=False \ + +algorithm.force_append_eos=${false_append_eos} \ + actor_rollout_ref.actor.ppo_epochs=${ppo_epochs} \ + critic.ppo_epochs=${ppo_epochs} \ + actor_rollout_ref.actor.scale_pg_by_local_kl=False \ + actor_rollout_ref.rollout.num_bon=${num_bon} \ + actor_rollout_ref.rollout.bon_strategy=${bon_strategy} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_weight} \ + actor_rollout_ref.actor.shuffle=False \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.kl_penalty=${kl_penalty} \ + data.train_files=${TRAIN_FILE} \ + data.val_files=${TEST_FILE} \ + data.prompt_key=prompt \ + data.answer_key=answer \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_batch_size} \ + data.val_batch_size=${val_batch_size} \ + data.truncation='left' \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path=${SFT_MODEL_PATH} \ + +actor_rollout_ref.model.use_rmpad=True \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.optim.lr=${actor_lr} \ + actor_rollout_ref.actor.optim.lr_warmup_steps=${lr_warmup_steps} \ + actor_rollout_ref.actor.ppo_mini_batch_size=${ppo_mini_batch_size} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.entropy_coeff=0.0 \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=${clip_ratio2} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${actor_sp_size} \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.name=vllm \ + +actor_rollout_ref.rollout.use_vllm=True \ + +actor_rollout_ref.rollout.num_slots=256 \ + +actor_rollout_ref.rollout.slot_block_size=512 \ + actor_rollout_ref.rollout.max_num_batched_tokens=${max_num_batched_tokens} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.scale_pg_by_kl=False \ + actor_rollout_ref.actor.optim.weight_decay=${weight_decay} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${ref_sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \ + critic.use_dynamic_bsz=${use_dynamic_bsz} \ + critic.ppo_max_token_len_per_gpu=${critic_ppo_max_token_len} \ + critic.optim.lr=${critic_lr} \ + critic.optim.lr_warmup_steps=${lr_warmup_steps} \ + critic.model.path=${RM_MODEL_PATH} \ + critic.model.enable_gradient_checkpointing=True \ + critic.ppo_micro_batch_size=${train_micro_batch_size} \ + critic.model.fsdp_config.param_offload=False \ + critic.ulysses_sequence_parallel_size=${critic_sp_size} \ + +critic.tp_size=${critic_tp} \ + +critic.model.override_config.attention_dropout=0. \ + +critic.model.override_config.embd_pdrop=0. \ + +critic.model.override_config.resid_pdrop=0. \ + critic.model.use_remove_padding=True \ + +critic.use_rmpad=True \ + reward_model.enable=False \ + reward_model.model.input_tokenizer=null \ + reward_model.model.path=${RM_MODEL_PATH} \ + reward_model.micro_batch_size=${infer_micro_batch_size} \ + reward_model.mean=0.0 \ + reward_model.std=1.0 \ + reward_model.use_last_response=${use_last_response} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.gamma=${gae_gamma} \ + algorithm.lam=${gae_lam} \ + trainer.critic_warmup=10 \ + trainer.logger=['console'] \ + trainer.project_name=${project_name} \ + trainer.experiment_name=${experiment_name} \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=${NNODES} \ + trainer.save_freq=${save_freq} \ + trainer.test_freq=${test_freq} \ + trainer.total_epochs=${total_epochs} \ + trainer.val_only=False \ + reward_model.punish_format=False \ + reward_model.format_punish_score=-0.1 \ diff --git a/recipe/tppo/sample_pool.py b/recipe/tppo/sample_pool.py new file mode 100644 index 00000000000..c049c33befe --- /dev/null +++ b/recipe/tppo/sample_pool.py @@ -0,0 +1,187 @@ +import numpy as np +from collections import defaultdict, deque +import random +import logging +from verl import DataProto +import random +import torch +import torch.nn.functional as F +from copy import deepcopy +import uuid +from collections import defaultdict + +logger = logging.getLogger(__file__) + + +def is_finished(item, config, tokenizer, max_window_rounds): + max_response_length = config.data.get('window_response_length', None) + max_prompt_length = config.data.max_prompt_length + max_window_rounds * max_response_length + actual_prompt_length = max_prompt_length - item.batch['window_rounds'].item() * max_response_length + prompt_ids = item.batch['input_ids'][:actual_prompt_length] + response_ids = item.batch['input_ids'][actual_prompt_length:] + response_length = response_ids.shape[-1] + valid_response_length = item.batch['attention_mask'][actual_prompt_length:].sum().item() + valid_response_ids = response_ids[:valid_response_length] + is_trunc = (response_length == valid_response_length) and (valid_response_ids[valid_response_length-1].item() != tokenizer.eos_token_id) + is_last_round = item.batch['window_rounds'].item() == config.data.max_response_length // config.data.window_response_length - 1 + is_finished = (not is_trunc) or is_last_round + return is_finished + + +class SamplePool: + name = "sample_pool" + + def __init__(self, config, tokenizer): + self.config = config + self.tokenizer = tokenizer + self.num_bon = config.actor_rollout_ref.rollout.get("num_bon", 1) + self.sample_pool = [] + self.prompt_list = [] + self.batch_keys = [] + self.non_tensor_batch_keys = [] + self.meta_info_keys = [] + self.pool_with_grad = defaultdict(list) + self.pool_with_unfinished = [] + self.id2acc = defaultdict(list) + + def rearrange_sample_pool(self): + new_sample_list = [] + for v in self.prompt_list: + new_sample_list.append(deepcopy(v)) + self.prompt_list = [] + for idx, item in enumerate(self.sample_pool): + new_sample_list.append(deepcopy(item)) + self.sample_pool = [i for i in new_sample_list] + + def fill_sample_pool(self, batch): + batch_lst = batch.chunk(len(batch)) + max_prompt_length = self.config.data.max_prompt_length + self.batch_keys = list(batch.batch.keys()) + ['left_pad_len', 'actual_prompt_len', 'window_rounds'] + if 'values' not in self.batch_keys: self.batch_keys += ['values'] + if 'answer_input_ids' in self.batch_keys: self.batch_keys.remove('answer_input_ids') + if 'answer_attention_mask' in self.batch_keys: self.batch_keys.remove('answer_attention_mask') + self.non_tensor_batch_keys = list(batch.non_tensor_batch.keys()) + ['rollout_id'] + self.meta_info_keys = batch.meta_info.keys() + for item in batch_lst: + item.batch['window_rounds'] = torch.tensor([0], device=item.batch['attention_mask'].device) + item.batch['actual_prompt_len'] = item.batch['attention_mask'][:, :max_prompt_length].sum(-1) + item.batch['left_pad_len'] = max_prompt_length - item.batch['actual_prompt_len'] + item.batch['values'] = torch.zeros_like(item.batch['attention_mask'], dtype=torch.float32)[:, :0] + item.non_tensor_batch['rollout_id'] = np.array([str(uuid.uuid4())], dtype=object) + for _ in range(self.num_bon): + self.sample_pool.append(deepcopy(item)) + print("[SamplePool] fill_batch:", len(batch_lst), "sample_pool size:", len(self.sample_pool)) + + def get_gen_batch(self, return_batch_size): + return_batch = [] + window_round = 0 + padded_size = 0 + while len(return_batch) < return_batch_size: + item = self.sample_pool[0] + self.sample_pool = self.sample_pool[1:] + window_round = max(window_round, item.batch['window_rounds'].item()) + padded_size = max(padded_size, item.batch['input_ids'].size(1)) + new_item = item.select(batch_keys=self.batch_keys, non_tensor_batch_keys=self.non_tensor_batch_keys, meta_info_keys=self.meta_info_keys) + return_batch.append(deepcopy(new_item)) + for idx, item in enumerate(return_batch): + if item.batch['window_rounds'].item() == 0: + pad_size = (window_round - item.batch['window_rounds'].item()) * self.config.data.window_response_length + item.batch['left_pad_len'] += pad_size + item.batch['input_ids'] = F.pad(item.batch['input_ids'], (pad_size, 0), value=self.tokenizer.pad_token_id) + item.batch['attention_mask'] = F.pad(item.batch['attention_mask'], (pad_size, 0), value=0) + item.batch['values'] = F.pad(item.batch['values'], (pad_size, 0), value=0) + # item.batch['answer_input_ids'] = F.pad(item.batch['answer_input_ids'], (pad_size, 0), value=self.tokenizer.pad_token_id) + # item.batch['answer_attention_mask'] = F.pad(item.batch['answer_attention_mask'], (pad_size, 0), value=0) + elif item.batch['input_ids'].size(1) < padded_size: + pad_size = padded_size - item.batch['input_ids'].size(1) + item.batch['input_ids'] = F.pad(item.batch['input_ids'], (pad_size, 0), value=self.tokenizer.pad_token_id) + item.batch['attention_mask'] = F.pad(item.batch['attention_mask'], (pad_size, 0), value=0) + item.batch['values'] = F.pad(item.batch['values'], (pad_size, 0), value=0) + item.batch['left_pad_len'] += pad_size + return DataProto.concat(return_batch) + + def update_multi_round_pool(self, batch): + batch_lst = batch.chunk(len(batch)) + for item in batch_lst: + is_finished = item.batch['is_finished'] + # uid = item.non_tensor_batch['index'][0] + # prompt = item.batch['prompts'] + # response = item.batch['responses'] + if (not is_finished) and item.batch['window_rounds'].item() < self.config.data.max_response_length // self.config.data.window_response_length - 1: + start_idx = torch.nonzero(item.batch['attention_mask'].flatten())[0].item() + real_len = item.batch['attention_mask'].sum(-1).item() + max_prompt_length = self.config.data.max_prompt_length + (item.batch['window_rounds'].item() + 1) * self.config.data.window_response_length + prompt_ids = F.pad(item.batch['input_ids'][:, start_idx:start_idx + real_len], (max_prompt_length - real_len, 0), value=self.tokenizer.pad_token_id) + prompt_attention_mask = F.pad(item.batch['attention_mask'][:, start_idx:start_idx + real_len], (max_prompt_length - real_len, 0), value=0) + new_item = deepcopy(item.select(batch_keys=self.batch_keys, non_tensor_batch_keys=self.non_tensor_batch_keys, meta_info_keys=self.meta_info_keys)) + new_item.batch['left_pad_len'] = torch.tensor([max_prompt_length - real_len], device=prompt_ids.device) + new_item.batch['input_ids'] = prompt_ids + new_item.batch['attention_mask'] = prompt_attention_mask + new_item.batch['values'] = new_item.batch['values'][:, :(item.batch['window_rounds'].item() + 1) * self.config.data.window_response_length] + # new_item.batch['answer_input_ids'] = prompt_ids + # new_item.batch['answer_attention_mask'] = prompt_attention_mask + new_item.batch['window_rounds'] += 1 + self.prompt_list.append(deepcopy(new_item)) + + def fill_rollout_pool_grad(self, batch): + batch_lst = batch.chunk(len(batch)) + self.id2data = defaultdict(list) + self.id2finish = defaultdict(list) + self.id2unfinish = defaultdict(list) + # get acc + for item in batch_lst: + score = item.batch['token_level_scores'].sum(-1).item() + is_finished = item.batch['is_finished'] + rollout_id = item.non_tensor_batch['rollout_id'][0] + if is_finished: + self.id2acc[rollout_id].append(score) + self.id2finish[rollout_id].append(item) + else: + self.id2unfinish[rollout_id].append(item) + self.id2data[rollout_id].append(item) + for k, v in self.id2acc.items(): + if np.mean(v) != self.config.algorithm.rollout_pool.min_score and np.mean(v) != self.config.algorithm.rollout_pool.max_score: + if self.config.algorithm.rollout_pool.strategy == 'v2': + for item in self.id2finish[k]: + self.pool_with_grad[k].append(item) + else: + for item in self.id2data[k]: + self.pool_with_grad[k].append(item) + print("[SamplePool/fill_rollout_pool_grad] fill_batch:", len(batch_lst), "pool_size after fill:", len(self.pool_with_grad)) + + + def get_train_batch_grad(self, return_batch_size): + return_batch = [] + k_lst = list(self.pool_with_grad.keys()) + random.shuffle(k_lst) + while len(return_batch) < return_batch_size: + if len(k_lst) == 0: break + k = k_lst[0] + k_lst = k_lst[1:] + for item in self.pool_with_grad[k]: + return_batch.append(item) + k_lst = sorted(self.id2unfinish.keys(), key = lambda k: len(self.id2unfinish[k]), reverse=True) + if self.config.algorithm.rollout_pool.strategy == 'v2': + k_lst = list(filter(lambda k: len(self.id2unfinish[k]) >= self.config.actor_rollout_ref.rollout.num_bon // 8, k_lst)) + while len(return_batch) < return_batch_size: + if len(k_lst) == 0: break + k = k_lst[0] + k_lst = k_lst[1:] + if self.config.algorithm.rollout_pool.strategy == 'v2': + for item in self.id2unfinish[k]: + return_batch.append(item) + elif k not in self.pool_with_grad.keys(): + for item in self.id2data[k]: + return_batch.append(item) + k_lst = [k for k in self.id2data.keys() if ((k not in self.pool_with_grad.keys()) and (k not in self.id2unfinish.keys()))] + while len(return_batch) < return_batch_size: + if len(k_lst) == 0: break + k = k_lst[0] + k_lst = k_lst[1:] + for item in self.id2data[k]: + return_batch.append(item) + if len(return_batch) < return_batch_size: + return_batch.extend([random.choice(return_batch) for _ in range(return_batch_size - len(return_batch))]) + return_batch = return_batch[:return_batch_size] + self.pool_with_grad = defaultdict(list) + return DataProto.concat(return_batch) diff --git a/recipe/tppo/tppo_actor.py b/recipe/tppo/tppo_actor.py new file mode 100644 index 00000000000..59fca1b4de3 --- /dev/null +++ b/recipe/tppo/tppo_actor.py @@ -0,0 +1,646 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Single Process Actor +""" + +import itertools +import logging +import os +from typing import Tuple + +import torch +from torch import nn +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +import verl.utils.torch_functional as verl_F +from verl import DataProto +# from verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, get_policy_loss_fn, kl_penalty +from verl.utils.debug import GPUMemoryLogger +from verl.utils.device import get_device_id, get_device_name, is_cuda_available, is_npu_available +from verl.utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_ +from verl.utils.py_functional import append_to_dict +from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches +from verl.utils.torch_functional import logprobs_from_logits +from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad, ulysses_pad_and_slice_inputs +from verl.workers.actor import BasePPOActor + +if is_cuda_available: + from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input +elif is_npu_available: + from transformers.integrations.npu_flash_attention import index_first_axis, pad_input, rearrange, unpad_input + + +__all__ = ["DataParallelPPOActor"] + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class DataParallelPPOActor(BasePPOActor): + def __init__(self, + config, + actor_module: nn.Module, + actor_optimizer: torch.optim.Optimizer = None, + tokenizer=None + ): + """When optimizer is None, it is Reference Policy""" + super().__init__(config) + self.actor_module = actor_module + self.actor_optimizer = actor_optimizer + + self.use_remove_padding = self.config.get("use_remove_padding", False) + if torch.distributed.get_rank() == 0: + print(f"Actor use_remove_padding={self.use_remove_padding}") + self.use_fused_kernels = self.config.get("use_fused_kernels", False) + if torch.distributed.get_rank() == 0: + print(f"Actor use_fused_kernels={self.use_fused_kernels}") + + self.ulysses_sequence_parallel_size = self.config.ulysses_sequence_parallel_size + self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1 + + if self.config.entropy_from_logits_with_chunking: + entropy_from_logits = verl_F.entropy_from_logits_with_chunking + else: + entropy_from_logits = verl_F.entropy_from_logits + + self.compute_entropy_from_logits = ( + torch.compile(entropy_from_logits, dynamic=True) + if self.config.get("use_torch_compile", True) # use torch compile by default + else entropy_from_logits + ) + self.device_name = get_device_name() + self.tokenizer = tokenizer + + def _forward_micro_batch(self, micro_batch, temperature, calculate_entropy=False) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Returns: + entropy: # (bs, response_len) + log_probs: # (bs, response_len) + """ + if 'left_pad_len' in micro_batch.keys():# NOTE(zht) + response_length = micro_batch['input_ids'].size(-1) - micro_batch['left_pad_len'] - micro_batch['actual_prompt_len'] + else: + response_length = micro_batch["responses"].size(-1) + + multi_modal_inputs = {} + if "multi_modal_inputs" in micro_batch.keys(): + for key in micro_batch["multi_modal_inputs"][0].keys(): + # Special handling for MiniCPM-o model: pixel_values, image_bound, and tgt_sizes + # need different concatenation strategies compared to other multimodal inputs + if (key == "pixel_values" and isinstance(micro_batch["multi_modal_inputs"][0]["pixel_values"], list)) or key == "image_bound" or key == "tgt_sizes": + # For MiniCPM-o: keep as list structure instead of concatenating tensors + multi_modal_inputs[key] = [inputs[key] for inputs in micro_batch["multi_modal_inputs"]] + else: + multi_modal_inputs[key] = torch.cat([inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0) + + with torch.autocast(device_type=self.device_name, dtype=torch.bfloat16): + input_ids = micro_batch["input_ids"] + batch_size, seqlen = input_ids.shape + attention_mask = micro_batch["attention_mask"] + position_ids = micro_batch["position_ids"] + entropy = None + if position_ids.dim() == 3: # qwen2vl mrope + position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen) + + if self.use_remove_padding: + input_ids_rmpad, indices, cu_seqlens, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) + + # unpad the position_ids to align the rotary + if position_ids.dim() == 3: + position_ids_rmpad = index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices).transpose(0, 1).unsqueeze(1) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen) + else: + position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices).transpose(0, 1) + + if "multi_modal_inputs" in micro_batch: + # MiniCPM-o specific processing for image bounds and pixel values + if "image_bound" in multi_modal_inputs: + # Adjust image bounds based on left padding and cumulative sequence lengths + # This is necessary for MiniCPM-o's vision-language alignment + left_padding_length = torch.argmax(attention_mask, dim=1) + image_bounds = [] + for i in range(len(multi_modal_inputs["image_bound"])): + image_bound = multi_modal_inputs["image_bound"][i].to(left_padding_length.device) - left_padding_length[i] + cu_seqlens[i] + image_bounds.append(image_bound) + multi_modal_inputs["image_bound"] = [torch.vstack(image_bounds)] + # Flatten pixel values list for MiniCPM-o processing + pixel_values = [] + for i in range(len(multi_modal_inputs["pixel_values"])): + pixel_values.extend([p for p in multi_modal_inputs["pixel_values"][i]]) + multi_modal_inputs["pixel_values"] = [pixel_values] + # Handle target sizes for MiniCPM-o vision processing + if "tgt_sizes" in multi_modal_inputs: + multi_modal_inputs["tgt_sizes"] = [torch.vstack(multi_modal_inputs["tgt_sizes"])] + + # for compute the log_prob + input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz) + + # pad and slice the inputs if sp > 1 + if self.use_ulysses_sp: + is_vlm_model = "multi_modal_inputs" in micro_batch.keys() + if is_vlm_model: + # vlm model's inputs will be sliced after embedding + input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad( + input_ids_rmpad, + position_ids_rmpad=position_ids_rmpad, + sp_size=self.ulysses_sequence_parallel_size, + ) + else: + input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( + input_ids_rmpad, + position_ids_rmpad=position_ids_rmpad, + sp_size=self.ulysses_sequence_parallel_size, + ) + input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs( + input_ids_rmpad_rolled, + position_ids_rmpad=None, + sp_size=self.ulysses_sequence_parallel_size, + ) + + input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) # ((total_nnz / sp) + pad) + + # only pass input_ids and position_ids to enable flash_attn_varlen + extra_args = {} + if self.use_fused_kernels: + extra_args["temperature"] = temperature + extra_args["return_dict"] = True + + output = self.actor_module( + input_ids=input_ids_rmpad, + attention_mask=None, + position_ids=position_ids_rmpad, + **multi_modal_inputs, + use_cache=False, + **extra_args, + ) # prevent model thinks we are generating + + if self.use_fused_kernels: + log_probs = output.log_probs.squeeze(0) # (total_nnz,) + entropy_rmpad = output.entropy.squeeze(0) # (total_nnz,) + + else: + logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size) + logits_rmpad.div_(temperature) + + # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) + + if calculate_entropy: + inplace_backward = False + else: + inplace_backward = True + + log_probs = logprobs_from_logits( + logits=logits_rmpad, + labels=input_ids_rmpad_rolled, + inplace_backward=inplace_backward, + ) + + # compute entropy + if calculate_entropy: + if not self.config.entropy_checkpointing: + entropy_rmpad = self.compute_entropy_from_logits(logits_rmpad) # ((total_nnz / sp) + pad) + else: + entropy_rmpad = torch.utils.checkpoint.checkpoint(self.compute_entropy_from_logits, logits_rmpad) + + # gather log_prob if sp > 1 + if self.use_ulysses_sp: + # gather and unpad for the ulysses sp + log_probs = gather_outputs_and_unpad( + log_probs, + gather_dim=0, + unpad_dim=0, + padding_size=pad_size, + ) + if calculate_entropy: + entropy_rmpad = gather_outputs_and_unpad( + entropy_rmpad, + gather_dim=0, + unpad_dim=0, + padding_size=pad_size, + ) + # pad back to (bsz, seqlen) + full_log_probs = pad_input( + hidden_states=log_probs.unsqueeze(-1), + indices=indices, + batch=batch_size, + seqlen=seqlen, + ) + if calculate_entropy: + full_entropy = pad_input( + hidden_states=entropy_rmpad.unsqueeze(-1), + indices=indices, + batch=batch_size, + seqlen=seqlen, + ) + + # only return response part: + full_output = full_log_probs.squeeze(-1) + if 'left_pad_len' in micro_batch.keys(): # NOTE(zht) + log_probs = torch.zeros_like(full_output) + for idx, r in enumerate(response_length): + log_probs[idx, :r] = full_output[idx, -r - 1: -1] + else: + log_probs = full_output[:, -response_length - 1 : -1] # [batch_size, response_length] + + if calculate_entropy: + entropy = full_entropy.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length) + + else: # not using rmpad and no ulysses sp + extra_args = {} + if self.use_fused_kernels: + extra_args["temperature"] = temperature + extra_args["return_dict"] = True + + output = self.actor_module( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + **multi_modal_inputs, + use_cache=False, + **extra_args, + ) # prevent model thinks we are generating + + if self.use_fused_kernels: + log_probs = output.log_probs[:, -response_length - 1 : -1] + entropy = output.entropy[:, -response_length - 1 : -1] # (bsz, response_length) + + else: + logits = output.logits + + logits.div_(temperature) + logits = logits[:, -response_length - 1 : -1, :] # (bsz, response_length, vocab_size) + log_probs = logprobs_from_logits(logits, micro_batch["responses"]) + if calculate_entropy: + entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length) + + return entropy, log_probs + + + def _optimizer_step(self): + assert self.config.grad_clip is not None + + if isinstance(self.actor_module, FSDP): + grad_norm = self.actor_module.clip_grad_norm_(max_norm=self.config.grad_clip) + elif isinstance(self.actor_module, FSDPModule): + grad_norm = fsdp2_clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.grad_clip) + else: + grad_norm = torch.nn.utils.clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.grad_clip) + + # if grad_norm is not finite, skip the update + if not torch.isfinite(grad_norm): + print(f"WARN: rank {torch.distributed.get_rank()} grad_norm is not finite: {grad_norm}") + self.actor_optimizer.zero_grad() + else: + self.actor_optimizer.step() + return grad_norm + + + @GPUMemoryLogger(role="dp actor", logger=logger) + def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Tensor: + """Compute the log probability of the responses given input_ids, attention_mask and position_ids + + Args: + data (DataProto): a DataProto containing keys + + ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the + concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``. + + ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64. + + ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. + + ``responses``: tensor of shape [batch_size, response_length]. torch.int64. + + Returns: + torch.Tensor: the log_prob tensor + """ + # set to eval + self.actor_module.eval() + + micro_batch_size = data.meta_info["micro_batch_size"] + temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid silent error + use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] + + def _get_micro_batches(data: DataProto) -> Tuple[list, list | None]: + select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] + batch = data.select(batch_keys=select_keys).batch + has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch + + if has_multi_modal_inputs: + all_multi_modal_inputs_list = data.non_tensor_batch["multi_modal_inputs"] + if use_dynamic_bsz: + max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size + rearranged_text_micro_batches, textual_indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len) + + final_micro_batches_list = [] + for i, text_mb_td in enumerate(rearranged_text_micro_batches): + current_original_indices = textual_indices[i] + current_mm_inputs_list = [all_multi_modal_inputs_list[idx] for idx in current_original_indices] + + mb_dict = {k: v for k, v in text_mb_td.items()} + mb_dict["multi_modal_inputs"] = current_mm_inputs_list + final_micro_batches_list.append(mb_dict) + return final_micro_batches_list, textual_indices + else: + num_micro_batches = batch.batch_size[0] // micro_batch_size + micro_batches_dp = data.chunk(num_micro_batches) + return micro_batches_dp, None + elif use_dynamic_bsz: + max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size + micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len) + return micro_batches, indices + else: + micro_batches = batch.split(micro_batch_size) + return micro_batches, None + + micro_batches, indices = _get_micro_batches(data) + + log_probs_lst = [] + entropy_lst = [] + for micro_batch in micro_batches: + if isinstance(micro_batch, DataProto): + micro_batch = {**micro_batch.batch, **micro_batch.non_tensor_batch} + with torch.no_grad(): + entropy, log_probs = self._forward_micro_batch(micro_batch, temperature=temperature, calculate_entropy=calculate_entropy) + log_probs_lst.append(log_probs) + if calculate_entropy: + entropy_lst.append(entropy) + + log_probs = torch.concat(log_probs_lst, dim=0) + entropys = None + if calculate_entropy: + entropys = torch.concat(entropy_lst, dim=0) + if use_dynamic_bsz: + indices = list(itertools.chain.from_iterable(indices)) + assert len(indices) == log_probs.size(0), f"{len(indices)} vs. {log_probs.size()}" + revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) + log_probs = log_probs[revert_indices] + if calculate_entropy: + entropys = entropys[revert_indices] + + return log_probs, entropys + + + @GPUMemoryLogger(role="dp actor", logger=logger) + def update_policy(self, data: DataProto): + # make sure we are in training mode + self.actor_module.train() + + temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid silent error + multi_turn = data.meta_info.get("multi_turn", False) + + select_keys = ["responses", "input_ids", "attention_mask", "position_ids", "old_log_probs", "advantages"] + if multi_turn: + select_keys.append("loss_mask") + if self.config.use_kl_loss: + select_keys.append("ref_log_prob") + if 'window_rounds' in data.batch.keys(): # ADD(zht) + select_keys += ['left_pad_len', 'actual_prompt_len', 'is_finished', 'rounds_eos_mask', 'window_rounds', 'token_level_scores'] + + batch = data.select(batch_keys=select_keys).batch + + ################################################################################### + ################################################################################### + response_length = batch['responses'].size(1) # NOTE(zht) + batch_full_token_count_mask = batch['attention_mask'][:, -response_length:] + if 'overlong_mask' in batch.keys(): + batch_full_token_count_mask *= batch['overlong_mask'].unsqueeze(-1) + if 'acc_mask' in batch.keys(): + batch_full_token_count_mask *= batch['acc_mask'].unsqueeze(-1) + batch_full_token_count = max(1, batch_full_token_count_mask.sum().item()) + # compute batch lm full token count + batch_rounds_eos_mask = batch.get('rounds_eos_mask', None) + if batch_rounds_eos_mask is not None: + batch_eos_ids = batch['eos_ids'].unsqueeze(1) + batch_raw_scores = batch['token_level_scores'] + batch_scores = torch.gather(batch_raw_scores, 1, batch_eos_ids) + batch_mask = (batch_scores > 0).repeat(1, batch_rounds_eos_mask.shape[1]) + batch_mask &= batch_rounds_eos_mask.bool() + batch_lm_full_token_count = max(1, batch_mask.sum().item()) + else: + batch_lm_full_token_count = batch_full_token_count + ################################################################################### + ################################################################################### + + has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() + + # Split to make minibatch iterator for updating the actor + # See PPO paper for details. https://arxiv.org/abs/1707.06347 + if has_multi_modal_inputs: + num_mini_batches = data.batch.batch_size[0] // self.config.ppo_mini_batch_size + non_tensor_select_keys = ["multi_modal_inputs"] + dataloader = data.select(select_keys, non_tensor_select_keys).chunk(num_mini_batches) + else: + dataloader = batch.split(self.config.ppo_mini_batch_size) + + metrics = {} + for epoch in range(self.config.ppo_epochs): + for batch_idx, data in enumerate(dataloader): + # split batch into micro_batches + mini_batch = data + if has_multi_modal_inputs: + micro_batches = [] + if self.config.use_dynamic_bsz: + all_multi_modal_inputs_list = data.non_tensor_batch["multi_modal_inputs"] + batch_tensordict_for_rearrange = data.batch + + max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size + rearranged_text_micro_batches_tds, textual_indices = rearrange_micro_batches(batch=batch_tensordict_for_rearrange, max_token_len=max_token_len) + + for current_original_indices, text_mb_td in zip(textual_indices, rearranged_text_micro_batches_tds): + current_mm_inputs_list = [all_multi_modal_inputs_list[idx] for idx in current_original_indices] + mb_dict = {k: v for k, v in text_mb_td.items()} + mb_dict["multi_modal_inputs"] = current_mm_inputs_list + micro_batches.append(mb_dict) + else: + self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu + num_micro_batches = mini_batch.batch.batch_size[0] // self.config.ppo_micro_batch_size_per_gpu + micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches) + elif self.config.use_dynamic_bsz: + max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size + micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len) + else: + self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu + # split batch into micro_batches + micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu) + + self.actor_optimizer.zero_grad() + + ############################################################################################ + # compute minibatch full token count # NOTE(zht) + mini_batch_full_token_count_mask = mini_batch['attention_mask'][:, -response_length:] + if 'overlong_mask' in mini_batch.keys(): + mini_batch_full_token_count_mask *= mini_batch['overlong_mask'].unsqueeze(-1) + if 'acc_mask' in mini_batch.keys(): + mini_batch_full_token_count_mask *= mini_batch['acc_mask'].unsqueeze(-1) + mini_batch_full_token_count = max(1, mini_batch_full_token_count_mask.sum().item()) + + mini_batch_rounds_eos_mask = mini_batch.get('rounds_eos_mask', None) + if mini_batch_rounds_eos_mask is not None: + mini_batch_eos_ids = mini_batch['eos_ids'].unsqueeze(1) + mini_batch_raw_scores = mini_batch['token_level_scores'] + mini_batch_scores = torch.gather(mini_batch_raw_scores, 1, mini_batch_eos_ids) + mini_batch_mask = (mini_batch_scores > 0).repeat(1, mini_batch_rounds_eos_mask.shape[1]) + mini_batch_mask &= mini_batch_rounds_eos_mask.bool() + mini_batch_lm_full_token_count = max(1, mini_batch_mask.sum().item()) + else: + mini_batch_lm_full_token_count = mini_batch_full_token_count + ############################################################################################ + + for data in micro_batches: + # Support all hardwares + if isinstance(data, DataProto): + data = {**data.batch.to(get_device_id()), **data.non_tensor_batch} + elif isinstance(data, dict): + for k, v in data.items(): + if isinstance(v, torch.Tensor): + data[k] = v.to(get_device_id()) + elif k == "multi_modal_inputs" and v is not None: + data[k] = [{kk: vv.to(get_device_id()) for kk, vv in item_dict.items()} for item_dict in v] + else: + data[k] = v + else: + data = data.to(get_device_id()) # actor device is cpu when using offload + responses = data["responses"] + response_length = responses.size(1) + attention_mask = data["attention_mask"] + if multi_turn: + response_mask = data["loss_mask"][:, -response_length:] + else: + response_mask = attention_mask[:, -response_length:] + + old_log_prob = data["old_log_probs"] + advantages = data["advantages"] + + clip_ratio = self.config.clip_ratio + clip_ratio_low = self.config.clip_ratio_low if self.config.clip_ratio_low is not None else clip_ratio + clip_ratio_high = self.config.clip_ratio_high if self.config.clip_ratio_high is not None else clip_ratio + clip_ratio_c = self.config.get("clip_ratio_c", 3.0) + entropy_coeff = self.config.entropy_coeff + loss_agg_mode = self.config.loss_agg_mode + + # all return: (bsz, response_length) + calculate_entropy = False + if entropy_coeff != 0: + calculate_entropy = True + entropy, log_prob = self._forward_micro_batch(micro_batch=data, temperature=temperature, calculate_entropy=calculate_entropy) + + loss_mode = self.config.policy_loss.get("loss_mode", "vanilla") + + ################################################################################################## + eos_ids = data.get('eos_ids', None) # NOTE(zht): microdata actually + overlong_mask = data.get('overlong_mask', None) + acc_mask = data.get('acc_mask', None) + single_eos_ids = eos_ids.clone() + if data.get('window_rounds', None) is not None: + total_len = log_prob.size(1) + window_response_length = self.config.window_response_length + single_log_prob = torch.zeros_like(log_prob) + for idx, window_round in enumerate(data['window_rounds']): + single_log_prob[idx, :total_len-window_round * window_response_length] = log_prob[idx, window_round * window_response_length:] + single_eos_ids[idx] -= window_round * window_response_length + single_log_prob = single_log_prob[:, :window_response_length] + else: + single_log_prob = log_prob + + if self.config.policy_loss.loss_mode == "tppo": # loss_mode must be tppo !!!!!!! + from recipe.tppo.tppo_algos import compute_policy_loss + pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss( + old_log_prob=old_log_prob, + log_prob=single_log_prob, + advantages=advantages, + response_mask=response_mask, + cliprange=clip_ratio, + cliprange_low=clip_ratio_low, + cliprange_high=clip_ratio_high, + clip_ratio_c=clip_ratio_c, + loss_agg_mode=loss_agg_mode, + eos_ids=single_eos_ids, + overlong_mask=overlong_mask, + acc_mask=acc_mask, + ) + else: + raise ValueError(f"Unknown loss mode: {loss_mode}") + # policy_loss_fn = get_policy_loss_fn(loss_mode) + # pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(old_log_prob, log_prob, advantages, response_mask, loss_agg_mode, self.config) + ################################################################################################## + + if entropy_coeff != 0: + from recipe.tppo.tppo_algos import agg_loss + entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + + # # compute policy loss + # policy_loss = pg_loss - entropy_loss * entropy_coeff + # else: + # policy_loss = pg_loss + else: + entropy_loss = torch.zeros((), device=pg_loss.device) + + if self.config.use_kl_loss: + from recipe.tppo.tppo_algos import compute_kl_loss + ref_log_prob = data["ref_log_prob"] + kl_loss = compute_kl_loss(log_prob, ref_log_prob, response_mask, self.config.kl_loss_type, loss_agg_mode) + kl_loss_coef = self.config.kl_loss_coef + # policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef + metrics["actor/kl_loss"] = kl_loss.detach().item() + metrics["actor/kl_coef"] = self.config.kl_loss_coef + else: + kl_loss = torch.zeros((), device=pg_loss.device) + kl_loss_coef = 0 + + ################################################################################## + # lm_loss # NOTE(zht): need to add 'token_level_scores' + lm_loss_weight = self.config.lm_loss_weight + if lm_loss_weight > 0: + from recipe.tppo.tppo_algos import compute_lm_loss + rounds_eos_mask = data.get('rounds_eos_mask', None) + lm_loss = compute_lm_loss( + log_prob=log_prob, + raw_score=data['token_level_scores'], + eos_ids=eos_ids, + rounds_eos_mask=rounds_eos_mask, + loss_average_method=loss_agg_mode, + ) + + ################################################################################## + # if self.config.use_dynamic_bsz: + # # relative to the dynamic bsz + # loss = policy_loss * (len(data) / self.config.ppo_mini_batch_size) + # else: + # loss = policy_loss / self.gradient_accumulation + if loss_agg_mode=='minibatch': + loss = (pg_loss - entropy_loss*entropy_coeff + kl_loss*kl_loss_coef) / mini_batch_full_token_count + lm_loss_weight*lm_loss / mini_batch_lm_full_token_count + elif loss_agg_mode=='batch': + loss = ((pg_loss - entropy_loss*entropy_coeff + kl_loss*kl_loss_coef) / batch_full_token_count + lm_loss_weight*lm_loss / batch_lm_full_token_count)*len(dataloader) + else: + raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}") + + loss.backward() + + data = { + "actor/pg_loss": pg_loss.detach().item(), + "actor/pg_clipfrac": pg_clipfrac.detach().item(), + "actor/ppo_kl": ppo_kl.detach().item(), + "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(), + } + append_to_dict(metrics, data) + + grad_norm = self._optimizer_step() + data = {"actor/grad_norm": grad_norm.detach().item()} + append_to_dict(metrics, data) + self.actor_optimizer.zero_grad() + return metrics diff --git a/recipe/tppo/tppo_algos.py b/recipe/tppo/tppo_algos.py new file mode 100644 index 00000000000..228ae9ea4e3 --- /dev/null +++ b/recipe/tppo/tppo_algos.py @@ -0,0 +1,205 @@ +import torch +import verl.utils.torch_functional as verl_F +from verl.trainer.ppo import core_algos + +def compute_truncate_gae_advantage_return(token_level_rewards: torch.Tensor, single_token_level_rewards: torch.Tensor, values: torch.Tensor, eos_mask: torch.Tensor, + gamma: torch.Tensor, lam: torch.Tensor, use_variable_lambda: torch.Tensor, + variable_lambda_scalar: torch.Tensor, adv_whiten: bool, use_separate_critic_lam: bool, + critic_lam: torch.Tensor, is_finished: torch.Tensor, ignore_token_num: int, rounds_eos_mask: torch.Tensor, seq_len_per_sample: torch.Tensor, is_clamp: bool): + window_mask = torch.ones_like(eos_mask) + window_mask[:, -ignore_token_num:] = 0 + window_mask = torch.maximum(window_mask, torch.ones_like(eos_mask) * is_finished.unsqueeze(-1)) + values = values * eos_mask + token_level_rewards = token_level_rewards * rounds_eos_mask[:, :token_level_rewards.size(-1)] + if use_variable_lambda: + # seq_len_per_sample = torch.clamp(torch.sum(rounds_eos_mask, dim=1), min=1.0) + seq_len_per_sample += (1 - is_finished) * values.shape[-1] // 2 + lam = torch.clamp(1 - 1 / (variable_lambda_scalar * seq_len_per_sample), min=lam) + with torch.no_grad(): + lastgaelam = 0 + advantages_reversed = [] + gen_len = values.shape[-1] + for t in reversed(range(gen_len)): + nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0 + delta = single_token_level_rewards[:, t] + gamma * nextvalues - values[:, t] + if t == gen_len - 1: + delta = delta * is_finished + ~is_finished * (gamma-1) * values[:, t] + lastgaelam = delta + gamma * lam * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], dim=1) + advantages *= window_mask + if use_separate_critic_lam and critic_lam == 1: + cumsum_rewards = torch.cumsum(token_level_rewards, dim=1) + returns = token_level_rewards - cumsum_rewards + cumsum_rewards[:, -1:None] + returns *= is_finished.unsqueeze(-1) + origin_advantages = advantages + if adv_whiten: + advantages = verl_F.masked_whiten(origin_advantages, eos_mask*window_mask) + advantages = advantages * eos_mask * window_mask + else: + advantages = torch.clone(origin_advantages) + if is_clamp: + advantages = torch.clamp(advantages, max=10.0, min=-10.0) + return origin_advantages, advantages, returns + + +def compute_lm_loss(log_prob, raw_scores, eos_ids, rounds_eos_mask=None, loss_average_method='token'): + if rounds_eos_mask is not None: + log_prob = log_prob[:, :rounds_eos_mask.shape[-1]] + eos_ids = eos_ids.unsqueeze(1) + scores = torch.gather(raw_scores, 1, eos_ids) + ids = torch.arange(log_prob.shape[1], device=eos_ids.device).unsqueeze(0).repeat(log_prob.shape[0], 1) + mask0 = ids <= eos_ids + mask1 = (scores > 0).repeat(1, log_prob.shape[1]) + mask = mask0 & mask1 + if rounds_eos_mask is not None: + mask &= rounds_eos_mask.bool() + lm_loss = torch.masked_select(log_prob, mask) + if loss_average_method in ['sample', 'token']: + lm_loss = -torch.sum(lm_loss) / max(lm_loss.numel(), 1) + elif loss_average_method in ['minibatch', 'batch']: + lm_loss = -torch.sum(lm_loss) + else: + raise NotImplementedError(f"loss_average_method {loss_average_method} not implemented") + return lm_loss + + +def compute_kl_loss(log_prob, ref_log_prob, eos_mask, kl_penalty_, loss_average_method='token'): + if kl_penalty_ in ("abs", "mse", "low_var_kl"): + kl = core_algos.kl_penalty(log_prob, ref_log_prob, kl_penalty_) + elif kl_penalty_ in ("kl"): + kl = core_algos.kl_penalty(log_prob, ref_log_prob, kl_penalty_).square() + else: + raise NotImplementedError + if loss_average_method == 'sample': + seq_len_per_sample = torch.clamp(torch.sum(eos_mask, dim=1), min=1.0) + kl_loss = torch.mean(torch.sum(kl * eos_mask, dim=1) / seq_len_per_sample) + elif loss_average_method == 'token': + kl_loss = (kl * eos_mask).sum() / (eos_mask.sum() + 1e-6) + elif loss_average_method in ['minibatch', 'batch']: + kl_loss = (kl * eos_mask).sum() + return kl_loss + + +def compute_value_loss(vpreds, returns, values, eos_mask, cliprange_value_low, cliprange_value_high, overlong_mask): + """Compute the value loss. Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1151 + Args: + vpreds (`torch.FloatTensor`): + Predicted values of the value head, shape (`batch_size`, `response_length`) + values (`torch.FloatTensor`): + Old values of value head, shape (`batch_size`, `response_length`) + returns: (`torch.FloatTensor`): + Ground truth returns, shape (`batch_size`, `response_length`) + Returns: + vf_loss: a scalar (`torch.FloatTensor`): + value function loss + vf_clipfrac: a float + The ratio of vf being clipped + """ + vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value_low, values + cliprange_value_high) + vf_losses1 = (vpreds - returns)**2 + vf_losses2 = (vpredclipped - returns)**2 + seq_len_per_sample = torch.clamp(torch.sum(eos_mask, dim=1), min=1.0) + if overlong_mask is not None: + vf_loss = 0.5 * torch.mean( + torch.sum(torch.max(vf_losses1, vf_losses2) * eos_mask, dim=1) / seq_len_per_sample * overlong_mask) + else: + vf_loss = 0.5 * torch.mean(torch.sum(torch.max(vf_losses1, vf_losses2) * eos_mask, dim=1) / seq_len_per_sample) + vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), eos_mask) + vf_loss = vf_loss + return vf_loss, vf_clipfrac + + +def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str): + """ + Aggregate the loss matrix into a scalar. + + Args: + loss_mat: `(torch.Tensor)`: + shape: (bs, response_length) + loss_mask: `(torch.Tensor)`: + shape: (bs, response_length) + loss_agg_mode: (str) choices: + method to aggregate the loss matrix into a scalar. + Returns: + loss: `a scalar torch.Tensor` + aggregated loss + """ + if loss_agg_mode == "token-mean": + loss = verl_F.masked_mean(loss_mat, loss_mask) + elif loss_agg_mode == "seq-mean-token-sum": + seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) # token-sum + loss = torch.mean(seq_losses) # seq-mean + elif loss_agg_mode == "seq-mean-token-mean": + seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / torch.sum(loss_mask, dim=-1) # token-mean + loss = torch.mean(seq_losses) # seq-mean + elif loss_agg_mode == "seq-mean-token-sum-norm": + seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) + loss = torch.sum(seq_losses) / loss_mask.shape[-1] # The divisor + # (loss_mask.shape[-1]) should ideally be constant + # throughout training to well-replicate the DrGRPO paper. + # TODO: Perhaps add user-defined normalizer argument to + # agg_loss to ensure divisor stays constant throughout. + elif loss_agg_mode in ['minibatch-sum', 'batch-sum']: + # NOTE(HanlinDu): this case is the only incremental difference from ppo.core_algos.agg_loss + # Maybe we can replace the original one with this one in the TPPO usages, + # or even in all of the Verl PPO usages. + loss = (loss_mat * loss_mask).sum() + else: + raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}") + + return loss + + +def compute_policy_loss( + old_log_prob, + log_prob, + advantages, + response_mask, + cliprange=None, + cliprange_low=None, + cliprange_high=None, + clip_ratio_c=10.0, # cliprange2 + loss_agg_mode: str = "batch-sum", + eos_ids=None, + overlong_mask=None, + acc_mask=None, + kl_penalty='low_var_kl', +): + negative_approx_kl = core_algos.kl_penalty(log_prob, old_log_prob, kl_penalty=kl_penalty) + # Clamp negative_approx_kl for stability + ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) + negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) + ratio = torch.exp(negative_approx_kl) + + pg_losses1 = -advantages * ratio + if cliprange_high is None: + assert cliprange is not None, "cliprange_high is None, but cliprange is also None, please set one of them" + cliprange_high = cliprange + + if cliprange_low is None: + assert cliprange is not None, "cliprange_low is None, but cliprange is also None, please set one of them" + cliprange_low = cliprange + + pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - cliprange_low, 1.0 + cliprange_high) + pg_losses3 = torch.abs(-advantages * clip_ratio_c) + pg_losses_clip = torch.maximum(pg_losses1, pg_losses2) + pg_losses = torch.minimum(pg_losses_clip, pg_losses3) + + if overlong_mask is not None: + response_mask = response_mask * overlong_mask.unsqueeze(-1) + if acc_mask is not None: + response_mask = response_mask * acc_mask.unsqueeze(-1) + + pg_loss = core_algos.agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + + pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask) + pg_clipfrac_lower = verl_F.masked_mean(torch.gt(pg_losses_clip, pg_losses3) * (advantages < 0).float(), response_mask) + + return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower + + + + + + diff --git a/recipe/tppo/tppo_critic.py b/recipe/tppo/tppo_critic.py new file mode 100644 index 00000000000..c905b9fe72b --- /dev/null +++ b/recipe/tppo/tppo_critic.py @@ -0,0 +1,330 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Implement a multiprocess PPOCritic +""" + +import itertools +import logging +import os + +import torch +import torch.distributed as dist +from torch import nn, optim +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +import torch.nn.functional as F + +from verl import DataProto +from verl.trainer.ppo import core_algos +from verl.utils.debug import GPUMemoryLogger +from verl.utils.device import get_device_id, get_device_name, is_cuda_available, is_npu_available +from verl.utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_ +from verl.utils.py_functional import append_to_dict +from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches +from verl.utils.torch_functional import masked_mean +from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs +from verl.workers.critic import BasePPOCritic + +if is_cuda_available: + from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input +elif is_npu_available: + from transformers.integrations.npu_flash_attention import index_first_axis, pad_input, rearrange, unpad_input + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class DataParallelPPOCritic(BasePPOCritic): + def __init__(self, config, critic_module: nn.Module, critic_optimizer: optim.Optimizer): + super().__init__(config=config) + self.critic_module = critic_module + self.critic_optimizer = critic_optimizer + self.use_remove_padding = self.config.model.get("use_remove_padding", False) + print(f"Critic use_remove_padding={self.use_remove_padding}") + + self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) + self.device_name = get_device_name() + + def _forward_micro_batch(self, micro_batch): + if 'left_pad_len' in micro_batch.keys(): + response_length = micro_batch['input_ids'].size(-1) - micro_batch['left_pad_len'] - micro_batch['actual_prompt_len'] + else: + response_length = micro_batch['responses'].size(-1) + multi_modal_inputs = {} + if "multi_modal_inputs" in micro_batch.keys(): + for key in micro_batch["multi_modal_inputs"][0].keys(): + multi_modal_inputs[key] = torch.cat([inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0) + + with torch.autocast(device_type=self.device_name, dtype=torch.bfloat16): + input_ids = micro_batch["input_ids"] + batch, seqlen = input_ids.shape + attention_mask = micro_batch["attention_mask"] + position_ids = micro_batch["position_ids"] + if position_ids.dim() == 3: # qwen2vl mrope + position_ids = position_ids.transpose(0, 1) + + if self.use_remove_padding: + input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) + + # unpad the position_ids to align the rotary + if position_ids.dim() == 3: + position_ids_rmpad = index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices).transpose(0, 1).unsqueeze(1) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen) + else: + position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices).transpose(0, 1) + + # pad and slice the inputs if sp > 1 + if self.ulysses_sequence_parallel_size > 1: + input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size) + + # only pass input_ids and position_ids to enable flash_attn_varlen + output = self.critic_module( + input_ids=input_ids_rmpad, + attention_mask=None, + position_ids=position_ids_rmpad, + **multi_modal_inputs, + use_cache=False, + ) # prevent model thinks we are generating + + if hasattr(self.critic_module, "v_head"): + # For trl.AutoModelForCausalLMWithValueHead + values_rmpad = output[2].squeeze(0).unsqueeze(-1) + else: + values_rmpad = output.logits + values_rmpad = values_rmpad.squeeze(0) # (total_nnz) + + # gather output if sp > 1 + if self.ulysses_sequence_parallel_size > 1: + values_rmpad = gather_outputs_and_unpad(values_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size) + + # pad it back + # NOTE(HanlinDu): why we need padding here? + values = pad_input(values_rmpad, indices=indices, batch=batch, seqlen=seqlen).squeeze(-1) + if 'left_pad_len' in micro_batch.keys(): + new_values = torch.zeros_like(values) + for idx, r in enumerate(response_length): + new_values[idx, :r] = values[idx, -r - 1: -1] + # FIXME (HanlinDu): should not truncate the values here for temporarily fixing the bug + # values = new_values + values = new_values[:, :r] + else: + values = values[:, -response_length - 1:-1] + else: + output = self.critic_module( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + **multi_modal_inputs, + use_cache=False, + ) # prevent model thinks we are generating + if hasattr(self.critic_module, "v_head"): + # For trl.AutoModelForCausalLMWithValueHead + values = output[2] + else: + values = output.logits + values = values[:, -response_length - 1 : -1].squeeze(-1) + return values + + def _optimizer_step(self): + assert self.config.grad_clip is not None + + if isinstance(self.critic_module, FSDP): + grad_norm = self.critic_module.clip_grad_norm_(self.config.grad_clip) + elif isinstance(self.critic_module, FSDPModule): + grad_norm = fsdp2_clip_grad_norm_(self.critic_module.parameters(), max_norm=self.config.grad_clip) + else: + grad_norm = torch.nn.utils.clip_grad_norm_(self.critic_module.parameters(), max_norm=self.config.grad_clip) + + # if grad_norm is not finite, skip the update + if not torch.isfinite(grad_norm): + print(f"WARN: grad_norm is not finite: {grad_norm}") + self.critic_optimizer.zero_grad() + else: + self.critic_optimizer.step() + return grad_norm + + @GPUMemoryLogger(role="dp critic", logger=logger) + def compute_values(self, data: DataProto) -> torch.Tensor: + self.critic_module.eval() + micro_batch_size = data.meta_info["micro_batch_size"] + select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] + if 'window_rounds' in data.batch.keys(): + select_keys += ['left_pad_len', 'actual_prompt_len', 'window_rounds'] + batch = data.select(batch_keys=select_keys).batch + use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] + has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() + + if has_multi_modal_inputs: + num_micro_batches = data.batch.batch_size[0] // micro_batch_size + non_tensor_select_keys = ["multi_modal_inputs"] + micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches) + elif use_dynamic_bsz: + # split using dynamic bsz + max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size + micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len) + else: + micro_batches = batch.split(micro_batch_size) + + values_lst = [] + for micro_batch in micro_batches: + if isinstance(micro_batch, DataProto): + micro_batch = {**micro_batch.batch, **micro_batch.non_tensor_batch} + + with torch.no_grad(): + values = self._forward_micro_batch(micro_batch) + values_lst.append(values) + values = torch.concat(values_lst, dim=0) + + if use_dynamic_bsz: + indices = list(itertools.chain.from_iterable(indices)) + assert len(indices) == values.size(0), f"{len(indices)} vs. {values.size()}" + revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) + values = values[revert_indices] + + response_mask = data.batch["response_mask"] + if 'left_pad_len' in data.batch.keys(): + print(" --- use left pad --- ") + pad_len = values.shape[-1] - response_mask.shape[-1] + # pad from right + response_mask = F.pad(response_mask, (0, pad_len), value=0) + values = values * response_mask # Only action tokens have values + return values + + @GPUMemoryLogger(role="dp critic", logger=logger) + def update_critic(self, data: DataProto): + # make sure we are in training mode + self.critic_module.train() + metrics = {} + + select_keys = ["input_ids", "responses", "attention_mask", "position_ids", "values", "returns"] + + use_window_rollout = 'window_rounds' in data.batch.keys() + if use_window_rollout: + select_keys += ['left_pad_len', 'actual_prompt_len', 'is_finished', 'rounds_eos_mask'] + batch = data.select(batch_keys=select_keys).batch + has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() + + # Split to make minibatch iterator for updating the actor + # See PPO paper for details. https://arxiv.org/abs/1707.06347 + if has_multi_modal_inputs: + num_mini_batches = data.batch.batch_size[0] // self.config.ppo_mini_batch_size + non_tensor_select_keys = ["multi_modal_inputs"] + dataloader = data.select(select_keys, non_tensor_select_keys).chunk(num_mini_batches) + else: + dataloader = batch.split(self.config.ppo_mini_batch_size) + + for epoch in range(self.config.ppo_epochs): + for batch_idx, data in enumerate(dataloader): + # split batch into micro_batches + mini_batch = data + ############################################################### + if use_window_rollout: + mini_batch_mask, new_batch_size = [], 0 + for item in mini_batch: + if item['is_finished']: + new_batch_size += 1 + mini_batch_mask.append(True) + else: + mini_batch_mask.append(False) + if dist.is_initialized(): + new_batch_sizes = torch.tensor([new_batch_size], device='cuda') + dist.all_reduce(new_batch_sizes, op=dist.ReduceOp.MAX, group=None) + new_batch_sizes = new_batch_sizes.cpu().item() + else: + new_batch_sizes = new_batch_size + dp_size = torch.distributed.get_world_size() // self.config.tp_size // self.config.ulysses_sequence_parallel_size + total_num = (new_batch_sizes // dp_size + 1 if new_batch_sizes % dp_size else new_batch_sizes // dp_size) * dp_size + if total_num == 0: + continue + elif total_num != new_batch_size: + for i, m in enumerate(mini_batch_mask): + if (not m) and new_batch_size < total_num: + mini_batch_mask[i] = True + new_batch_size += 1 + mini_batch_mask = torch.tensor(mini_batch_mask) + mini_batch = mini_batch.masked_select(mini_batch_mask) + ################################################################### + if has_multi_modal_inputs: + num_micro_batches = mini_batch.batch.batch_size[0] // self.config.ppo_micro_batch_size_per_gpu + micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches) + self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu + elif self.config.use_dynamic_bsz: + max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size + micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len) + else: + micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu) + self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu + + self.critic_optimizer.zero_grad() + + for data in micro_batches: + # Support all devices + if isinstance(data, DataProto): + data = {**data.batch.to(get_device_id()), **data.non_tensor_batch} + else: + data = data.to(get_device_id()) # critic device is cpu when using offload + responses = data["responses"] + attention_mask = data["attention_mask"] + values = data["values"] + returns = data["returns"] + response_length = responses.size(1) + + response_mask = attention_mask[:, -response_length:] # TODO (zht) maybe no use + overlong_mask = data.get('overlong_mask', None) if not use_window_rollout else data.get('is_finished', None) + if 'rounds_eos_mask' in data.keys(): + eos_mask = data['rounds_eos_mask'] # TODO(zht):confirm eos_mask + else: + eos_mask = attention_mask[:, -response_length - 1:-1] + + vpreds = self._forward_micro_batch(data) + + if returns.size(-1) < vpreds.size(-1): + vpreds = vpreds[:, :returns.size(-1)] + if returns.size(-1) < values.size(-1): + values = values[:, :returns.size(-1)] + cliprange_value_low = self.config.get('cliprange_value_low', self.config.cliprange_value) + cliprange_value_high = self.config.get('cliprange_value_high', self.config.cliprange_value) + from recipe.tppo.tppo_algos import compute_value_loss + + vf_loss, vf_clipfrac = compute_value_loss( + vpreds=vpreds, + values=values, + returns=returns, + eos_mask=eos_mask, + cliprange_value_low=cliprange_value_low, + cliprange_value_high=cliprange_value_high, + overlong_mask=overlong_mask + ) + + if self.config.use_dynamic_bsz: + # relative to the dynamic bsz + loss = vf_loss * (len(data) / self.config.ppo_mini_batch_size) + else: + loss = vf_loss / self.gradient_accumulation + + loss.backward() + + data = { + "critic/vf_loss": vf_loss.detach().item(), + "critic/vf_clipfrac": vf_clipfrac.detach().item(), + "critic/vpred_mean": masked_mean(vpreds, eos_mask).detach().item(), + } + + append_to_dict(metrics, data) + + grad_norm = self._optimizer_step() + data = {"critic/grad_norm": grad_norm.detach().item()} + append_to_dict(metrics, data) + self.critic_optimizer.zero_grad() + return metrics diff --git a/recipe/tppo/tppo_reward_manager.py b/recipe/tppo/tppo_reward_manager.py new file mode 100644 index 00000000000..5257ada6df5 --- /dev/null +++ b/recipe/tppo/tppo_reward_manager.py @@ -0,0 +1,172 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict + +import torch + +from verl import DataProto +from verl.utils.reward_score import default_compute_score +from verl.workers.reward_manager import register + + +@register("tppo") +class TPPORewardManager: + """The reward manager.""" + + def __init__(self, config, tokenizer, num_examine, compute_score=None, reward_fn_key="data_source") -> None: + """ + Initialize the NaiveRewardManager instance. + + Args: + tokenizer: The tokenizer used to decode token IDs into text. + num_examine: The number of batches of decoded responses to print to the console for debugging purpose. + compute_score: A function to compute the reward score. If None, `default_compute_score` will be used. + reward_fn_key: The key used to access the data source in the non-tensor batch data. Defaults to "data_source". + """ + self.config = config + self.tokenizer = tokenizer # Store the tokenizer for decoding token IDs + self.num_examine = num_examine # the number of batches of decoded responses to print to the console + self.compute_score = compute_score or default_compute_score + self.reward_fn_key = reward_fn_key # Store the key for accessing the data source + + def __call__(self, data: DataProto, return_dict=False, is_validation=False, use_window_rollout=True): + """We will expand this function gradually based on the available datasets""" + if "rm_scores" in data.batch.keys(): + if return_dict: + return {"reward_tensor": data.batch["rm_scores"]} + else: + return data.batch["rm_scores"] + + # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn + if use_window_rollout: + max_response_length = self.config.data.get('window_response_length', None) + max_prompt_length = self.config.data.max_prompt_length + data.batch['window_rounds'].max().item() * max_response_length + else: + max_prompt_length = self.config.data.max_prompt_length + + response_ids = data.batch['input_ids'][:, max_prompt_length:] + empty_response_ids = data.batch['input_ids'][:, self.config.data.max_prompt_length:] + + reward_tensor = torch.zeros_like(empty_response_ids, dtype=torch.float32) + + # NOTE(HanlinDu): this conditional branch seems to be redundant + if use_window_rollout: + # raw_scores = torch.zeros_like(empty_response_ids, dtype=torch.float32) + # format_scores = torch.zeros_like(empty_response_ids, dtype=torch.float32) + # len_scores = torch.zeros_like(empty_response_ids, dtype=torch.float32) + # idx_tensor = torch.zeros(response_ids.shape[0], dtype=torch.int64, device=response_ids.device) + is_finished = torch.ones(response_ids.shape[0], dtype=torch.int64, device=response_ids.device) + + print(f" --- {response_ids.shape=}, {is_finished.shape=}") + + current_mean_len = data.batch['attention_mask'][:, max_prompt_length:].sum(-1).float().mean().item() + + + reward_extra_info = defaultdict(list) + + already_print_data_sources = {} + + for i in range(len(data)): + data_item = data[i] # DataProtoItem + + # process + if use_window_rollout: + actual_prompt_length = max_prompt_length - data_item.batch['window_rounds'].item() * max_response_length + else: + actual_prompt_length = max_prompt_length + + prompt_ids = data_item.batch['input_ids'][:actual_prompt_length] + response_ids = data_item.batch['input_ids'][actual_prompt_length:] + + prompt_length = prompt_ids.shape[-1] + valid_prompt_length = data_item.batch['attention_mask'][:prompt_length].sum().item() + valid_prompt_ids = prompt_ids[-valid_prompt_length:] + response_length = response_ids.shape[-1] + valid_response_length = data_item.batch['attention_mask'][prompt_length:].sum().item() + valid_response_ids = response_ids[:valid_response_length] + + # decode + prompt_str = self.tokenizer.decode(valid_prompt_ids, skip_special_tokens=True) + response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True) + + ground_truth = data_item.non_tensor_batch["reward_model"]["ground_truth"] + data_source = data_item.non_tensor_batch[self.reward_fn_key] + + score_fn_inputs = { + "data_source": data_source, + "solution_str": response_str, + "ground_truth": ground_truth, + } + score = self.compute_score(**score_fn_inputs) + + # ================ unfinished response handling ================ + if use_window_rollout: + curr_window_round = data_item.batch['window_rounds'].item() + last_token_id = valid_response_ids[valid_response_length-1].item() + last_window_round = self.config.data.max_response_length // self.config.data.window_response_length - 1 + # is_trunc: unfinished response + is_trunc = (response_length == valid_response_length) and \ + (last_token_id != self.tokenizer.eos_token_id) + is_last_round = (curr_window_round == last_window_round) + if is_trunc and (not is_last_round): + is_finished[i] = 0 + if not is_validation: + score = 0 + # ================================================================ + + if isinstance(score, dict): + print(f" --- size of score dict: {len(score)}") + reward = score["score"] + # Store the information including original reward + for key, value in score.items(): + reward_extra_info[key].append(value) + else: + reward = score + + reward_tensor[i, valid_response_length - 1] = reward + + if data_source not in already_print_data_sources: + already_print_data_sources[data_source] = 0 + + if already_print_data_sources[data_source] < self.num_examine: + already_print_data_sources[data_source] += 1 + print("[prompt]", prompt_str) + print("[response]", response_str) + print("[ground_truth]", ground_truth) + if isinstance(score, dict): + for key, value in score.items(): + print(f"[{key}]", value) + else: + print("[score]", score) + + # raw_scores[i, valid_response_length - 1] = score + # format_scores[i, valid_response_length - 1] = 0 + # len_scores[i, valid_response_length - 1] = 0 + # idx_tensor[i] = valid_response_length - 1 + + if use_window_rollout: + # reward_extra_info['raw_scores'] = raw_scores + # reward_extra_info['idx_tensor'] = idx_tensor + reward_extra_info['is_finished'] = is_finished + # reward_extra_info['format_scores'] = format_scores + # reward_extra_info['len_scores'] = len_scores + + if return_dict: + return { + "reward_tensor": reward_tensor, + "reward_extra_info": reward_extra_info, + } + else: + return reward_tensor diff --git a/recipe/tppo/tppo_trainer.py b/recipe/tppo/tppo_trainer.py new file mode 100644 index 00000000000..435185eeb6d --- /dev/null +++ b/recipe/tppo/tppo_trainer.py @@ -0,0 +1,1234 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +FSDP PPO Trainer with Ray-based single controller. +This trainer supports model-agonistic model initialization with huggingface +""" + +import json +import os +import uuid +from collections import defaultdict +from copy import deepcopy +from dataclasses import dataclass, field +from enum import Enum +from pprint import pprint +from typing import Optional, Type + +import numpy as np +import ray +import torch +import torch.nn.functional as F +from omegaconf import OmegaConf, open_dict +from torch.utils.data import Dataset, Sampler +from torchdata.stateful_dataloader import StatefulDataLoader +from tqdm import tqdm + +from verl import DataProto +from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto +from verl.single_controller.base import Worker +from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup +from verl.single_controller.ray.base import create_colocated_worker_cls +from verl.trainer.ppo import core_algos +from verl.trainer.ppo.core_algos import AdvantageEstimator +from verl.trainer.ppo.metric_utils import ( + compute_data_metrics, + compute_throughout_metrics, + compute_timing_metrics, + process_validation_metrics, +) +from verl.trainer.ppo.reward import compute_reward, compute_reward_async +from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path +from verl.utils.debug import marked_timer +from verl.utils.metric import ( + reduce_metrics, +) +from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance +from verl.utils.torch_functional import masked_mean +from verl.utils.tracking import ValidationGenerationsLogger +from verl.trainer.ppo.ray_trainer import Role, ResourcePoolManager + +from recipe.tppo.tppo_algos import compute_truncate_gae_advantage_return, agg_loss + +from recipe.tppo.sample_pool import SamplePool +import queue + +WorkerType = Type[Worker] + +def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl", multi_turn=False): + """Apply KL penalty to the token-level rewards. + + This function computes the KL divergence between the reference policy and current policy, + then applies a penalty to the token-level rewards based on this divergence. + + Args: + data (DataProto): The data containing batched model outputs and inputs. + kl_ctrl (core_algos.AdaptiveKLController): Controller for adaptive KL penalty. + kl_penalty (str, optional): Type of KL penalty to apply. Defaults to "kl". + multi_turn (bool, optional): Whether the data is from a multi-turn conversation. Defaults to False. + + Returns: + tuple: A tuple containing: + - The updated data with token-level rewards adjusted by KL penalty + - A dictionary of metrics related to the KL penalty + """ + responses = data.batch["responses"] + response_length = responses.size(1) + token_level_scores = data.batch["token_level_scores"] + batch_size = data.batch.batch_size[0] + + if multi_turn: + loss_mask = data.batch["loss_mask"] + response_mask = loss_mask[:, -response_length:] + else: + attention_mask = data.batch["attention_mask"] + response_mask = attention_mask[:, -response_length:] + + rollout_log_probs = data.batch['rollout_log_probs'] + old_log_probs = data.batch['old_log_probs'] + if rollout_log_probs.size(1) < old_log_probs.size(1): ##### NOTE(zht) + rollout_log_probs = F.pad(rollout_log_probs, (response_length - rollout_log_probs.size(1), 0), value=0) + # compute kl between ref_policy and current policy + # When apply_kl_penalty, algorithm.use_kl_in_reward=True, so the reference model has been enabled. + kld = core_algos.kl_penalty(old_log_probs, data.batch["ref_log_prob"], kl_penalty=kl_penalty) # (batch_size, response_length) + kld = kld * response_mask + beta = kl_ctrl.value + + token_level_rewards = token_level_scores - beta * kld if beta > 0 else token_level_scores ###### NOTE(zht) + + current_kl = masked_mean(kld, mask=response_mask, axis=-1) # average over sequence + current_kl = torch.mean(current_kl, dim=0).item() + + # according to https://github.com/huggingface/trl/blob/951ca1841f29114b969b57b26c7d3e80a39f75a0/trl/trainer/ppo_trainer.py#L837 + kl_ctrl.update(current_kl=current_kl, n_steps=batch_size) + data.batch["token_level_rewards"] = token_level_rewards + + metrics = {"actor/reward_kl_penalty": current_kl, "actor/reward_kl_penalty_coeff": beta} + + return data, metrics + + +def print_batch_info(data: DataProto, prefix=""): + """Print batch information for debugging purposes. + + This function prints the shapes of tensors in the batch and non-tensor batch data, + along with the batch size. + + Args: + data (DataProto): The data containing batched model outputs and inputs. + prefix (str, optional): A prefix string to prepend to each printed line. Defaults to "". + """ + batch_size = data.batch.batch_size[0] + non_t_kv_check_list = {} + for key, val in data.non_tensor_batch.items(): + non_t_kv_check_list[key] = val.shape + try: + print(f" -*- {prefix}: bsz = {batch_size}, is_f = {data.batch['is_finished'].shape[0]}") + except KeyError: + print(f" -*- {prefix}: bsz = {batch_size}, is_f not defined") + print(f" -*- {prefix} non_t_list = {non_t_kv_check_list}") + + +def compute_response_mask(data: DataProto): + """Compute the attention mask for the response part of the sequence. + + This function extracts the portion of the attention mask that corresponds to the model's response, + which is used for masking computations that should only apply to response tokens. + + Args: + data (DataProto): The data containing batched model outputs and inputs. + + Returns: + torch.Tensor: The attention mask for the response tokens. + """ + responses = data.batch["responses"] + response_length = responses.size(1) + attention_mask = data.batch["attention_mask"] + return attention_mask[:, -response_length:] + + +def compute_advantage_tppo(data: DataProto, gamma, lam, use_variable_lambda, variable_lambda_scalar, + adv_estimator, adv_whiten=False, use_separate_critic_lam=True, critic_lam=1.0, + adv_bias=0, window_response_length=None, ignore_token_num=0, is_clamp=False): + # TODO: add other ways to estimate advantages + token_level_rewards = data.batch['token_level_rewards'] + responses = data.batch['responses'] + response_length = responses.size(1) + attention_mask = data.batch['attention_mask'] + response_mask = attention_mask[:, -response_length:] + # response_mask = data.batch['response_mask'] + + if adv_estimator == 'gae-trunc' and (window_response_length is not None): + is_finished = data.batch['is_finished'] + values = data.batch['values'] + total_len = values.size(1) + new_values = torch.zeros_like(values) + window_rounds = data.batch['window_rounds'] + rounds_eos_mask = torch.zeros_like(token_level_rewards, dtype=response_mask.dtype) + single_token_level_rewards = torch.zeros_like(token_level_rewards) + total_len1 = token_level_rewards.size(1) + for idx, window_round in enumerate(window_rounds): + new_values[idx, :total_len-window_round * window_response_length] = values[idx, window_round * window_response_length:] + rounds_eos_mask[idx, window_round * window_response_length : (window_round+1) * window_response_length] = response_mask[idx, :] + rounds_eos_mask[idx, :window_round * window_response_length] = 1 + single_token_level_rewards[idx, :total_len1-window_round * window_response_length] = token_level_rewards[idx, window_round * window_response_length:] + seq_len_per_sample = torch.clamp(torch.sum(rounds_eos_mask, dim=1), min=1.0) + single_token_level_rewards = single_token_level_rewards[:, :window_response_length] + for idx, window_round in enumerate(window_rounds): + if not is_finished[idx]: + rounds_eos_mask[idx, :] = 0 + # torch.set_printoptions(threshold=torch.inf) + # print('@@@@@@@@@@@@@@@@ compute_advantage: window_rounds ', window_rounds, ' rounds_eos_mask ', rounds_eos_mask.max(), rounds_eos_mask, ' seq_len_per_sample ', seq_len_per_sample) + rounds_eos_mask = torch.clamp(rounds_eos_mask, min=0, max=1) + data.batch['rounds_eos_mask'] = rounds_eos_mask + values = new_values[:, :window_response_length] + origin_advantages, advantages, returns = compute_truncate_gae_advantage_return( + token_level_rewards=token_level_rewards, + single_token_level_rewards=single_token_level_rewards, + values=values, + eos_mask=response_mask, + gamma=gamma, + lam=lam, + use_variable_lambda=use_variable_lambda, + variable_lambda_scalar=variable_lambda_scalar, + adv_whiten=adv_whiten, + use_separate_critic_lam=use_separate_critic_lam, + critic_lam=critic_lam, + is_finished=is_finished, + ignore_token_num=ignore_token_num, + rounds_eos_mask=rounds_eos_mask, + seq_len_per_sample=seq_len_per_sample, + is_clamp=is_clamp) + data.batch['advantages'] = advantages + adv_bias + data.batch['origin_advantages'] = origin_advantages + data.batch['returns'] = returns + data.batch['upgo_advantages'] = torch.zeros_like(advantages) + adv_metrics = {} + else: + raise NotImplementedError + return data, adv_metrics + + +class RayPPOTrainer: + # TODO: support each role have individual ray_worker_group_cls, + # i.e., support different backend of different role + def __init__( + self, + config, + tokenizer, + role_worker_mapping: dict[Role, WorkerType], + resource_pool_manager: ResourcePoolManager, + ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup, + processor=None, + reward_fn=None, + val_reward_fn=None, + train_dataset: Optional[Dataset] = None, + val_dataset: Optional[Dataset] = None, + collate_fn=None, + train_sampler: Optional[Sampler] = None, + device_name="cuda", + ): + """ + Initialize distributed PPO trainer with Ray backend. + Note that this trainer runs on the driver process on a single CPU/GPU node. + + Args: + config: Configuration object containing training parameters. + tokenizer: Tokenizer used for encoding and decoding text. + role_worker_mapping (dict[Role, WorkerType]): Mapping from roles to worker classes. + resource_pool_manager (ResourcePoolManager): Manager for Ray resource pools. + ray_worker_group_cls (RayWorkerGroup, optional): Class for Ray worker groups. Defaults to RayWorkerGroup. + processor: Optional data processor, used for multimodal data + reward_fn: Function for computing rewards during training. + val_reward_fn: Function for computing rewards during validation. + train_dataset (Optional[Dataset], optional): Training dataset. Defaults to None. + val_dataset (Optional[Dataset], optional): Validation dataset. Defaults to None. + collate_fn: Function to collate data samples into batches. + train_sampler (Optional[Sampler], optional): Sampler for the training dataset. Defaults to None. + device_name (str, optional): Device name for training (e.g., "cuda", "cpu"). Defaults to "cuda". + """ + + # Store the tokenizer for text processing + self.tokenizer = tokenizer + self.processor = processor + self.config = config + self.reward_fn = reward_fn + self.val_reward_fn = val_reward_fn + + self.hybrid_engine = config.actor_rollout_ref.hybrid_engine + assert self.hybrid_engine, "Currently, only support hybrid engine" + + if self.hybrid_engine: + assert Role.ActorRollout in role_worker_mapping, f"{role_worker_mapping.keys()=}" + + self.role_worker_mapping = role_worker_mapping + self.resource_pool_manager = resource_pool_manager + self.use_reference_policy = Role.RefPolicy in role_worker_mapping + self.use_rm = Role.RewardModel in role_worker_mapping + self.ray_worker_group_cls = ray_worker_group_cls + self.device_name = device_name + self.validation_generations_logger = ValidationGenerationsLogger() + + # if ref_in_actor is True, the reference policy will be actor without lora applied + self.ref_in_actor = config.actor_rollout_ref.model.get("lora_rank", 0) > 0 + + # define in-reward KL control + # kl loss control currently not suppoorted + if config.algorithm.use_kl_in_reward: + self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl) + + # if self.config.algorithm.adv_estimator == AdvantageEstimator.GAE: + # self.use_critic = True + # elif self.config.algorithm.adv_estimator in [ + # AdvantageEstimator.GRPO, + # AdvantageEstimator.GRPO_PASSK, + # AdvantageEstimator.REINFORCE_PLUS_PLUS, + # AdvantageEstimator.REMAX, + # AdvantageEstimator.RLOO, + # AdvantageEstimator.OPO, + # AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE, + # ]: + # self.use_critic = False + # else: + # raise NotImplementedError + self.use_critic = True + self.num_bon = config.actor_rollout_ref.rollout.get("num_bon", 1) + + self._validate_config() + self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler) + + def _validate_config(self): + config = self.config + # number of GPUs total + n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes + if config.actor_rollout_ref.actor.strategy == "megatron": + model_parallel_size = config.actor_rollout_ref.actor.megatron.tensor_model_parallel_size * config.actor_rollout_ref.actor.megatron.pipeline_model_parallel_size + assert n_gpus % (model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size) == 0, f"n_gpus ({n_gpus}) must be divisible by model_parallel_size ({model_parallel_size}) times context_parallel_size ({config.actor_rollout_ref.actor.megatron.context_parallel_size})" + megatron_dp = n_gpus // (model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size) + minimal_bsz = megatron_dp * config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu + else: + minimal_bsz = n_gpus + + # 1. Check total batch size for data correctness + real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n + assert real_train_batch_size % minimal_bsz == 0, f"real_train_batch_size ({real_train_batch_size}) must be divisible by minimal possible batch size ({minimal_bsz})" + + # A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu" + # We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu". + def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): + settings = { + "actor_rollout_ref.actor": "micro_batch_size", + "critic": "micro_batch_size", + "reward_model": "micro_batch_size", + "actor_rollout_ref.ref": "log_prob_micro_batch_size", + "actor_rollout_ref.rollout": "log_prob_micro_batch_size", + } + + if name in settings: + param = settings[name] + param_per_gpu = f"{param}_per_gpu" + + if mbs is None and mbs_per_gpu is None: + raise ValueError(f"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'.") + + if mbs is not None and mbs_per_gpu is not None: + raise ValueError(f"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. Please remove '{name}.{param}' because only '*_{param_per_gpu}'" + "is supported (the former is deprecated).") + + if not config.actor_rollout_ref.actor.use_dynamic_bsz: + # actor: ppo_micro_batch_size vs. ppo_micro_batch_size_per_gpu + check_mutually_exclusive( + config.actor_rollout_ref.actor.ppo_micro_batch_size, + config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu, + "actor_rollout_ref.actor", + ) + + if self.use_reference_policy: + # reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu + check_mutually_exclusive( + config.actor_rollout_ref.ref.log_prob_micro_batch_size, + config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu, + "actor_rollout_ref.ref", + ) + + # The rollout section also has log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu + check_mutually_exclusive( + config.actor_rollout_ref.rollout.log_prob_micro_batch_size, + config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu, + "actor_rollout_ref.rollout", + ) + + if self.use_critic and not config.critic.use_dynamic_bsz: + # Check for critic micro-batch size conflicts + check_mutually_exclusive(config.critic.ppo_micro_batch_size, config.critic.ppo_micro_batch_size_per_gpu, "critic") + + # Check for reward model micro-batch size conflicts + if config.reward_model.enable and not config.reward_model.use_dynamic_bsz: + check_mutually_exclusive(config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, "reward_model") + + # Actor + # check if train_batch_size is larger than ppo_mini_batch_size + # if NOT dynamic_bsz, we must ensure: + # ppo_mini_batch_size is divisible by ppo_micro_batch_size + # ppo_micro_batch_size * sequence_parallel_size >= n_gpus + if not config.actor_rollout_ref.actor.use_dynamic_bsz: + assert config.data.train_batch_size >= config.actor_rollout_ref.actor.ppo_mini_batch_size + sp_size = config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) + if config.actor_rollout_ref.actor.ppo_micro_batch_size is not None: + assert config.actor_rollout_ref.actor.ppo_mini_batch_size % config.actor_rollout_ref.actor.ppo_micro_batch_size == 0 + assert config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size >= n_gpus + + assert config.actor_rollout_ref.actor.loss_agg_mode in [ + "token-mean", + "seq-mean-token-sum", + "seq-mean-token-mean", + "seq-mean-token-sum-norm", + "minibatch-sum", + "batch-sum", + ], f"Invalid loss_agg_mode: {config.actor_rollout_ref.actor.loss_agg_mode}" + + if config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss: + print("NOTICE: You have both enabled in-reward kl and kl loss.") + + # critic + if self.use_critic and not config.critic.use_dynamic_bsz: + assert config.data.train_batch_size >= config.critic.ppo_mini_batch_size + sp_size = config.critic.get("ulysses_sequence_parallel_size", 1) + if config.critic.ppo_micro_batch_size is not None: + assert config.critic.ppo_mini_batch_size % config.critic.ppo_micro_batch_size == 0 + assert config.critic.ppo_micro_batch_size * sp_size >= n_gpus + + # Check if use_remove_padding is enabled when using sequence parallelism for fsdp + if config.actor_rollout_ref.actor.strategy == "fsdp" and (config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) > 1 or config.actor_rollout_ref.ref.get("ulysses_sequence_parallel_size", 1) > 1): + assert config.actor_rollout_ref.model.use_remove_padding, "When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`." + + if self.use_critic and config.critic.strategy == "fsdp": + if config.critic.get("ulysses_sequence_parallel_size", 1) > 1: + assert config.critic.model.use_remove_padding, "When using sequence parallelism for critic, you must enable `use_remove_padding`." + + if config.data.get("val_batch_size", None) is not None: + print("WARNING: val_batch_size is deprecated." + " Validation datasets are sent to inference engines as a whole batch," + " which will schedule the memory themselves.") + + # check eval config + if config.actor_rollout_ref.rollout.val_kwargs.do_sample: + assert config.actor_rollout_ref.rollout.temperature > 0, "validation gen temperature should be greater than 0 when enabling do_sample" + + # check multi_turn with tool config + if config.actor_rollout_ref.rollout.multi_turn.enable: + assert config.actor_rollout_ref.rollout.multi_turn.tool_config_path is not None or config.actor_rollout_ref.rollout.multi_turn.interaction_config_path is not None, "tool_config_path or interaction_config_path must be set when enabling multi_turn with tool, due to no role-playing support" + assert config.algorithm.adv_estimator in [AdvantageEstimator.GRPO], "only GRPO is tested for multi-turn with tool" + + print("[validate_config] All configuration checks passed successfully!") + + def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler): + """ + Creates the train and validation dataloaders. + """ + # TODO: we have to make sure the batch size is divisible by the dp size + from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler + + if train_dataset is None: + train_dataset = create_rl_dataset(self.config.data.train_files, self.config.data, self.tokenizer, self.processor) + if val_dataset is None: + val_dataset = create_rl_dataset(self.config.data.val_files, self.config.data, self.tokenizer, self.processor) + self.train_dataset, self.val_dataset = train_dataset, val_dataset + + if train_sampler is None: + train_sampler = create_rl_sampler(self.config.data, self.train_dataset) + if collate_fn is None: + from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn + + collate_fn = default_collate_fn + + self.train_dataloader = StatefulDataLoader( + dataset=self.train_dataset, + batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size), + num_workers=self.config.data.get("dataloader_num_workers", 8), + drop_last=True, + collate_fn=collate_fn, + sampler=train_sampler, + ) + + val_batch_size = self.config.data.val_batch_size # Prefer config value if set + if val_batch_size is None: + val_batch_size = len(self.val_dataset) + + self.val_dataloader = StatefulDataLoader( + dataset=self.val_dataset, + batch_size=val_batch_size, + num_workers=self.config.data.get("dataloader_num_workers", 8), + shuffle=self.config.data.get("validation_shuffle", True), + drop_last=False, + collate_fn=collate_fn, + ) + + assert len(self.train_dataloader) >= 1, "Train dataloader is empty!" + assert len(self.val_dataloader) >= 1, "Validation dataloader is empty!" + + print(f"Size of train dataloader: {len(self.train_dataloader)}, Size of val dataloader: {len(self.val_dataloader)}") + + total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs + + if self.config.trainer.total_training_steps is not None: + total_training_steps = self.config.trainer.total_training_steps + + self.total_training_steps = total_training_steps + print(f"Total training steps: {self.total_training_steps}") + + try: + OmegaConf.set_struct(self.config, True) + with open_dict(self.config): + if OmegaConf.select(self.config, "actor_rollout_ref.actor.optim"): + self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps + if OmegaConf.select(self.config, "critic.optim"): + self.config.critic.optim.total_training_steps = total_training_steps + except Exception as e: + print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}") + + def _dump_generations(self, inputs, outputs, scores, reward_extra_infos_dict, dump_path): + """Dump rollout/validation samples as JSONL.""" + os.makedirs(dump_path, exist_ok=True) + filename = os.path.join(dump_path, f"{self.global_steps}.jsonl") + + n = len(inputs) + base_data = { + "input": inputs, + "output": outputs, + "score": scores, + "step": [self.global_steps] * n, + } + + for k, v in reward_extra_infos_dict.items(): + if len(v) == n: + base_data[k] = v + + lines = [] + for i in range(n): + entry = {k: v[i] for k, v in base_data.items()} + lines.append(json.dumps(entry, ensure_ascii=False)) + + with open(filename, "w") as f: + f.write("\n".join(lines) + "\n") + + print(f"Dumped generations to {filename}") + + def _maybe_log_val_generations(self, inputs, outputs, scores): + """Log a table of validation samples to the configured logger (wandb or swanlab)""" + + generations_to_log = self.config.trainer.log_val_generations + + if generations_to_log == 0: + return + + import numpy as np + + # Create tuples of (input, output, score) and sort by input text + samples = list(zip(inputs, outputs, scores)) + samples.sort(key=lambda x: x[0]) # Sort by input text + + # Use fixed random seed for deterministic shuffling + rng = np.random.RandomState(42) + rng.shuffle(samples) + + # Take first N samples after shuffling + samples = samples[:generations_to_log] + + # Log to each configured logger + self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps) + + def _validate(self): + data_source_lst = [] + reward_extra_infos_dict: dict[str, list] = defaultdict(list) + + # Lists to collect samples for the table + sample_inputs = [] + sample_outputs = [] + sample_scores = [] + + for test_data in self.val_dataloader: + test_batch = DataProto.from_single_dict(test_data) + + # repeat test batch + test_batch = test_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True) + + # we only do validation on rule-based rm + if self.config.reward_model.enable and test_batch[0].non_tensor_batch["reward_model"]["style"] == "model": + return {} + + # Store original inputs + input_ids = test_batch.batch["input_ids"] + # TODO: Can we keep special tokens except for padding tokens? + input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] + sample_inputs.extend(input_texts) + + batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] + non_tensor_batch_keys_to_pop = ["raw_prompt_ids"] + if "multi_modal_data" in test_batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("multi_modal_data") + if "raw_prompt" in test_batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("raw_prompt") + if "tools_kwargs" in test_batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("tools_kwargs") + if "interaction_kwargs" in test_batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("interaction_kwargs") + test_gen_batch = test_batch.pop( + batch_keys=batch_keys_to_pop, + non_tensor_batch_keys=non_tensor_batch_keys_to_pop, + ) + + test_gen_batch.meta_info = { + "eos_token_id": self.tokenizer.eos_token_id, + "pad_token_id": self.tokenizer.pad_token_id, + "recompute_log_prob": False, + "do_sample": self.config.actor_rollout_ref.rollout.val_kwargs.do_sample, + "validate": True, + } + print(f"test_gen_batch meta info: {test_gen_batch.meta_info}") + + # pad to be divisible by dp_size + test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size) + if not self.async_rollout_mode: + test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded) + else: + self.async_rollout_manager.wake_up() + test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded) + self.async_rollout_manager.sleep() + + # unpad + test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size) + print("validation generation end") + + # Store generated outputs + output_ids = test_output_gen_batch.batch["responses"] + output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] + sample_outputs.extend(output_texts) + + test_batch = test_batch.union(test_output_gen_batch) + + # evaluate using reward_function + result = self.val_reward_fn(test_batch, return_dict=True, use_window_rollout=False) + reward_tensor = result["reward_tensor"] + scores = reward_tensor.sum(-1).cpu().tolist() + sample_scores.extend(scores) + + reward_extra_infos_dict["reward"].extend(scores) + print(f"len reward_extra_infos_dict['reward']: {len(reward_extra_infos_dict['reward'])}") + if "reward_extra_info" in result: + for key, lst in result["reward_extra_info"].items(): + reward_extra_infos_dict[key].extend(lst) + print(f"len reward_extra_infos_dict['{key}']: {len(reward_extra_infos_dict[key])}") + + data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0])) + + self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores) + + # dump generations + val_data_dir = self.config.trainer.get("validation_data_dir", None) + if val_data_dir: + self._dump_generations( + inputs=sample_inputs, + outputs=sample_outputs, + scores=sample_scores, + reward_extra_infos_dict=reward_extra_infos_dict, + dump_path=val_data_dir, + ) + + for key_info, lst in reward_extra_infos_dict.items(): + assert len(lst) == 0 or len(lst) == len(sample_scores), f"{key_info}: {len(lst)=}, {len(sample_scores)=}" + + data_sources = np.concatenate(data_source_lst, axis=0) + + data_src2var2metric2val = process_validation_metrics(data_sources, sample_inputs, reward_extra_infos_dict) + metric_dict = {} + for data_source, var2metric2val in data_src2var2metric2val.items(): + core_var = "acc" if "acc" in var2metric2val else "reward" + for var_name, metric2val in var2metric2val.items(): + n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()]) + for metric_name, metric_val in metric2val.items(): + if (var_name == core_var) and any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best"]) and (f"@{n_max}" in metric_name): + metric_sec = "val-core" + else: + metric_sec = "val-aux" + pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}" + metric_dict[pfx] = metric_val + + return metric_dict + + def init_workers(self): + """Initialize distributed training workers using Ray backend. + + Creates: + 1. Ray resource pools from configuration + 2. Worker groups for each role (actor, critic, etc.) + """ + self.resource_pool_manager.create_resource_pool() + + self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} + + # create actor and rollout + if self.hybrid_engine: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout) + actor_rollout_cls = RayClassWithInitArgs( + cls=self.role_worker_mapping[Role.ActorRollout], + config=self.config.actor_rollout_ref, + role="actor_rollout", + ) + self.resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls + else: + raise NotImplementedError + + # create critic + if self.use_critic: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) + critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=self.config.critic) + self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls + + # create reference policy if needed + if self.use_reference_policy: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) + ref_policy_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RefPolicy], config=self.config.actor_rollout_ref, role="ref") + self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls + + # create a reward model if reward_fn is None + if self.use_rm: + # we create a RM here + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) + rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model) + self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls + + # initialize WorkerGroup + # NOTE: if you want to use a different resource pool for each role, which can support different parallel size, + # you should not use `create_colocated_worker_cls`. + # Instead, directly pass different resource pool to different worker groups. + # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information. + all_wg = {} + wg_kwargs = {} # Setting up kwargs for RayWorkerGroup + if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None: + wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout + if OmegaConf.select(self.config.trainer, "profile_steps") is not None: + wg_kwargs["profile_steps"] = OmegaConf.select(self.config.trainer, "profile_steps") + assert OmegaConf.select(self.config.trainer, "worker_nsight_options") is not None, "worker_nsight_options must be set when profile_steps is set" + wg_kwargs["worker_nsight_options"] = OmegaConf.to_container(OmegaConf.select(self.config.trainer, "worker_nsight_options")) + + for resource_pool, class_dict in self.resource_pool_to_cls.items(): + worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) + wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls, device_name=self.device_name, **wg_kwargs) + spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) + all_wg.update(spawn_wg) + + if self.use_critic: + self.critic_wg = all_wg["critic"] + self.critic_wg.init_model() + + if self.use_reference_policy and not self.ref_in_actor: + self.ref_policy_wg = all_wg["ref"] + self.ref_policy_wg.init_model() + + if self.use_rm: + self.rm_wg = all_wg["rm"] + self.rm_wg.init_model() + + # sample pool + self.sample_pool = SamplePool(self.config, self.tokenizer) + + # we should create rollout at the end so that vllm can have a better estimation of kv cache memory + self.actor_rollout_wg = all_wg["actor_rollout"] + self.actor_rollout_wg.init_model() + + # create async rollout manager and request scheduler + self.async_rollout_mode = False + if self.config.actor_rollout_ref.rollout.mode == "async": + from verl.workers.rollout.async_server import AsyncLLMServerManager + + self.async_rollout_mode = True + self.async_rollout_manager = AsyncLLMServerManager( + config=self.config, + worker_group=self.actor_rollout_wg, + ) + + def _save_checkpoint(self): + from verl.utils.fs import local_mkdir_safe + + # path: given_path + `/global_step_{global_steps}` + `/actor` + local_global_step_folder = os.path.join(self.config.trainer.default_local_dir, f"global_step_{self.global_steps}") + + print(f"local_global_step_folder: {local_global_step_folder}") + actor_local_path = os.path.join(local_global_step_folder, "actor") + + actor_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "actor") + + remove_previous_ckpt_in_save = self.config.trainer.get("remove_previous_ckpt_in_save", False) + if remove_previous_ckpt_in_save: + print("Warning: remove_previous_ckpt_in_save is deprecated," + " set max_actor_ckpt_to_keep=1 and max_critic_ckpt_to_keep=1 instead") + max_actor_ckpt_to_keep = self.config.trainer.get("max_actor_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 + max_critic_ckpt_to_keep = self.config.trainer.get("max_critic_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 + + self.actor_rollout_wg.save_checkpoint(actor_local_path, actor_remote_path, self.global_steps, max_ckpt_to_keep=max_actor_ckpt_to_keep) + + if self.use_critic: + critic_local_path = os.path.join(local_global_step_folder, "critic") + critic_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "critic") + self.critic_wg.save_checkpoint(critic_local_path, critic_remote_path, self.global_steps, max_ckpt_to_keep=max_critic_ckpt_to_keep) + + # save dataloader + local_mkdir_safe(local_global_step_folder) + dataloader_local_path = os.path.join(local_global_step_folder, "data.pt") + dataloader_state_dict = self.train_dataloader.state_dict() + torch.save(dataloader_state_dict, dataloader_local_path) + + # latest checkpointed iteration tracker (for atomic usage) + local_latest_checkpointed_iteration = os.path.join(self.config.trainer.default_local_dir, "latest_checkpointed_iteration.txt") + with open(local_latest_checkpointed_iteration, "w") as f: + f.write(str(self.global_steps)) + + def _load_checkpoint(self): + if self.config.trainer.resume_mode == "disable": + return 0 + + # load from hdfs + if self.config.trainer.default_hdfs_dir is not None: + raise NotImplementedError("load from hdfs is not implemented yet") + else: + checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path + if not os.path.isabs(checkpoint_folder): + working_dir = os.getcwd() + checkpoint_folder = os.path.join(working_dir, checkpoint_folder) + global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest + + # find global_step_folder + if self.config.trainer.resume_mode == "auto": + if global_step_folder is None: + print("Training from scratch") + return 0 + else: + if self.config.trainer.resume_mode == "resume_path": + assert isinstance(self.config.trainer.resume_from_path, str), "resume ckpt must be str type" + assert "global_step_" in self.config.trainer.resume_from_path, "resume ckpt must specify the global_steps" + global_step_folder = self.config.trainer.resume_from_path + if not os.path.isabs(global_step_folder): + working_dir = os.getcwd() + global_step_folder = os.path.join(working_dir, global_step_folder) + print(f"Load from checkpoint folder: {global_step_folder}") + # set global step + self.global_steps = int(global_step_folder.split("global_step_")[-1]) + + print(f"Setting global step to {self.global_steps}") + print(f"Resuming from {global_step_folder}") + + actor_path = os.path.join(global_step_folder, "actor") + critic_path = os.path.join(global_step_folder, "critic") + # load actor + self.actor_rollout_wg.load_checkpoint(actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load) + # load critic + if self.use_critic: + self.critic_wg.load_checkpoint(critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load) + + # load dataloader, + # TODO: from remote not implemented yet + dataloader_local_path = os.path.join(global_step_folder, "data.pt") + if os.path.exists(dataloader_local_path): + dataloader_state_dict = torch.load(dataloader_local_path, weights_only=False) + self.train_dataloader.load_state_dict(dataloader_state_dict) + else: + print(f"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch") + + def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen"): + """Reorder the data on single controller such that each dp rank gets similar total tokens""" + attention_mask = batch.batch["attention_mask"] + batch_size = attention_mask.shape[0] + global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,) + world_size = self.actor_rollout_wg.world_size + global_partition_lst = get_seqlen_balanced_partitions(global_seqlen_lst, k_partitions=world_size, equal_size=True) + # reorder based on index. The data will be automatically equally partitioned by dispatch function + global_idx = torch.tensor([j for partition in global_partition_lst for j in partition]) + batch.reorder(global_idx) + global_balance_stats = log_seqlen_unbalance(seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix) + metrics.update(global_balance_stats) + + def fit(self): + """ + The training loop of PPO. + The driver process only need to call the compute functions of the worker group through RPC + to construct the PPO dataflow. + The light-weight advantage computation is done on the driver process. + """ + from omegaconf import OmegaConf + + from verl.utils.tracking import Tracking + + logger = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + ) + + self.global_steps = 0 + + # load checkpoint before doing anything + self._load_checkpoint() + + # perform validation before training + # currently, we only support validation using the reward_function. + if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): + val_metrics = self._validate() + assert val_metrics, f"{val_metrics=}" + pprint(f"Initial validation metrics: {val_metrics}") + logger.log(data=val_metrics, step=self.global_steps) + if self.config.trainer.get("val_only", False): + return + + # add tqdm + progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") + + # we start from step 1 + self.global_steps += 1 + last_val_metrics = None + + for epoch in range(self.config.trainer.total_epochs): + for step_idx, batch_dict in enumerate(self.train_dataloader): + do_profile = self.global_steps in self.config.trainer.profile_steps if self.config.trainer.profile_steps is not None else False + if do_profile: + self.actor_rollout_wg.start_profile() + if self.use_reference_policy: + self.ref_policy_wg.start_profile() + if self.use_critic: + self.critic_wg.start_profile() + if self.use_rm: + self.rm_wg.start_profile() + + metrics = {} + timing_raw = {} + batch: DataProto = DataProto.from_single_dict(batch_dict) + + # max len + use_window_rollout = (self.config.data.get('window_response_length', None) is not None) and \ + ((self.config.algorithm.get("mix_window_freq", None) is None) or \ + (self.global_step % self.config.algorithm.mix_window_freq != 0)) + + if use_window_rollout: + assert not self.config.algorithm.force_append_eos, '`window_response_length` cannot used with `force_append_eos`.' + self.sample_pool.fill_sample_pool(batch) + self.sample_pool.rearrange_sample_pool() + batch = self.sample_pool.get_gen_batch(self.config.data.train_batch_size * self.num_bon) + + max_response_length = self.config.data.get('window_response_length', None) + max_window_rounds = batch.batch['window_rounds'].max().item() + print(f" --- max_window_rounds at step {step_idx} : {max_window_rounds}") + max_prompt_length = self.config.data.max_prompt_length + max_window_rounds * max_response_length + else: + max_response_length = self.config.data.max_response_length + max_prompt_length = self.config.data.max_prompt_length + + if 'rollout_log_probs' not in batch: + batch.batch['rollout_log_probs'] = torch.zeros(batch.batch['input_ids'].shape[0], + max_response_length, + dtype=torch.bfloat16, + device=batch.batch['input_ids'].device).fill_(-1) + + # pop those keys for generation + batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids", "rollout_log_probs"] + non_tensor_batch_keys_to_pop = ["raw_prompt_ids"] + if "multi_modal_data" in batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("multi_modal_data") + if "raw_prompt" in batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("raw_prompt") + if "tools_kwargs" in batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("tools_kwargs") + if "interaction_kwargs" in batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("interaction_kwargs") + gen_batch = batch.pop( + batch_keys=batch_keys_to_pop, + non_tensor_batch_keys=non_tensor_batch_keys_to_pop, + ) + + is_last_step = self.global_steps >= self.total_training_steps + + with marked_timer("step", timing_raw): + # generate a batch + print(" --- Begin rollout step ---") + with marked_timer("gen", timing_raw, color="red"): + if not self.async_rollout_mode: + gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) + else: + self.async_rollout_manager.wake_up() + gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch) + self.async_rollout_manager.sleep() + + # NOTE(HanlinDu): should we truncate the all of the sequences to max_response_length? + if use_window_rollout: + seq_trunc_len = max_prompt_length + max_response_length + for k in ["input_ids", "attention_mask", "position_ids"]: + arr = gen_batch_output.batch[k] + gen_batch_output.batch[k] = arr[:, :seq_trunc_len] + + gen_batch_output.batch['prompts'] = gen_batch_output.batch['input_ids'][:, :max_prompt_length] + gen_batch_output.batch['responses'] = gen_batch_output.batch['input_ids'][:, max_prompt_length:] + + timing_raw.update(gen_batch_output.meta_info["timing"]) + gen_batch_output.meta_info.pop("timing", None) + + + batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object) + # repeat to align with repeated responses in rollout + batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + batch = batch.union(gen_batch_output) + + batch.batch["response_mask"] = compute_response_mask(batch) + # Balance the number of valid tokens across DP ranks. + # NOTE: This usually changes the order of data in the `batch`, + # which won't affect the advantage calculation (since it's based on uid), + # but might affect the loss calculation (due to the change of mini-batching). + # TODO: Decouple the DP balancing and mini-batching. + if self.config.trainer.balance_batch: + self._balance_batch(batch, metrics=metrics) + + # compute global_valid tokens + batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + + # print_batch_info(batch, prefix="980") + + with marked_timer("reward", timing_raw, color="yellow"): + # compute reward model score + if self.use_rm: + reward_tensor = self.rm_wg.compute_rm_score(batch) + batch = batch.union(reward_tensor) + + if self.config.reward_model.launch_reward_fn_async: + # FIXME async reward may track reliance bugs for now. e.g. `token_level_scores` + assert self.config.reward_model.launch_reward_fn_async + future_reward = compute_reward_async.remote(batch, self.config, self.tokenizer) + else: + reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn) + # NOTE(zht) use reward_fn only for tppo, so don't need flag arg in + # here need to return a dict instead of args + # ['raw_scores', 'format_scores', 'length_scores', 'eos_ids', 'is_finished'] + # batch.batch['raw_scores'] = reward_extra_infos_dict['raw_scores'] + # print(f" --- {batch.batch['raw_scores'].shape=}") + # batch.batch['format_scores'] = reward_extra_infos_dict['format_scores'] + # batch.batch['length_scores'] = reward_extra_infos_dict['length_scores'] + print_batch_info(batch, prefix="1000") + + if 'eos_ids' in reward_extra_infos_dict and len(reward_extra_infos_dict['eos_ids']) == len(batch.batch['input_ids']): + batch.batch['eos_ids'] = reward_extra_infos_dict['eos_ids'] + else: + print(" --- assign eos_ids by default eos_token_id") + batch.batch['eos_ids'] = torch.full( + (len(batch.batch['input_ids']),), + self.tokenizer.eos_token_id, + dtype=torch.int64, + device=batch.batch['input_ids'].device + ) + batch.batch['is_finished'] = reward_extra_infos_dict['is_finished'] + + # `token_level_scores` is initiate only in sync mode + batch.batch["token_level_scores"] = reward_tensor + + if self.config.algorithm.all_samples_with_grad and use_window_rollout: + self.sample_pool.fill_rollout_pool_grad(batch) + return_batch_size = self.config.data.actor_training_batch_size * self.num_bon + batch = self.sample_pool.get_train_batch_grad(return_batch_size) + + # recompute old_log_probs + with marked_timer("old_log_prob", timing_raw, color="blue"): + old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) + entropys = old_log_prob.batch["entropys"] + response_masks = batch.batch["response_mask"] + loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode + entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) + old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} + metrics.update(old_log_prob_metrics) + old_log_prob.batch.pop("entropys") + batch = batch.union(old_log_prob) + + if "rollout_log_probs" in batch.batch.keys(): + # TODO: we may want to add diff of probs too. + rollout_old_log_probs = batch.batch["rollout_log_probs"] + actor_old_log_probs = batch.batch["old_log_probs"] + attention_mask = batch.batch["attention_mask"] + responses = batch.batch["responses"] + response_length = responses.size(1) + response_mask = attention_mask[:, -response_length:] + + rollout_probs = torch.exp(rollout_old_log_probs) + actor_probs = torch.exp(actor_old_log_probs) + rollout_probs_diff = torch.abs(rollout_probs - actor_probs) + rollout_probs_diff = torch.masked_select(rollout_probs_diff, response_mask.bool()) + rollout_probs_diff_max = torch.max(rollout_probs_diff) + rollout_probs_diff_mean = torch.mean(rollout_probs_diff) + rollout_probs_diff_std = torch.std(rollout_probs_diff) + metrics.update( + { + "training/rollout_probs_diff_max": rollout_probs_diff_max.detach().item(), + "training/rollout_probs_diff_mean": rollout_probs_diff_mean.detach().item(), + "training/rollout_probs_diff_std": rollout_probs_diff_std.detach().item(), + } + ) + + # print_batch_info(batch, prefix="1059") + + if self.use_reference_policy: + # compute reference log_prob + with marked_timer("ref", timing_raw, color="olive"): + if not self.ref_in_actor: + ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + else: + ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch) + batch = batch.union(ref_log_prob) + + # compute values + if self.use_critic: + with marked_timer("values", timing_raw, color="cyan"): + values = self.critic_wg.compute_values(batch) + if 'values' in batch.batch and use_window_rollout and self.config.algorithm.use_actual_values: + is_finished = batch.batch['is_finished'] + window_rounds = batch.batch['window_rounds'] + saved_values = batch.batch['values'] + window_response_length = self.config.data.window_response_length + new_values = torch.zeros_like(values.batch['values']) + for idx, window_round in enumerate(window_rounds): + new_values[idx, :window_round * window_response_length] = saved_values[idx, :window_round * window_response_length] + new_values[idx, window_round * window_response_length:] = values.batch['values'][idx, window_round * window_response_length:] + batch.batch['values'] = new_values + else: + batch.batch['values'] = values.batch['values'] + + # fill multi round replay buffer + if use_window_rollout: + self.sample_pool.update_multi_round_pool(batch) + + # print_batch_info(batch, prefix="1098") + + with marked_timer("adv", timing_raw, color="brown"): + # we combine with rule-based rm + if self.config.reward_model.launch_reward_fn_async: + # FIXME async reward may track reliance bugs for now. e.g. `token_level_scores` + assert self.config.reward_model.launch_reward_fn_async + reward_tensor, reward_extra_infos_dict = ray.get(future_reward) + # batch.batch["token_level_scores"] = reward_tensor + + # NOTE(HanlinDu): we do not put reward_extra_infos_dict into non_tensor_batch, + # if reward_extra_infos_dict: + # batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()}) + + # compute rewards. apply_kl_penalty if available + if self.config.algorithm.use_kl_in_reward: + print(" --- apply kl penalty --- ") + batch, kl_metrics = apply_kl_penalty(batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty) + metrics.update(kl_metrics) + else: + batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] + + # print_batch_info(batch, prefix="1133") + + # compute advantages, executed on the driver process + batch, adv_metrics = compute_advantage_tppo( + batch, + gamma=self.config.algorithm.gamma, + lam=self.config.algorithm.lam, + use_variable_lambda=self.config.algorithm.use_variable_lambda, + variable_lambda_scalar=self.config.algorithm.variable_lambda_scalar, + adv_estimator='gae-trunc', + adv_whiten=self.config.algorithm.adv_whiten, + use_separate_critic_lam=self.config.algorithm.use_separate_critic_lam, + critic_lam=self.config.algorithm.critic_lam, + adv_bias=self.config.algorithm.adv_bias, + window_response_length=self.config.data.get('window_response_length', None) if use_window_rollout else None, + ignore_token_num=self.config.algorithm.get('window_ignore_token_num', 8), + is_clamp=self.config.algorithm.adv_clamp, + ) + metrics.update(adv_metrics) + + # print_batch_info(batch, prefix="1153") + + # update critic + if self.use_critic: + with marked_timer("update_critic", timing_raw, color="pink"): + critic_output = self.critic_wg.update_critic(batch) + critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) + metrics.update(critic_output_metrics) + + # print_batch_info(batch, prefix="1162") + + # implement critic warmup + if self.config.trainer.critic_warmup <= self.global_steps: + # update actor + with marked_timer("update_actor", timing_raw, color="red"): + batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable + actor_output = self.actor_rollout_wg.update_actor(batch) + actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) + metrics.update(actor_output_metrics) + + # Log rollout generations if enabled + rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) + if rollout_data_dir: + with marked_timer("dump_rollout_generations", timing_raw, color="green"): + print(batch.batch.keys()) + inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True) + outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True) + scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist() + self._dump_generations( + inputs=inputs, + outputs=outputs, + scores=scores, + reward_extra_infos_dict=reward_extra_infos_dict, + dump_path=rollout_data_dir, + ) + + # validate + if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0): + with marked_timer("testing", timing_raw, color="green"): + val_metrics: dict = self._validate() + if is_last_step: + last_val_metrics = val_metrics + metrics.update(val_metrics) + + if self.config.trainer.save_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.save_freq == 0): + with marked_timer("save_checkpoint", timing_raw, color="green"): + self._save_checkpoint() + + # training metrics + metrics.update( + { + "training/global_step": self.global_steps, + "training/epoch": epoch, + } + ) + # collect metrics + metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) + metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + # TODO: implement actual tflpo and theoretical tflpo + n_gpus = self.resource_pool_manager.get_n_gpus() + metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) + + # TODO: make a canonical logger that supports various backend + logger.log(data=metrics, step=self.global_steps) + + progress_bar.update(1) + self.global_steps += 1 + + if do_profile: + self.actor_rollout_wg.stop_profile() + if self.use_reference_policy: + self.ref_policy_wg.stop_profile() + if self.use_critic: + self.critic_wg.stop_profile() + if self.use_rm: + self.rm_wg.stop_profile() + + if is_last_step: + pprint(f"Final validation metrics: {last_val_metrics}") + progress_bar.close() + return diff --git a/verl/workers/config/actor.py b/verl/workers/config/actor.py index 5d7d59c197a..e5a42d2e163 100644 --- a/verl/workers/config/actor.py +++ b/verl/workers/config/actor.py @@ -126,15 +126,6 @@ def __post_init__(self): "'actor.ppo_micro_batch_size_per_gpu' if use_dynamic_bsz is not enabled." ) - valid_loss_agg_modes = [ - "token-mean", - "seq-mean-token-sum", - "seq-mean-token-mean", - "seq-mean-token-sum-norm", - ] - if self.loss_agg_mode not in valid_loss_agg_modes: - raise ValueError(f"Invalid loss_agg_mode: {self.loss_agg_mode}") - def validate(self, n_gpus: int, train_batch_size: int, model_config: dict = None): """Validate actor configuration with runtime parameters.""" if not self.use_dynamic_bsz: @@ -219,6 +210,12 @@ class FSDPActorConfig(ActorConfig): fsdp_config: FSDPEngineConfig = field(default_factory=FSDPEngineConfig) use_remove_padding: bool = False + # For tppo + window_response_length: int = 8192 + lm_loss_weight: float = 0.1 + scale_pg_by_local_kl: bool = False + scale_pg_by_kl: bool = False + def __post_init__(self): """Validate FSDP actor configuration parameters.""" super().__post_init__() diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index be2bcf50aca..26179c156c9 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -562,7 +562,9 @@ def _build_rollout(self, trust_remote_code=False): @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): - from verl.workers.actor import DataParallelPPOActor + # FIXME: do not import tppo statically + from recipe.tppo.tppo_actor import DataParallelPPOActor + # from verl.workers.actor import DataParallelPPOActor # This is used to import external_lib into the huggingface systems import_external_libs(self.config.model.get("external_lib", None)) @@ -1176,7 +1178,9 @@ def init_model(self): # This is used to import external_lib into the huggingface systems import_external_libs(self.config.model.get("external_lib", None)) - from verl.workers.critic import DataParallelPPOCritic + # FIXME: do not import tppo statically + from recipe.tppo.tppo_critic import DataParallelPPOCritic + # from verl.workers.critic import DataParallelPPOCritic self.critic_module, self.critic_optimizer, self.critic_lr_scheduler = self._build_critic_model_optimizer( self.config