From bd40a41b1df4e0b3a131254e3a9f854dd1322e4b Mon Sep 17 00:00:00 2001 From: Jiaqi Zeng Date: Mon, 20 Oct 2025 14:00:08 -0700 Subject: [PATCH] genrm rlhf Signed-off-by: Jiaqi Zeng --- examples/configs/grpo_genrm_rlhf_1B.yaml | 271 +++++++++ .../configs/grpo_genrm_rlhf_1B_megatron.yaml | 219 +++++++ examples/run_grpo_genrm_rlhf.py | 256 ++++++++ nemo_rl/algorithms/grpo.py | 5 + .../datasets/response_datasets/__init__.py | 2 + .../datasets/response_datasets/genrm_rlhf.py | 89 +++ nemo_rl/data/processors.py | 98 +++ .../ray_actor_environment_registry.py | 1 + .../environments/genrm_rlhf_environment.py | 573 ++++++++++++++++++ .../pairwise_reward_aggregators.py | 440 ++++++++++++++ 10 files changed, 1954 insertions(+) create mode 100644 examples/configs/grpo_genrm_rlhf_1B.yaml create mode 100644 examples/configs/grpo_genrm_rlhf_1B_megatron.yaml create mode 100644 examples/run_grpo_genrm_rlhf.py create mode 100644 nemo_rl/data/datasets/response_datasets/genrm_rlhf.py create mode 100644 nemo_rl/environments/genrm_rlhf_environment.py create mode 100644 nemo_rl/environments/pairwise_reward_aggregators.py diff --git a/examples/configs/grpo_genrm_rlhf_1B.yaml b/examples/configs/grpo_genrm_rlhf_1B.yaml new file mode 100644 index 0000000000..8ed0a64fa1 --- /dev/null +++ b/examples/configs/grpo_genrm_rlhf_1B.yaml @@ -0,0 +1,271 @@ +# GRPO Algorithm Configuration for GenRM RLHF +grpo: + num_prompts_per_step: 32 + num_generations_per_prompt: 8 + num_val_generations_per_prompt: 4 # Number of responses to generate per prompt during validation + max_rollout_turns: 1 # Single turn evaluation task + max_num_epochs: 1 + max_num_steps: 1000000 + normalize_rewards: true + use_leave_one_out_baseline: true + val_period: 10 + val_at_start: false + overlong_filtering: false + max_val_samples: 32 + val_batch_size: 32 + seed: 42 + async_grpo: + enabled: false # Set to true to enable async training mode + # Max age (in training steps) for trajectories used in training + max_trajectory_age_steps: 1 + +loss_fn: + reference_policy_kl_penalty: 0.01 + ratio_clip_min: 0.2 + ratio_clip_max: 0.2 + ratio_clip_c: null + # (default off) loss formulation improvements (docs/guides/grpo.md#loss) + use_on_policy_kl_approximation: false + # Async GRPO requires importance sampling correction enabled + # Set to true when async_grpo.enabled is true + use_importance_sampling_correction: false + sequence_level_importance_ratios: false + token_level_loss: true + imp_clip_max: null # Maximum value for importance sampling weight clipping (null to disable) + +checkpointing: + enabled: true + checkpoint_dir: "results/grpo_genrm_rlhf" + metric_name: "val_reward" + higher_is_better: true + keep_top_k: 3 + save_period: 10 + checkpoint_must_save_by: null + model_save_format: "safetensors" + save_consolidated: false + +policy: + model_name: "Qwen/Qwen2.5-1.5B" # Use base model, not instruct version + tokenizer: + name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default + train_global_batch_size: 512 + train_micro_batch_size: 4 + generation_batch_size: 32 # Only used when generating using HF backend + logprob_batch_size: 4 + max_total_sequence_length: 8192 # Increased for long evaluation prompts + precision: "bfloat16" + logprob_chunk_size: null + + dtensor_cfg: + _v2: true + enabled: true + cpu_offload: False + sequence_parallel: false + activation_checkpointing: false + tensor_parallel_size: 1 + context_parallel_size: 1 + custom_parallel_plan: null + + megatron_cfg: + enabled: false + empty_unused_memory_level: 0 + activation_checkpointing: false + converter_type: "Qwen2ForCausalLM" + tensor_model_parallel_size: 1 + expert_tensor_parallel_size: 1 + expert_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + num_layers_in_first_pipeline_stage: null + num_layers_in_last_pipeline_stage: null + context_parallel_size: 1 + pipeline_dtype: ${policy.precision} + sequence_parallel: false + freeze_moe_router: true + moe_router_dtype: "fp64" + moe_router_load_balancing_type: "none" # "seq_aux_loss" causes logprob error divergence for grpo + moe_router_bias_update_rate: 0.0 # by default, disable bias updates for grpo + moe_permute_fusion: false + #gives ~20% training perf speedup with sequence packing + apply_rope_fusion: True + defer_fp32_logits: null + + optimizer: + optimizer: "adam" + lr: 5.0e-6 + min_lr: 5.0e-7 + weight_decay: 0.01 + bf16: true + fp16: false + params_dtype: "float32" + + #adam + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_eps: 1e-8 + + #sgd + sgd_momentum: 0.9 + + #distributed optimizer + use_distributed_optimizer: true + use_precision_aware_optimizer: true + + clip_grad: ${policy.max_grad_norm} + + scheduler: + start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} + end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} + weight_decay_incr_style: "constant" + lr_decay_style: "constant" + lr_decay_iters: 1000 + lr_warmup_iters: 13 + lr_warmup_init: 5.0e-7 + + distributed_data_parallel_config: + grad_reduce_in_fp32: false + overlap_grad_reduce: true + overlap_param_gather: true + average_in_collective: true + use_custom_fsdp: false + data_parallel_sharding_strategy: "optim_grads_params" + + env_vars: null + + # See docs/design-docs/sequence-packing-and-dynamic-batching.md + # for more details on dynamic batching and sequence packing. + dynamic_batching: + enabled: False + train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} + logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} + sequence_length_round: 64 + + sequence_packing: + enabled: True + train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} + logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} + algorithm: "modified_first_fit_decreasing" + sequence_length_round: 64 + + # makes the training sequence length divisible by the tensor parallel size + # this is useful for sequence parallel training + make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size} + max_grad_norm: 1.0 + + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 5.0e-6 + weight_decay: 0.01 + betas: [0.9, 0.999] + eps: 1e-8 + # when using Dtensor, we need to set foreach + # and fused to False + foreach: False + fused: False + + scheduler: + - name: "torch.optim.lr_scheduler.LinearLR" + kwargs: + start_factor: 0.1 + end_factor: 1.0 + total_iters: 50 + - name: "torch.optim.lr_scheduler.ConstantLR" + kwargs: + factor: 1.0 + total_iters: 10000000000 + - milestones: [50] + + generation: + backend: "vllm" + max_new_tokens: ${policy.max_total_sequence_length} + temperature: 1.0 + top_p: 1.0 + top_k: null + stop_token_ids: null + stop_strings: null + vllm_cfg: + async_engine: false + precision: ${policy.precision} + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + enable_expert_parallel: false + gpu_memory_utilization: 0.6 + max_model_len: ${policy.max_total_sequence_length} + enforce_eager: False + use_deep_gemm: False + num_last_layers_in_bf16: 0 + num_first_layers_in_bf16: 0 + vllm_kwargs: + compilation_config: + # when enforce_eager is False, set ++policy.generation.vllm_kwargs.compilation_config.use_inductor=False for better accuracy, + # with the flag, vllm will use the custom CUDA kernels instead of the Triton kernels generated by torch.compile + # for more details, see convergence issue https://github.com/NVIDIA-NeMo/RL/issues/998 + use_inductor: False + colocated: + # true: generation shares training GPUs + # false: uses dedicated generation resources + enabled: true + # only relevant when enabled is false + resources: + gpus_per_node: null # Decides num gpus to be dedicated to generation when there is one node in the cluster i.e cluster.num_nodes == 1 + num_nodes: null # Decides number of nodes to be dedicated to generation + +# Data configuration +data: + max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len + prompt_file: null # Optional prompt template file + system_prompt_file: null # Optional system prompt file + shuffle: true + dataset_name: "genrm_rlhf" + train_data_path: "path/to/your/train.jsonl" # Update with your data path + val_data_path: "path/to/your/val.jsonl" # Optional validation data path + +env: + genrm_rlhf: + num_workers: 2 # Number of parallel GenRM workers + model_name: "nvidia/Llama-3_3-Nemotron-Super-49B-GenRM" # GenRM model for pairwise comparison + tensor_parallel_size: 4 # TP size for the GenRM model (requires substantial GPU memory) + gpu_memory_utilization: 0.95 # GPU memory utilization for GenRM model + max_model_len: 40000 # Max sequence length for GenRM model + num_generations_per_prompt: ${grpo.num_generations_per_prompt} # The expected number of responses to generate per prompt during training + num_val_generations_per_prompt: ${grpo.num_val_generations_per_prompt} # The expected number of responses per prompt during validation + num_judges_per_comparison: 1 # Number of independent GenRM passes per pairwise comparison (for majority voting) + temperature: 0.0 # GenRM temperature (usually 0 for deterministic evaluation) + top_p: 1.0 # Optional sampling parameter for vLLM + max_tokens: 32768 # Max tokens for GenRM's comparison output + stop: null # Stop strings for GenRM evaluation + reasoning_split_word: "" # The word to split the response into reasoning and answer + max_concurrency: 16 # Maximum concurrent step calls for the environment actor + # Reward aggregation configuration + aggregator_method: "simple_tiebreaker" # Options: "weighted_win_loss", "simple_tiebreaker", "individual_scores" + # aggregator_config: # Additional config for the aggregator + # score_mapping: # Mapping from ranking scores (1-6) to weighted points for weighted_win_loss + # 1: 1.0 # Much better + # 2: 0.8 # Better + # 3: 0.6 # Slightly better + # 4: 0.4 # Slightly worse + # 5: 0.2 # Worse + # 6: 0.0 # Much worse + +# Logger configuration +logger: + log_dir: "logs" # Base directory for all logs + num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal + wandb_enabled: false + tensorboard_enabled: false + mlflow_enabled: false # Disable MLflow logging + monitor_gpus: true # If true, will monitor GPU usage and log to wandb and/or tensorboard + wandb: + project: "grpo-genrm-rlhf" + name: "grpo-genrm-rlhf-logger" + tensorboard: {} + mlflow: + experiment_name: "grpo-genrm-rlhf" + run_name: "grpo-genrm-rlhf-logger" + gpu_monitoring: + collection_interval: 10 # How often to collect GPU usage metrics (in seconds) + flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) + +cluster: + gpus_per_node: 1 + num_nodes: 1 \ No newline at end of file diff --git a/examples/configs/grpo_genrm_rlhf_1B_megatron.yaml b/examples/configs/grpo_genrm_rlhf_1B_megatron.yaml new file mode 100644 index 0000000000..e132cee820 --- /dev/null +++ b/examples/configs/grpo_genrm_rlhf_1B_megatron.yaml @@ -0,0 +1,219 @@ +# GRPO Algorithm Configuration for GenRM RLHF with Megatron +defaults: "grpo_genrm_rlhf_1B.yaml" + +grpo: + num_prompts_per_step: 32 + num_generations_per_prompt: 8 + num_val_generations_per_prompt: 8 # Number of responses to generate per prompt during validation + max_rollout_turns: 1 # Single turn evaluation task + max_num_epochs: 1 + max_num_steps: 1000000 + normalize_rewards: true + use_leave_one_out_baseline: true + val_period: 10 + val_at_start: false + overlong_filtering: false + max_val_samples: 32 + val_batch_size: 32 + seed: 42 + async_grpo: + enabled: false # Set to true to enable async training mode + # Max age (in training steps) for trajectories used in training + max_trajectory_age_steps: 1 + +loss_fn: + reference_policy_kl_penalty: 0.01 + ratio_clip_min: 0.2 + ratio_clip_max: 0.2 + # (default off) loss formulation improvements (docs/guides/grpo.md#loss) + use_on_policy_kl_approximation: false + use_importance_sampling_correction: false + token_level_loss: true + ratio_clip_c: null + scale_no_clip: false + sequence_level_importance_ratios: false + +checkpointing: + enabled: true + checkpoint_dir: "results/grpo_genrm_rlhf_megatron" + metric_name: "val_reward" + higher_is_better: true + keep_top_k: 3 + save_period: 10 + checkpoint_must_save_by: null + +policy: + model_name: "Qwen/Qwen2.5-1.5B" # Use base model, not instruct version + tokenizer: + name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default + train_global_batch_size: 512 + train_micro_batch_size: 4 + generation_batch_size: 64 # Used when generating using megatron backend + logprob_batch_size: 8 + max_total_sequence_length: 8192 # Increased for long evaluation prompts + precision: "bfloat16" + + dtensor_cfg: + enabled: false + + # See docs/design-docs/sequence-packing-and-dynamic-batching.md + # for more details on dynamic batching and sequence packing. + # + # We disable dynamic batching for Megatron as it is incompatible with Pipeline parallelism. + # Instead, we use sequence packing. + dynamic_batching: + enabled: False + train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} + logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} + sequence_length_round: 64 + + sequence_packing: + enabled: True + train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} + logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} + algorithm: "modified_first_fit_decreasing" + sequence_length_round: 64 + + # makes the training sequence length divisible by the tensor parallel size + # this is useful for sequence parallel training + make_sequence_length_divisible_by: ${policy.megatron_cfg.tensor_model_parallel_size} + max_grad_norm: 1.0 + + optimizer: null # remove default FSDP optimizer + + megatron_cfg: + enabled: true + empty_unused_memory_level: 0 + activation_checkpointing: false + converter_type: "Qwen2ForCausalLM" + tensor_model_parallel_size: 1 + expert_tensor_parallel_size: 1 + expert_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + num_layers_in_first_pipeline_stage: null + num_layers_in_last_pipeline_stage: null + context_parallel_size: 1 + pipeline_dtype: ${policy.precision} + sequence_parallel: false + freeze_moe_router: true + moe_router_dtype: "bf16" + moe_router_load_balancing_type: "none" # "seq_aux_loss" causes logprob error divergence for grpo + moe_router_bias_update_rate: 0.0 # by default, disable bias updates for grpo + #gives ~20% training perf speedup with sequence packing + apply_rope_fusion: True + + optimizer: + optimizer: "adam" + lr: 5.0e-6 + min_lr: 5.0e-7 + weight_decay: 0.0 + bf16: true + fp16: false + params_dtype: "float32" + + #adam + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_eps: 1e-8 + + #sgd + sgd_momentum: 0.9 + + #distributed optimizer + use_distributed_optimizer: true + use_precision_aware_optimizer: true + + clip_grad: ${policy.max_grad_norm} + + scheduler: + start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} + end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} + weight_decay_incr_style: "constant" + lr_decay_style: "constant" + lr_decay_iters: null + lr_warmup_iters: 10 + lr_warmup_init: 0.0 + + distributed_data_parallel_config: + grad_reduce_in_fp32: false + overlap_grad_reduce: true + overlap_param_gather: true + average_in_collective: true + use_custom_fsdp: false + data_parallel_sharding_strategy: "optim_grads_params" + + env_vars: null + + generation: + backend: "vllm" + max_new_tokens: ${policy.max_total_sequence_length} + temperature: 1.0 + top_p: 1.0 + top_k: null + stop_token_ids: null + stop_strings: null # Can add specific stop strings if needed + vllm_cfg: + async_engine: false + tensor_parallel_size: 1 + gpu_memory_utilization: 0.6 + max_model_len: ${policy.max_total_sequence_length} + +# Data configuration +data: + max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len + prompt_file: null # Optional prompt template file + system_prompt_file: null # Optional system prompt file + shuffle: true + dataset_name: "genrm_rlhf" + train_data_path: "path/to/your/train.jsonl" # Update with your data path + val_data_path: "path/to/your/val.jsonl" # Optional validation data path + +env: + genrm_rlhf: + num_workers: 2 # Number of parallel GenRM workers + model_name: "nvidia/Llama-3_3-Nemotron-Super-49B-GenRM" # GenRM model for pairwise comparison + tensor_parallel_size: 4 # TP size for the GenRM model (requires substantial GPU memory) + gpu_memory_utilization: 0.95 # GPU memory utilization for GenRM model + max_model_len: 40000 # Max sequence length for GenRM model + num_generations_per_prompt: ${grpo.num_generations_per_prompt} # The expected number of responses to generate per prompt during training + num_val_generations_per_prompt: ${grpo.num_val_generations_per_prompt} # The expected number of responses per prompt during validation + num_judges_per_comparison: 1 # Number of independent GenRM passes per pairwise comparison (for majority voting) + temperature: 0.0 # GenRM temperature (usually 0 for deterministic evaluation) + top_p: 1.0 # Optional sampling parameter for vLLM + max_tokens: 32768 # Max tokens for GenRM's comparison output + stop: null # Stop strings for GenRM evaluation + reasoning_split_word: "" # The word to split the response into reasoning and answer + max_concurrency: 16 # Maximum concurrent step calls for the environment actor + # Reward aggregation configuration + aggregator_method: "simple_tiebreaker" # Options: "weighted_win_loss", "simple_tiebreaker", "individual_scores" + # aggregator_config: # Additional config for the aggregator + # score_mapping: # Mapping from ranking scores (1-6) to weighted points for weighted_win_loss + # 1: 1.0 # Much better + # 2: 0.8 # Better + # 3: 0.6 # Slightly better + # 4: 0.4 # Slightly worse + # 5: 0.2 # Worse + # 6: 0.0 # Much worse + +# Logger configuration +logger: + log_dir: "logs" # Base directory for all logs + num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal + wandb_enabled: false + tensorboard_enabled: false + mlflow_enabled: false # Disable MLflow logging + monitor_gpus: true # If true, will monitor GPU usage and log to wandb and/or tensorboard + wandb: + project: "grpo-genrm-rlhf-megatron" + name: "grpo-genrm-rlhf-megatron-logger" + tensorboard: {} + mlflow: + experiment_name: "grpo-genrm-rlhf-megatron" + run_name: "grpo-genrm-rlhf-megatron-logger" + gpu_monitoring: + collection_interval: 10 # How often to collect GPU usage metrics (in seconds) + flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) + +cluster: + gpus_per_node: 1 + num_nodes: 1 \ No newline at end of file diff --git a/examples/run_grpo_genrm_rlhf.py b/examples/run_grpo_genrm_rlhf.py new file mode 100644 index 0000000000..48e3295e5f --- /dev/null +++ b/examples/run_grpo_genrm_rlhf.py @@ -0,0 +1,256 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import pprint +from collections import defaultdict +from typing import Any, Optional + +from omegaconf import OmegaConf +from transformers import PreTrainedTokenizerBase + +from nemo_rl.algorithms.grpo import MasterConfig, grpo_train, setup +from nemo_rl.algorithms.utils import get_tokenizer +from nemo_rl.data import DataConfig +from nemo_rl.data.datasets import AllTaskProcessedDataset +from nemo_rl.data.datasets.response_datasets.genrm_rlhf import GenRMRLHFDataset +from nemo_rl.data.interfaces import ( + TaskDataProcessFnCallable, + TaskDataSpec, +) +from nemo_rl.data.processors import genrm_rlhf_data_processor +from nemo_rl.distributed.ray_actor_environment_registry import ( + get_actor_python_env, +) +from nemo_rl.distributed.virtual_cluster import init_ray +from nemo_rl.environments.genrm_rlhf_environment import GenRMRLHFEnvironment +from nemo_rl.environments.interfaces import EnvironmentInterface +from nemo_rl.models.generation import configure_generation_config +from nemo_rl.utils.config import load_config, parse_hydra_overrides +from nemo_rl.utils.logger import get_next_experiment_dir + +OmegaConf.register_new_resolver("mul", lambda a, b: a * b) + + +def parse_args() -> tuple[argparse.Namespace, list[str]]: + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="Run GRPO training for GenRM RLHF") + parser.add_argument( + "--config", type=str, default=None, help="Path to YAML config file" + ) + + # Parse known args for the script + args, overrides = parser.parse_known_args() + + return args, overrides + + +# =============================================================================== +# GenRM RLHF Data Setup +# =============================================================================== +TokenizerType = PreTrainedTokenizerBase + + +def setup_data( + tokenizer: TokenizerType, + data_config: DataConfig, + env_configs: dict[str, Any], + seed: int, +) -> tuple[ + AllTaskProcessedDataset, + Optional[AllTaskProcessedDataset], + dict[str, EnvironmentInterface], + dict[str, EnvironmentInterface], +]: + print("\nā–¶ Setting up GenRM RLHF data...") + + # Create task spec for GenRM RLHF + genrm_rlhf_task_spec = TaskDataSpec( + task_name="genrm_rlhf", + prompt_file=data_config.get("prompt_file"), + system_prompt_file=data_config.get("system_prompt_file"), + ) + + # Load dataset + data: Any = GenRMRLHFDataset( + train_data_path=data_config["train_data_path"], + val_data_path=data_config.get("val_data_path"), + task_name="genrm_rlhf", + ) + + # Set up data processor + task_data_processors: dict[str, tuple[TaskDataSpec, TaskDataProcessFnCallable]] = ( + defaultdict(lambda: (genrm_rlhf_task_spec, genrm_rlhf_data_processor)) + ) + task_data_processors["genrm_rlhf"] = ( + genrm_rlhf_task_spec, + genrm_rlhf_data_processor, + ) + # Also register under alternate name for backward compatibility + task_data_processors["rlhf_genrm"] = ( + genrm_rlhf_task_spec, + genrm_rlhf_data_processor, + ) + + # Setup GenRM RLHF environment + genrm_rlhf_env = GenRMRLHFEnvironment.options( # type: ignore # it's wrapped with ray.remote + runtime_env={ + "py_executable": get_actor_python_env( + "nemo_rl.environments.genrm_rlhf_environment.GenRMRLHFEnvironment" + ), + "env_vars": dict(os.environ), # Pass thru all user environment variables + } + ).remote(env_configs["genrm_rlhf"]) + + # Create training dataset + dataset = AllTaskProcessedDataset( + data.formatted_ds["train"], + tokenizer, + genrm_rlhf_task_spec, + task_data_processors, + max_seq_length=data_config["max_input_seq_length"], + ) + + # Create validation dataset if available + val_dataset: Optional[AllTaskProcessedDataset] = None + if data.formatted_ds["validation"]: + val_dataset = AllTaskProcessedDataset( + data.formatted_ds["validation"], + tokenizer, + genrm_rlhf_task_spec, + task_data_processors, + max_seq_length=data_config["max_input_seq_length"], + ) + else: + val_dataset = None + + # Map task to environment + task_to_env: dict[str, EnvironmentInterface] = defaultdict(lambda: genrm_rlhf_env) + task_to_env["genrm_rlhf"] = genrm_rlhf_env + # Also register under alternate name for backward compatibility + task_to_env["rlhf_genrm"] = genrm_rlhf_env + + return dataset, val_dataset, task_to_env, task_to_env + + +def main() -> None: + """Main entry point for GenRM RLHF training.""" + # Parse arguments + args, overrides = parse_args() + + if not args.config: + args.config = os.path.join( + os.path.dirname(__file__), "configs", "grpo_genrm_rlhf_1B.yaml" + ) + + config = load_config(args.config) + print(f"Loaded configuration from: {args.config}") + + if overrides: + print(f"Overrides: {overrides}") + config = parse_hydra_overrides(config, overrides) + + config: MasterConfig = OmegaConf.to_container(config, resolve=True) + print("Applied CLI overrides") + + # Print config + print("Final config:") + pprint.pprint(config) + + # Get the next experiment directory with incremented ID + config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"]) + print(f"šŸ“Š Using log directory: {config['logger']['log_dir']}") + if config["checkpointing"]["enabled"]: + print( + f"šŸ“Š Using checkpoint directory: {config['checkpointing']['checkpoint_dir']}" + ) + + init_ray() + + # Setup tokenizer + tokenizer = get_tokenizer(config["policy"]["tokenizer"]) + assert config["policy"]["generation"] is not None, ( + "A generation config is required for GRPO" + ) + config["policy"]["generation"] = configure_generation_config( + config["policy"]["generation"], tokenizer + ) + + # Setup data + ( + dataset, + val_dataset, + task_to_env, + val_task_to_env, + ) = setup_data(tokenizer, config["data"], config["env"], config["grpo"]["seed"]) + + ( + policy, + policy_generation, + cluster, + dataloader, + val_dataloader, + loss_fn, + logger, + checkpointer, + grpo_state, + master_config, + ) = setup(config, tokenizer, dataset, val_dataset) + + # Check if async mode is enabled + if "async_grpo" in config["grpo"] and config["grpo"]["async_grpo"]["enabled"]: + from nemo_rl.algorithms.grpo import async_grpo_train + + print("šŸš€ Running async GRPO training") + + async_config = config["grpo"]["async_grpo"] + # Run async GRPO training + async_grpo_train( + policy=policy, + policy_generation=policy_generation, + dataloader=dataloader, + val_dataloader=val_dataloader, + tokenizer=tokenizer, + loss_fn=loss_fn, + task_to_env=task_to_env, + val_task_to_env=val_task_to_env, + logger=logger, + checkpointer=checkpointer, + grpo_save_state=grpo_state, + master_config=master_config, + max_trajectory_age_steps=async_config["max_trajectory_age_steps"], + ) + else: + print("šŸš€ Running synchronous GRPO training for GenRM RLHF") + + # Run standard GRPO training + grpo_train( + policy, + policy_generation, + dataloader, + val_dataloader, + tokenizer, + loss_fn, + task_to_env, + val_task_to_env, + logger, + checkpointer, + grpo_state, + master_config, + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index c67b532498..cf0bf1aa54 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -1593,6 +1593,11 @@ def validate( greedy=False, ) else: + num_val_generations = master_config["grpo"].get( + "num_val_generations_per_prompt", 1 + ) + if num_val_generations > 1: + val_batch = val_batch.repeat_interleave(num_val_generations) val_batch, gen_metrics = run_multi_turn_rollout( policy_generation, val_batch, diff --git a/nemo_rl/data/datasets/response_datasets/__init__.py b/nemo_rl/data/datasets/response_datasets/__init__.py index 8e75a99a0c..2078d40128 100644 --- a/nemo_rl/data/datasets/response_datasets/__init__.py +++ b/nemo_rl/data/datasets/response_datasets/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Any +from nemo_rl.data.datasets.response_datasets.genrm_rlhf import GenRMRLHFDataset from nemo_rl.data.datasets.response_datasets.clevr import CLEVRCoGenTDataset from nemo_rl.data.datasets.response_datasets.dapo_math import DAPOMath17KDataset from nemo_rl.data.datasets.response_datasets.deepscaler import DeepScalerDataset @@ -134,4 +135,5 @@ def load_response_dataset(data_config, seed: int = 42): "RefCOCODataset", "ResponseDataset", "SquadDataset", + "GenRMRLHFDataset", ] diff --git a/nemo_rl/data/datasets/response_datasets/genrm_rlhf.py b/nemo_rl/data/datasets/response_datasets/genrm_rlhf.py new file mode 100644 index 0000000000..e246b9835b --- /dev/null +++ b/nemo_rl/data/datasets/response_datasets/genrm_rlhf.py @@ -0,0 +1,89 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional + +from datasets import load_dataset + +from nemo_rl.data.interfaces import TaskDataSpec + + +class GenRMRLHFDataset: + """Dataset class for GenRM RLHF training data. + + This class handles loading of data for GenRM RLHF training where the model + generates multiple responses per prompt and uses GenRM for pairwise comparisons. + + The input JSONL files should contain valid JSON objects formatted like this: + { + "messages": [[{role, content, metadata}, ...]], # List of message lists (full conversations) + "task_name": str, + "dataset": str # Optional dataset identifier + } + + Each message list is a complete conversation with multiple user/assistant turns. + The last user message MUST contain metadata with: + - conversation_history: Previous conversation context (required, can be empty list []) + + Args: + train_data_path: Path to the JSON file containing training data + val_data_path: Optional path to the JSON file containing validation data + task_name: Name of the task (default: "genrm_rlhf") + """ + + def __init__( + self, + train_data_path: str, + val_data_path: Optional[str] = None, + task_name: str = "genrm_rlhf", + ): + self.task_name = task_name + + # Load from json file + train_ds = load_dataset("json", data_files=train_data_path)["train"] + val_ds = None + if val_data_path: + val_ds = load_dataset("json", data_files=val_data_path)["train"] + + # Process the data to extract conversation history from metadata + train_ds = train_ds.map(lambda x: self._process_format(x)) + if val_ds: + val_ds = val_ds.map(lambda x: self._process_format(x)) + + # Store the formatted dataset + self.formatted_ds = { + "train": train_ds, + "validation": val_ds, + } + + self.task_spec = TaskDataSpec(task_name=self.task_name) + + def _process_format(self, example: dict[str, Any]) -> dict[str, Any]: + """Ensure the example has the required format. + + Note: Metadata is expected to be in the last user message of each conversation. + The data processor will extract it from there. + """ + # Make sure task_name is set + if "task_name" not in example: + example["task_name"] = self.task_name + + # Ensure messages is a list of lists + if "messages" in example: + messages = example["messages"] + # If messages is a single list of dicts, wrap it + if messages and isinstance(messages[0], dict): + example["messages"] = [messages] + + return example \ No newline at end of file diff --git a/nemo_rl/data/processors.py b/nemo_rl/data/processors.py index 3a90f384fe..70062b7d72 100644 --- a/nemo_rl/data/processors.py +++ b/nemo_rl/data/processors.py @@ -231,3 +231,101 @@ def multichoice_qa_processor( if "task_name" in datum_dict: output["task_name"] = datum_dict["task_name"] return output + + + + +def genrm_rlhf_data_processor( + datum_dict: dict[str, Any], + task_data_spec: TaskDataSpec, + tokenizer: TokenizerType, + max_seq_length: int, + idx: int, +) -> DatumSpec: + """Process a datum dictionary for GenRM RLHF training tasks. + + The datum_dict should contain: + - messages: List of message lists, where each list is a full conversation + that MUST end with a user message + + The last user message MUST contain metadata with: + - conversation_history: Previous conversation context (required, can be empty list) + + The processor will: + 1. Extract metadata from the last user message + 2. Use the entire conversation as the generation prompt + """ + # Extract the messages (should be a list containing message lists) + messages_list = datum_dict["messages"] + + # Get the first (and likely only) message list + if messages_list and len(messages_list) > 0: + messages = messages_list[0] + else: + raise ValueError("No messages found in datum_dict") + + # The last message must be a user turn with required metadata + assert messages, "Messages list is empty" + assert messages[-1].get("role") == "user", "Last message must be a user turn" + assert "metadata" in messages[-1], "Last user turn must have metadata" + assert "conversation_history" in messages[-1]["metadata"], ( + "Metadata must contain conversation_history" + ) + + # Extract metadata from the last user turn + metadata = { + "conversation_history": messages[-1]["metadata"]["conversation_history"] + } + + # Clean conversation for tokenization (remove metadata from messages) + clean_conversation = [ + {"role": msg["role"], "content": msg["content"]} for msg in messages + ] + + # Apply chat template to the entire conversation and tokenize + templated_conversation = tokenizer.apply_chat_template( + clean_conversation, + tokenize=False, + add_generation_prompt=True, + add_special_tokens=True, + ) + + # Tokenize and create message log entry + token_ids = tokenizer( + templated_conversation, + return_tensors="pt", + add_special_tokens=False, # Already added by apply_chat_template + )["input_ids"][0] + + # Handle max sequence length + length = len(token_ids) + loss_multiplier = 1.0 + if length > max_seq_length: + # Truncate if necessary + token_ids = token_ids[:max_seq_length] + length = max_seq_length + loss_multiplier = 0.0 + + message_log: LLMMessageLogType = [ + { + "role": "user", # Full conversation prompt for generation + "content": templated_conversation, + "token_ids": token_ids, + } + ] + + # Prepare extra_env_info with conversation history + extra_env_info = { + "conversation_history": metadata.get("conversation_history", []), + } + + output: DatumSpec = { + "message_log": message_log, + "length": length, + "extra_env_info": extra_env_info, + "loss_multiplier": loss_multiplier, + "idx": idx, + "task_name": datum_dict.get("task_name", "genrm_rlhf"), + } + + return output diff --git a/nemo_rl/distributed/ray_actor_environment_registry.py b/nemo_rl/distributed/ray_actor_environment_registry.py index 6a3529d4a1..83cb519591 100644 --- a/nemo_rl/distributed/ray_actor_environment_registry.py +++ b/nemo_rl/distributed/ray_actor_environment_registry.py @@ -35,6 +35,7 @@ "nemo_rl.environments.math_environment.MathEnvironment": PY_EXECUTABLES.SYSTEM, "nemo_rl.environments.vlm_environment.VLMEnvironment": PY_EXECUTABLES.SYSTEM, "nemo_rl.environments.code_environment.CodeEnvironment": PY_EXECUTABLES.SYSTEM, + "nemo_rl.environments.genrm_rlhf_environment.GenRMRLHFEnvironment": PY_EXECUTABLES.SYSTEM, "nemo_rl.environments.reward_model_environment.RewardModelEnvironment": PY_EXECUTABLES.SYSTEM, "nemo_rl.environments.games.sliding_puzzle.SlidingPuzzleEnv": PY_EXECUTABLES.SYSTEM, # AsyncTrajectoryCollector needs vLLM environment to handle exceptions from VllmGenerationWorker diff --git a/nemo_rl/environments/genrm_rlhf_environment.py b/nemo_rl/environments/genrm_rlhf_environment.py new file mode 100644 index 0000000000..31da6ebe0f --- /dev/null +++ b/nemo_rl/environments/genrm_rlhf_environment.py @@ -0,0 +1,573 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import itertools +import logging +import os +import uuid +from typing import Any, Dict, List, Optional, Tuple, TypedDict + +import ray +import torch + +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.virtual_cluster import PY_EXECUTABLES, RayVirtualCluster +from nemo_rl.environments.interfaces import ( + EnvironmentInterface, + EnvironmentReturn, +) +from nemo_rl.environments.pairwise_reward_aggregators import create_aggregator + + +class GenRMRLHFConfig(TypedDict): + num_workers: int + model_name: ( + str # GenRM model name, e.g., "nvidia/Llama-3_3-Nemotron-Super-49B-GenRM" + ) + tensor_parallel_size: int + gpu_memory_utilization: float + max_model_len: int + num_generations_per_prompt: ( + int # e.g., 8 - number of responses to generate per prompt during training + ) + num_val_generations_per_prompt: Optional[ + int + ] # Number of responses per prompt during validation (defaults to num_generations_per_prompt) + # Default sampling parameters for the GenRM + temperature: Optional[float] + top_p: Optional[float] + max_tokens: Optional[int] + stop: Optional[List[str]] + max_concurrency: Optional[ + int + ] # Maximum concurrent step calls for the environment actor + reasoning_split_word: Optional[str] # Default: "" + # Reward aggregation configuration + aggregator_method: Optional[ + str + ] # e.g., "individual_scores", "weighted_win_loss", "simple_tiebreaker", etc. + aggregator_config: Optional[Dict[str, Any]] # Additional config for the aggregator + num_judges_per_comparison: Optional[ + int + ] # Number of times to evaluate each response pair (majority voting) + + +class GenRMEnvironmentMetadata(TypedDict): + conversation_history: List[ + Dict[str, str] + ] # The conversation history in user/assistant format + + +@ray.remote +class AsyncGenRMWorker: + """Worker that serves GenRM using vLLM AsyncEngine for pairwise response comparisons.""" + + DEFAULT_PY_EXECUTABLE = PY_EXECUTABLES.VLLM + + def __init__( + self, + model_name: str, + tensor_parallel_size: int = 1, + gpu_memory_utilization: float = 0.85, + max_model_len: Optional[int] = None, + disable_log_stats: bool = True, + reasoning_split_word: Optional[str] = "", + **engine_kwargs, + ): + # Configure logging for Ray worker + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + force=True, + ) + + # Imports moved here to be within the Ray actor's context + from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE + from transformers import AutoTokenizer + from vllm.engine.arg_utils import AsyncEngineArgs + from vllm.engine.async_llm_engine import AsyncLLMEngine + from vllm.inputs import TokensPrompt + from vllm.sampling_params import SamplingParams + + self.SamplingParams = SamplingParams + self.TokensPrompt = TokensPrompt + + # Setup HF cache path + hf_home_cache_path = os.environ.get("HF_HOME", HUGGINGFACE_HUB_CACHE) + if not os.path.isdir(hf_home_cache_path): + try: + os.makedirs(hf_home_cache_path, exist_ok=True) + logging.info( + f"Created HF cache directory for GenRM worker: {hf_home_cache_path}" + ) + except OSError as e: + logging.warning( + f"GenRM worker could not create HF cache directory {hf_home_cache_path}: {e}. " + "This might lead to download issues if the default cache is not writable." + ) + + # Load tokenizer for chat template functionality + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, + cache_dir=hf_home_cache_path, + trust_remote_code=True, + ) + + # Initialize AsyncEngine with GenRM model + engine_args = AsyncEngineArgs( + model=model_name, + tensor_parallel_size=tensor_parallel_size, + gpu_memory_utilization=gpu_memory_utilization, + max_model_len=max_model_len, + disable_log_stats=disable_log_stats, + download_dir=hf_home_cache_path, + ignore_patterns=[ + "*.safetensors.index.json", + "*.pt", + "*.bin.index.json", + "*.gitattributes", + ], + trust_remote_code=True, # GenRM models typically need trust_remote_code + **engine_kwargs, + ) + self.engine = AsyncLLMEngine.from_engine_args(engine_args) + self.reasoning_split_word = reasoning_split_word + logging.info(f"AsyncGenRMWorker initialized with GenRM model: {model_name}") + + def _format_genrm_messages( + self, + conversation_history: List[Dict[str, str]], + response_1: str, + response_2: str, + ) -> List[Dict[str, str]]: + """Format the conversation and responses into GenRM's expected message format.""" + # Build messages list in the format expected by GenRM + messages = conversation_history.copy() + + # Add the responses to be compared + messages.extend( + [ + {"role": "response_1", "content": response_1}, + {"role": "response_2", "content": response_2}, + ] + ) + + return messages + + async def compare_responses( + self, + request_id: str, + conversation_history: List[Dict[str, str]], + response_1: str, + response_2: str, + sampling_params_dict: dict, + ) -> Tuple[str, float, float, float]: + """Compare two responses using GenRM via AsyncEngine. + + Args: + request_id: Unique ID for this comparison request + conversation_history: List of conversation messages in user/assistant format + response_1: First response to compare + response_2: Second response to compare + sampling_params_dict: Parameters for vLLM sampling + + Returns: + Tuple of (request_id, individual_score_1, individual_score_2, ranking_score) + """ + try: + # Format messages for GenRM + if self.reasoning_split_word and self.reasoning_split_word in response_1: + response_1 = response_1.split(self.reasoning_split_word)[-1].lstrip() + if self.reasoning_split_word and self.reasoning_split_word in response_2: + response_2 = response_2.split(self.reasoning_split_word)[-1].lstrip() + messages = self._format_genrm_messages( + conversation_history, response_1, response_2 + ) + + # Apply chat template to get text for debugging + chat_template_text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + + # logging.info(f"GenRM chat template text for {request_id}:\n{chat_template_text}") + + # Apply chat template and tokenize + token_ids = self.tokenizer.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_tensors=None, # Return list of token IDs + ) + + # logging.info(f"GenRM tokenized prompt for {request_id}: {len(token_ids)} tokens") + + # Create sampling parameters + sampling_params = self.SamplingParams(**sampling_params_dict) + + # Create TokensPrompt object for vLLM + tokens_prompt = self.TokensPrompt(prompt_token_ids=token_ids) + + # Generate using AsyncEngine with TokensPrompt + results_generator = self.engine.generate( + tokens_prompt, sampling_params, request_id + ) + + final_output = None + async for request_output in results_generator: + final_output = request_output + + if final_output and final_output.outputs: + generated_text = final_output.outputs[0].text.strip() + + # Split by reasoning word if provided + if ( + self.reasoning_split_word + and self.reasoning_split_word in generated_text + ): + generated_text = generated_text.split(self.reasoning_split_word)[ + -1 + ].lstrip() + + # Parse the scores from GenRM output + individual_score_1, individual_score_2, ranking_score = ( + self._parse_genrm_output(generated_text) + ) + + logging.info( + f"GenRM comparison {request_id}: scores=({individual_score_1}, {individual_score_2}), ranking={ranking_score}, generated_text={generated_text}" + ) + + return request_id, individual_score_1, individual_score_2, ranking_score + else: + logging.warning( + f"No output received from GenRM for request {request_id}" + ) + return request_id, 3.0, 3.0, 3.5 + + except Exception as e: + logging.error(f"Error in GenRM comparison {request_id}: {e}") + return request_id, 3.0, 3.0, 3.5 + + def _parse_genrm_output(self, output: str) -> Tuple[float, float, float]: + """Parse GenRM output to extract individual and ranking scores from JSON format.""" + import json + import re + + try: + # Try to find JSON in the response (same as vanilla GenRM) + json_match = re.search(r"\{.*\}", output, re.DOTALL) + if json_match: + json_str = json_match.group(0) + parsed = json.loads(json_str) + + score_1 = float(parsed.get("score_1", 3.0)) + score_2 = float(parsed.get("score_2", 3.0)) + ranking = float(parsed.get("ranking", 3.5)) + + logging.debug( + f"Extracted scores from JSON: score_1={score_1}, score_2={score_2}, ranking={ranking}" + ) + return score_1, score_2, ranking + else: + logging.warning(f"No JSON found in GenRM output: {output}...") + return 3.0, 3.0, 3.5 # Default neutral scores + + except json.JSONDecodeError as e: + logging.error( + f"Failed to parse JSON from GenRM output: {e}. Output was: {output}..." + ) + return 3.0, 3.0, 3.5 + except Exception as e: + logging.error(f"Error parsing GenRM output: {e}. Output was: {output}...") + return 3.0, 3.0, 3.5 + + +@ray.remote +class GenRMRLHFEnvironment(EnvironmentInterface): + """Environment that uses GenRM for pairwise comparison of multiple responses per prompt.""" + + DEFAULT_PY_EXECUTABLE = PY_EXECUTABLES.SYSTEM + + def __init__(self, cfg: GenRMRLHFConfig): + self.cfg = cfg + self.num_workers = cfg["num_workers"] + self.num_generations_per_prompt = cfg["num_generations_per_prompt"] + # If validation generations not specified, default to training value + self.num_val_generations_per_prompt = cfg.get( + "num_val_generations_per_prompt", self.num_generations_per_prompt + ) + self.num_judges_per_comparison = int( + cfg.get("num_judges_per_comparison", 1) or 1 + ) + if self.num_judges_per_comparison < 1: + raise ValueError("num_judges_per_comparison must be at least 1") + + # Initialize the reward aggregator + aggregator_method = cfg.get( + "aggregator_method", "simple_tiebreaker" + ) # Default to simple_tiebreaker + aggregator_config = cfg.get("aggregator_config", {}) + self.reward_aggregator = create_aggregator( + aggregator_method, **aggregator_config + ) + logging.info( + f"Initialized GenRM environment with {self.reward_aggregator.name} aggregator" + ) + + tensor_parallel_size = cfg.get("tensor_parallel_size", 1) + + # Create RayVirtualCluster for GPU allocation if needed + if tensor_parallel_size == 1: + bundle_ct_per_node_list = [tensor_parallel_size] * self.num_workers + + self.virtual_cluster = RayVirtualCluster( + bundle_ct_per_node_list=bundle_ct_per_node_list, + use_gpus=True, + name="genrm_pairwise_vc", + ) + self.virtual_cluster.print_cluster_grid() + placement_groups = self.virtual_cluster.get_placement_groups() + else: + self.virtual_cluster = None + placement_groups = [] + + # Pass down environment variables to workers + env_vars_to_pass = {} + for key in [ + "HF_HOME", + "TRANSFORMERS_CACHE", + "WANDB_API_KEY", + "HUGGINGFACE_HUB_DISABLE_XET", + "HF_TOKEN", + ]: + if key in os.environ: + env_vars_to_pass[key] = os.environ[key] + + env_vars_to_pass.setdefault("HUGGINGFACE_HUB_DISABLE_XET", "1") + + worker_options = { + "runtime_env": { + "py_executable": AsyncGenRMWorker.DEFAULT_PY_EXECUTABLE, + "env_vars": env_vars_to_pass, + }, + "num_gpus": tensor_parallel_size, + } + + # Create GenRM workers + self.workers = [] + for i in range(self.num_workers): + if tensor_parallel_size == 1: + pg_index = i % len(placement_groups) + pg = placement_groups[pg_index] + scheduling_kwargs = dict( + scheduling_strategy=ray.util.scheduling_strategies.PlacementGroupSchedulingStrategy( + placement_group=pg + ) + ) + else: + scheduling_kwargs = {} + + worker = AsyncGenRMWorker.options( + **worker_options, + **scheduling_kwargs, + ).remote( + model_name=cfg["model_name"], + tensor_parallel_size=tensor_parallel_size, + gpu_memory_utilization=cfg.get("gpu_memory_utilization", 0.85), + max_model_len=cfg.get("max_model_len"), + reasoning_split_word=cfg.get("reasoning_split_word", ""), + ) + self.workers.append(worker) + + logging.info(f"Created {len(self.workers)} AsyncGenRMWorker actors.") + self._request_counter = 0 + self._actor_id_prefix = str(uuid.uuid4())[:8] + self._last_additional_metrics = {} # Store additional metrics from last step for global_post_process_and_metrics + + def shutdown(self): + for worker in self.workers: + ray.kill(worker) + if self.virtual_cluster is not None: + self.virtual_cluster.shutdown() + + def step( + self, + message_log_batch: List[List[Dict[str, str]]], + metadata: List[GenRMEnvironmentMetadata], + ) -> EnvironmentReturn: + """Step function for GenRM pairwise comparison environment. + + Args: + message_log_batch: List of conversations, where each conversation is a list of messages + metadata: List of metadata for each conversation + + Returns: + EnvironmentReturn with rewards based on pairwise comparison aggregation + """ + + def get_prompt_key(conversation_history: List[Dict[str, str]]) -> str: + """Extract the conversation history as a grouping key.""" + # Create a key from the conversation history (the prompt context) + prompt_parts = [] + for msg in conversation_history: + prompt_parts.append(f"{msg['role']}: {msg['content']}") + + return " | ".join(prompt_parts) + + # Group responses by prompt (conversation history from metadata) + prompt_groups = {} + for i, (conversation, single_metadata) in enumerate( + zip(message_log_batch, metadata) + ): + prompt_key = get_prompt_key(single_metadata["conversation_history"]) + if prompt_key not in prompt_groups: + prompt_groups[prompt_key] = { + "conversations": [], + "metadata": single_metadata, + "indices": [], + } + prompt_groups[prompt_key]["conversations"].append(conversation) + prompt_groups[prompt_key]["indices"].append(i) + + # Prepare default sampling parameters + default_sampling_params = { + "temperature": self.cfg.get("temperature", 0.0), + "top_p": self.cfg.get("top_p", 1.0), + "max_tokens": self.cfg.get("max_tokens", 32768), + "stop": self.cfg.get("stop", None), + } + + # Collect all pairwise comparison tasks + comparison_futures = [] + comparison_metadata = [] # Track which prompt group, response indices, and judge iteration each comparison belongs to + + for prompt_key, group_data in prompt_groups.items(): + conversations = group_data["conversations"] + group_metadata = group_data["metadata"] + conversation_history = group_metadata["conversation_history"] + + # Extract responses from conversations (assuming last message is assistant response) + responses = [] + for conversation in conversations: + assert len(conversation) >= 1, ( + "Each conversation should have at least one message" + ) + # Get the last assistant message as the response + assistant_msgs = [ + msg for msg in conversation if msg["role"] == "assistant" + ] + assert len(assistant_msgs) >= 1, ( + "Each conversation should have at least one assistant message" + ) + responses.append(assistant_msgs[-1]["content"]) + + # Check that we have the expected number of generations per prompt + if len(responses) not in [ + self.num_generations_per_prompt, + self.num_val_generations_per_prompt, + ]: + raise ValueError( + f"Expected {self.num_generations_per_prompt} (training) or {self.num_val_generations_per_prompt} (validation) " + f"generations per prompt, but found {len(responses)} responses. " + f"This may be because generations for the same prompt are distributed to multiple dp ranks." + ) + + # Generate all pairwise comparisons for this prompt group + for judge_idx in range(self.num_judges_per_comparison): + for i, j in itertools.combinations(range(len(responses)), 2): + request_id = ( + f"genrm_{self._actor_id_prefix}_step_{self._request_counter}_pk{hash(prompt_key)}" + f"_r{i}_r{j}_judge{judge_idx}" + ) + worker_idx = len(comparison_futures) % self.num_workers + + future = self.workers[worker_idx].compare_responses.remote( + request_id, + conversation_history, + responses[i], + responses[j], + default_sampling_params, + ) + comparison_futures.append(future) + comparison_metadata.append((prompt_key, i, j, judge_idx)) + + self._request_counter += 1 + + # Get all comparison results + comparison_results = ray.get(comparison_futures) + + # Aggregate pairwise comparisons into final scores for each response using the configured aggregator + final_scores = self.reward_aggregator.aggregate_scores( + comparison_results, comparison_metadata, prompt_groups + ) + + # Get additional metrics from the aggregator (e.g., individual scores for ranking-based aggregators) + self._last_additional_metrics = self.reward_aggregator.get_additional_metrics( + comparison_results, comparison_metadata, prompt_groups, final_scores + ) + + # Create observations and prepare return values in the same order as input + observations = [] + all_metadata = [] + rewards_list = [] + + for i, (conversation, single_metadata) in enumerate( + zip(message_log_batch, metadata) + ): + prompt_key = get_prompt_key(single_metadata["conversation_history"]) + + # Find which response index this is within its group + group_indices = prompt_groups[prompt_key]["indices"] + response_idx_in_group = group_indices.index(i) + + # Get the score for this response (default to 0.5 if no comparisons were made) + if prompt_key in final_scores: + score = final_scores[prompt_key][response_idx_in_group] + else: + score = 0.5 # Neutral score for single responses + + observations.append( + { + "role": "environment", + "content": f"Environment: {self.reward_aggregator.name} Score = {score:.3f}", + } + ) + all_metadata.append(single_metadata) + rewards_list.append(score) + + rewards_tensor = torch.tensor(rewards_list, dtype=torch.float32).cpu() + terminateds_tensor = torch.ones_like(rewards_tensor).cpu() + next_stop_strings = [None] * len(rewards_list) + + return EnvironmentReturn( + observations=observations, + metadata=all_metadata, + next_stop_strings=next_stop_strings, + rewards=rewards_tensor, + terminateds=terminateds_tensor, + answers=None, # GenRM RLHF doesn't extract specific answers + ) + + def global_post_process_and_metrics( + self, batch: BatchedDataDict + ) -> Tuple[BatchedDataDict, dict]: + """Computes metrics for the GenRM pairwise environment.""" + metrics = {} + + # Add additional metrics from the aggregator (e.g., individual scores for ranking-based aggregators) + if self._last_additional_metrics: + metrics.update(self._last_additional_metrics) + + return batch, metrics \ No newline at end of file diff --git a/nemo_rl/environments/pairwise_reward_aggregators.py b/nemo_rl/environments/pairwise_reward_aggregators.py new file mode 100644 index 0000000000..4703337856 --- /dev/null +++ b/nemo_rl/environments/pairwise_reward_aggregators.py @@ -0,0 +1,440 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Tuple + +import numpy as np + + +class PairwiseRewardAggregator(ABC): + """Abstract base class for aggregating pairwise comparison results into scalar rewards.""" + + def __init__(self, **kwargs): + """Default initialization that accepts any keyword arguments for compatibility with factory function.""" + # Accept any kwargs to maintain compatibility with the factory function + # Subclasses can override this if they need specific configuration parameters + pass + + @abstractmethod + def aggregate_scores( + self, + comparison_results: List[ + Tuple[str, float, float, float] + ], # (request_id, score_1, score_2, ranking_score) + comparison_metadata: List[ + Tuple[str, int, int, int] + ], # (prompt_key, resp_i, resp_j, judge_idx) + prompt_groups: Dict[ + str, Dict[str, Any] + ], # {prompt_key: {"conversations": [...], "metadata": {...}, "indices": [...]}} + ) -> Dict[str, List[float]]: + """Aggregate pairwise comparison results into final rewards for each response.""" + pass + + def get_additional_metrics( + self, + comparison_results: List[ + Tuple[str, float, float, float] + ], # (request_id, score_1, score_2, ranking_score) + comparison_metadata: List[ + Tuple[str, int, int, int] + ], # (prompt_key, resp_i, resp_j, judge_idx) + prompt_groups: Dict[str, Dict[str, Any]], + final_scores: Dict[ + str, List[float] + ], # The aggregated scores from aggregate_scores + ) -> Dict[str, float]: + """Compute additional metrics to log alongside the main aggregated scores. + + Default implementation returns empty dict. Subclasses can override to provide + additional metrics like individual scores, win rates, etc. + + Args: + comparison_results: Raw comparison results from the GenRM model + comparison_metadata: Metadata about which responses were compared (including judge index) + prompt_groups: Grouped prompt data + final_scores: The final aggregated scores returned by aggregate_scores + + Returns: + Dictionary of additional metrics to log + """ + return {} + + @property + @abstractmethod + def name(self) -> str: + """Name of the aggregation method.""" + pass + + +class IndividualScoreAggregator(PairwiseRewardAggregator): + """Use GenRM's individual helpfulness scores directly.""" + + def __init__(self, score_range: Tuple[float, float] = (1.0, 5.0)): + """Initialize the aggregator. + + Args: + score_range: (min_score, max_score) for normalization. + """ + self.min_score, self.max_score = score_range + + def aggregate_scores( + self, + comparison_results: List[Tuple[str, float, float, float]], + comparison_metadata: List[Tuple[str, int, int]], + prompt_groups: Dict[str, Dict[str, Any]], + ) -> Dict[str, List[float]]: + """Aggregate using individual helpfulness scores.""" + # Collect individual scores for each response + individual_scores = {} + score_counts = {} + + for prompt_key, group_data in prompt_groups.items(): + num_responses = len(group_data["conversations"]) + individual_scores[prompt_key] = [0.0 for _ in range(num_responses)] + score_counts[prompt_key] = [0 for _ in range(num_responses)] + + # Process each comparison result + for (request_id, score_1, score_2, ranking_score), ( + prompt_key, + resp_i, + resp_j, + judge_idx, + ) in zip(comparison_results, comparison_metadata): + # Accumulate individual scores + individual_scores[prompt_key][resp_i] += score_1 + individual_scores[prompt_key][resp_j] += score_2 + score_counts[prompt_key][resp_i] += 1 + score_counts[prompt_key][resp_j] += 1 + + # Calculate average individual scores and normalize + final_scores = {} + for prompt_key, group_data in prompt_groups.items(): + num_responses = len(group_data["conversations"]) + final_scores[prompt_key] = [] + for resp_idx in range(num_responses): + if score_counts[prompt_key][resp_idx] > 0: + avg_score = ( + individual_scores[prompt_key][resp_idx] + / score_counts[prompt_key][resp_idx] + ) + else: + avg_score = ( + self.min_score + self.max_score + ) / 2 # Neutral score if no comparisons + final_scores[prompt_key].append(avg_score) + + return final_scores + + @property + def name(self) -> str: + return "individual_scores" + + +class WeightedWinLossAggregator(PairwiseRewardAggregator): + """Use weighted scores based on the magnitude of GenRM ranking preferences.""" + + def __init__(self, score_mapping: Dict[int, float] = None): + """Initialize the aggregator. + + Args: + score_mapping: Mapping from ranking scores (1-6) to weighted points. + """ + self.score_mapping = score_mapping or { + 1: 1.0, # Much better + 2: 0.75, # Better + 3: 0.6, # Slightly better + 4: 0.4, # Slightly worse + 5: 0.25, # Worse + 6: 0.0, # Much worse + } + + def aggregate_scores( + self, + comparison_results: List[Tuple[str, float, float, float]], + comparison_metadata: List[Tuple[str, int, int, int]], + prompt_groups: Dict[str, Dict[str, Any]], + ) -> Dict[str, List[float]]: + """Aggregate using weighted win-loss scores.""" + # Initialize weighted scores + weighted_scores = {} + total_comparisons = {} + + for prompt_key, group_data in prompt_groups.items(): + num_responses = len(group_data["conversations"]) + weighted_scores[prompt_key] = [0.0 for _ in range(num_responses)] + total_comparisons[prompt_key] = [0 for _ in range(num_responses)] + + # Process each comparison result + for (request_id, score_1, score_2, ranking_score), ( + prompt_key, + resp_i, + resp_j, + judge_idx, + ) in zip(comparison_results, comparison_metadata): + # Convert ranking score to weighted points + ranking_int = int(round(ranking_score)) + + if ranking_int <= 3: + # Response i wins with different magnitudes + weight_i = self.score_mapping.get(ranking_int, 0.5) + weight_j = 1.0 - weight_i + else: + # Response j wins with different magnitudes + weight_j = self.score_mapping.get(ranking_int, 0.5) + weight_i = 1.0 - weight_j + + # Accumulate weighted scores + weighted_scores[prompt_key][resp_i] += weight_i + weighted_scores[prompt_key][resp_j] += weight_j + total_comparisons[prompt_key][resp_i] += 1 + total_comparisons[prompt_key][resp_j] += 1 + + # Calculate final scores as weighted averages + final_scores = {} + for prompt_key, group_data in prompt_groups.items(): + num_responses = len(group_data["conversations"]) + final_scores[prompt_key] = [] + for resp_idx in range(num_responses): + if total_comparisons[prompt_key][resp_idx] > 0: + avg_weighted_score = ( + weighted_scores[prompt_key][resp_idx] + / total_comparisons[prompt_key][resp_idx] + ) + else: + avg_weighted_score = 0.5 # Neutral score if no comparisons + final_scores[prompt_key].append(avg_weighted_score) + + return final_scores + + def get_additional_metrics( + self, + comparison_results: List[Tuple[str, float, float, float]], + comparison_metadata: List[Tuple[str, int, int, int]], + prompt_groups: Dict[str, Dict[str, Any]], + final_scores: Dict[str, List[float]], + ) -> Dict[str, float]: + """Compute individual score metrics alongside the weighted win-loss scores.""" + # Collect individual scores for metrics + all_individual_scores_1 = [] + all_individual_scores_2 = [] + all_individual_scores = [] + + # Process each comparison result to extract individual scores + for (request_id, score_1, score_2, ranking_score), ( + prompt_key, + resp_i, + resp_j, + judge_idx, + ) in zip(comparison_results, comparison_metadata): + all_individual_scores_1.append(score_1) + all_individual_scores_2.append(score_2) + all_individual_scores.extend([score_1, score_2]) + + # Compute statistics for individual scores + individual_metrics = {} + if all_individual_scores: + individual_metrics.update( + { + "mean_individual_score": np.mean(all_individual_scores), + "std_individual_score": np.std(all_individual_scores), + "min_individual_score": np.min(all_individual_scores), + "max_individual_score": np.max(all_individual_scores), + "median_individual_score": np.median(all_individual_scores), + } + ) + + # Also compute individual score metrics per response position + if all_individual_scores_1: + individual_metrics.update( + { + "mean_individual_score_first": np.mean(all_individual_scores_1), + "mean_individual_score_second": np.mean( + all_individual_scores_2 + ), + } + ) + + return individual_metrics + + @property + def name(self) -> str: + return "weighted_win_loss" + + +class SimpleTiebreakerAggregator(PairwiseRewardAggregator): + """Use individual scores primarily, with simple ranking-based tiebreaking when scores are equal. + + GenRM Scoring System: + - Individual scores: 1-5 (where 5 is most helpful) + - Ranking scores: 1-6 (where 1 = Response 1 much better, 6 = Response 2 much better) + - Neutral ranking = 3.5 (no preference between responses) + + Tiebreaker Logic: + When individual scores are equal, we use the ranking to break ties: + - score1 = score1 + (3.5 - ranking_score) + - score2 = score2 + (ranking_score - 3.5) + + Why this works: + - When ranking < 3.5: Response 1 is preferred, so score1 gets positive adjustment, score2 gets negative + - When ranking > 3.5: Response 2 is preferred, so score2 gets positive adjustment, score1 gets negative + - When ranking = 3.5: No preference, so no adjustments (both get 0) + - The further from 3.5, the larger the adjustment magnitude + + Examples: + - ranking = 1 (Response 1 much better): score1 += 2.5, score2 -= 2.5 + - ranking = 3 (Response 1 slightly better): score1 += 0.5, score2 -= 0.5 + - ranking = 6 (Response 2 much better): score1 -= 2.5, score2 += 2.5 + """ + + def __init__(self, score_range: Tuple[float, float] = (1.0, 5.0)): + """Initialize the aggregator. + + Args: + score_range: (min_score, max_score) for normalization. + """ + self.min_score, self.max_score = score_range + + def aggregate_scores( + self, + comparison_results: List[Tuple[str, float, float, float]], + comparison_metadata: List[Tuple[str, int, int, int]], + prompt_groups: Dict[str, Dict[str, Any]], + ) -> Dict[str, List[float]]: + """Aggregate using individual scores with simple ranking tiebreaking.""" + # Collect individual scores for each response + individual_scores = {} + score_counts = {} + + for prompt_key, group_data in prompt_groups.items(): + num_responses = len(group_data["conversations"]) + individual_scores[prompt_key] = [0.0 for _ in range(num_responses)] + score_counts[prompt_key] = [0 for _ in range(num_responses)] + + # Process each comparison result + for (request_id, score_1, score_2, ranking_score), ( + prompt_key, + resp_i, + resp_j, + judge_idx, + ) in zip(comparison_results, comparison_metadata): + # Apply simple tiebreaking logic when scores are equal + if score_1 == score_2: + # When individual scores are equal, use ranking to break ties + score_1 = score_1 + 3.5 - ranking_score # Response 1 adjustment + score_2 = score_2 + ranking_score - 3.5 # Response 2 adjustment + + # Accumulate individual scores (with tiebreaking adjustments if applied) + individual_scores[prompt_key][resp_i] += score_1 + individual_scores[prompt_key][resp_j] += score_2 + score_counts[prompt_key][resp_i] += 1 + score_counts[prompt_key][resp_j] += 1 + + # Calculate average individual scores + final_scores = {} + for prompt_key, group_data in prompt_groups.items(): + num_responses = len(group_data["conversations"]) + final_scores[prompt_key] = [] + for resp_idx in range(num_responses): + if score_counts[prompt_key][resp_idx] > 0: + avg_score = ( + individual_scores[prompt_key][resp_idx] + / score_counts[prompt_key][resp_idx] + ) + else: + avg_score = ( + self.min_score + self.max_score + ) / 2 # Neutral score if no comparisons + final_scores[prompt_key].append(avg_score) + + return final_scores + + def get_additional_metrics( + self, + comparison_results: List[Tuple[str, float, float, float]], + comparison_metadata: List[Tuple[str, int, int, int]], + prompt_groups: Dict[str, Dict[str, Any]], + final_scores: Dict[str, List[float]], + ) -> Dict[str, float]: + """Compute individual score metrics alongside the simple tiebreaker scores.""" + # Collect individual scores and track tiebreaking usage + all_individual_scores = [] + all_ranking_scores = [] + tiebreak_used_count = 0 + + # Process each comparison result to extract individual scores + for (request_id, score_1, score_2, ranking_score), ( + prompt_key, + resp_i, + resp_j, + judge_idx, + ) in zip(comparison_results, comparison_metadata): + # Track original individual scores (before any tiebreaking) + all_individual_scores.extend([score_1, score_2]) + all_ranking_scores.append(ranking_score) + + # Count how often tiebreaking is used + if score_1 == score_2: + tiebreak_used_count += 1 + + # Compute statistics for individual scores + individual_metrics = {} + if all_individual_scores: + individual_metrics.update( + { + "mean_individual_score": np.mean(all_individual_scores), + "std_individual_score": np.std(all_individual_scores), + "min_individual_score": np.min(all_individual_scores), + "max_individual_score": np.max(all_individual_scores), + } + ) + + if all_ranking_scores: + individual_metrics.update( + { + "mean_ranking_score": np.mean(all_ranking_scores), + "std_ranking_score": np.std(all_ranking_scores), + } + ) + + # Add tiebreaking statistics + total_comparisons = len(comparison_results) + if total_comparisons > 0: + individual_metrics["tiebreak_usage_rate"] = ( + tiebreak_used_count / total_comparisons + ) + + return individual_metrics + + @property + def name(self) -> str: + return "simple_tiebreaker" + + +# Factory function to create aggregators +def create_aggregator(method: str, **kwargs) -> PairwiseRewardAggregator: + """Create a reward aggregator by name.""" + aggregators = { + "individual_scores": IndividualScoreAggregator, + "weighted_win_loss": WeightedWinLossAggregator, + "simple_tiebreaker": SimpleTiebreakerAggregator, + } + + if method not in aggregators: + raise ValueError( + f"Unknown aggregation method: {method}. Available: {list(aggregators.keys())}" + ) + + return aggregators[method](**kwargs) \ No newline at end of file