diff --git a/nemo_gym/cli.py b/nemo_gym/cli.py index ae4330c20..05d9c5b93 100644 --- a/nemo_gym/cli.py +++ b/nemo_gym/cli.py @@ -28,7 +28,7 @@ from subprocess import Popen from threading import Thread from time import sleep, time -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import psutil import rich @@ -49,6 +49,9 @@ GlobalConfigDictParserConfig, get_global_config_dict, ) +from nemo_gym.ray_utils import ( + _start_global_ray_gpu_scheduling_helper, +) from nemo_gym.server_status import StatusCommand from nemo_gym.server_utils import ( HEAD_SERVER_KEY_NAME, @@ -60,10 +63,15 @@ ) -def _setup_env_command(dir_path: Path, global_config_dict: DictConfig) -> str: # pragma: no cover +def _setup_env_command(dir_path: Path, global_config_dict: DictConfig, top_level_name: Optional[str] = None) -> str: # pragma: no cover head_server_deps = global_config_dict[HEAD_SERVER_DEPS_KEY_NAME] - uv_venv_cmd = f"uv venv --seed --allow-existing --python {global_config_dict[PYTHON_VERSION_KEY_NAME]} .venv" + if top_level_name is not None: + venv = f".venv-{top_level_name}" + else: + venv = ".venv" + + uv_venv_cmd = f"uv venv --seed --allow-existing --python {global_config_dict[PYTHON_VERSION_KEY_NAME]} {venv}" has_pyproject_toml = exists(f"{dir_path / 'pyproject.toml'}") has_requirements_txt = exists(f"{dir_path / 'requirements.txt'}") @@ -74,21 +82,29 @@ def _setup_env_command(dir_path: Path, global_config_dict: DictConfig) -> str: ) elif has_pyproject_toml: install_cmd = f"""uv pip install '-e .' {" ".join(head_server_deps)}""" + if dir_path.name == "vllm_model": + # NB: --no-deps is a workaround for installing vllm (current version: 0.11.2) on a cpu target, + # b/c `uv pip install` resolves dependencies differently vs `pip install`. + install_cmd = f"""uv pip install --no-deps 'vllm==0.11.2' && {install_cmd}""" elif has_requirements_txt: install_cmd = f"""uv pip install -r requirements.txt {" ".join(head_server_deps)}""" else: raise RuntimeError(f"Missing pyproject.toml or requirements.txt for uv venv setup in server dir: {dir_path}") + if top_level_name is not None: + uv_venv_cmd = f"{uv_venv_cmd} > >(sed 's/^/({top_level_name}) /') 2> >(sed 's/^/({top_level_name}) /' >&2)" + install_cmd = f"{install_cmd} > >(sed 's/^/({top_level_name}) /') 2> >(sed 's/^/({top_level_name}) /' >&2)" + cmd = f"""cd {dir_path} \\ && {uv_venv_cmd} \\ - && source .venv/bin/activate \\ + && source {venv}/bin/activate \\ && {install_cmd} \\ """ return cmd -def _run_command(command: str, working_dir_path: Path) -> Popen: # pragma: no cover +def _run_command(command: str, working_dir_path: Path, top_level_name: Optional[str] = None) -> Popen: # pragma: no cover work_dir = f"{working_dir_path.absolute()}" custom_env = environ.copy() py_path = custom_env.get("PYTHONPATH", None) @@ -96,7 +112,19 @@ def _run_command(command: str, working_dir_path: Path) -> Popen: # pragma: no c custom_env["PYTHONPATH"] = f"{work_dir}:{py_path}" else: custom_env["PYTHONPATH"] = work_dir - return Popen(command, executable="/bin/bash", shell=True, env=custom_env) + redirect_stdout = sys.stdout + redirect_stderr = sys.stderr + if top_level_name is not None: + redirect_stdout = open(f"{work_dir}/run-{top_level_name}.out.log", "a") + redirect_stderr = open(f"{work_dir}/run-{top_level_name}.err.log", "a") + return Popen( + command, + executable="/bin/bash", + shell=True, + env=custom_env, + stdout=redirect_stdout, + stderr=redirect_stderr, + ) class RunConfig(BaseNeMoGymCLIConfig): @@ -152,6 +180,7 @@ class RunHelper: # pragma: no cover _head_server: uvicorn.Server _head_server_thread: Thread _head_server_instance: HeadServer + _head_ray_gpu_helper: Any _processes: Dict[str, Popen] _server_instance_display_configs: List[ServerInstanceDisplayConfig] @@ -164,6 +193,8 @@ def start(self, global_config_dict_parser_config: GlobalConfigDictParserConfig) # Note: This function will modify the global config dict - update `ray_head_node_address` initialize_ray() + self._head_ray_gpu_helper = _start_global_ray_gpu_scheduling_helper() + # Assume Nemo Gym Run is for a single agent. escaped_config_dict_yaml_str = shlex.quote(OmegaConf.to_yaml(global_config_dict)) @@ -201,7 +232,7 @@ def start(self, global_config_dict_parser_config: GlobalConfigDictParserConfig) dir_path = PARENT_DIR / Path(first_key, second_key) - command = f"""{_setup_env_command(dir_path, global_config_dict)} \\ + command = f"""{_setup_env_command(dir_path, global_config_dict, top_level_path)} \\ && {NEMO_GYM_CONFIG_DICT_ENV_VAR_NAME}={escaped_config_dict_yaml_str} \\ {NEMO_GYM_CONFIG_PATH_ENV_VAR_NAME}={shlex.quote(top_level_path)} \\ python {str(entrypoint_fpath)}""" diff --git a/nemo_gym/comparison_strategies.py b/nemo_gym/comparison_strategies.py new file mode 100644 index 000000000..6efef3782 --- /dev/null +++ b/nemo_gym/comparison_strategies.py @@ -0,0 +1,150 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Comparison strategies for multi-generation reward computation. +""" +import hashlib +import json +from typing import Any, Dict, List, Optional, Protocol, Tuple, runtime_checkable + +from pydantic import BaseModel, Field + +from nemo_gym.server_utils import ServerClient, raise_for_status + + +@runtime_checkable +class ComparisonStrategy(Protocol): + """Protocol for comparison strategies that compute rewards from multiple generations.""" + + agent_names: List[str] + num_generations_per_prompt: int + policy_model_server_name: str + + async def compare( + self, + conversation_history: List[Dict[str, str]], + responses: List[str], + server_client: ServerClient, + principle: Optional[str] = None, + ) -> Tuple[List[float], Dict[str, float]]: + """Compare N responses and return (rewards, metrics).""" + ... + + +class GenRMStrategyConfig(BaseModel): + """Configuration for GenRM comparison strategy.""" + agent_names: List[str] = Field(default_factory=lambda: ["genrm_simple_agent"]) + num_generations_per_prompt: int = 16 + genrm_compare_server_name: str = "genrm_compare" + policy_model_server_name: str = "policy_model" + + +class GenRMStrategy: + """GenRM comparison strategy using pairwise comparisons.""" + + def __init__(self, config: GenRMStrategyConfig): + self.config = config + self.agent_names = config.agent_names + self.num_generations_per_prompt = config.num_generations_per_prompt + self.policy_model_server_name = config.policy_model_server_name + + async def compare( + self, + conversation_history: List[Dict[str, str]], + response_objs: List[Dict], + server_client: ServerClient, + principle: Optional[str] = None, + ) -> Tuple[List[float], Dict[str, float]]: + """Call genrm_compare server to get rewards for each response. + + Args: + conversation_history: The conversation context + response_objs: List of raw Response API objects + server_client: The server client for making requests + principle: Optional principle for principle-based GenRM comparison + + Returns: + Tuple of (rewards, metrics) from GenRM comparison + """ + payload = { + "conversation_history": conversation_history, + "response_objs": response_objs, + } + + if principle is not None: + payload["principle"] = principle + + res = await server_client.post( + server_name=self.config.genrm_compare_server_name, + url_path="/compare", + json=payload, + ) + await raise_for_status(res) + result = await res.json() + + rewards = result.get("rewards", [0.0] * len(response_objs)) + metrics = result.get("metrics", {}) + + return rewards, metrics + + +def get_prompt_key(example: Dict) -> str: + """Get stable key for grouping examples by prompt and principle. + + Examples with the same conversation history but different principles + should be in separate groups, so we include principle in the hash. + """ + if "prompt_id" in example: + # If prompt_id exists, combine it with principle for uniqueness + prompt_id = str(example["prompt_id"]) + principle = example.get("principle") + if principle is not None: + return f"{prompt_id}::{principle}" + return prompt_id + + # Hash both conversation history and principle together + conv = extract_conversation_history(example) + principle = example.get("principle") + key_data = { + "conversation": conv, + "principle": principle, + } + return hashlib.sha256(json.dumps(key_data, sort_keys=True).encode()).hexdigest() + + +def extract_conversation_history(example: Dict) -> List[Dict]: + """Extract conversation history from example. + + Gym examples store history in responses_create_params.input + """ + responses_create_params = example.get("responses_create_params") + if responses_create_params is None: + raise ValueError(f"Example missing 'responses_create_params': {list(example.keys())}") + if "input" not in responses_create_params: + raise ValueError(f"responses_create_params missing 'input': {list(responses_create_params.keys())}") + return responses_create_params["input"] + + +def extract_generated_text(gen_result: Dict) -> str: + """Extract generated text from generation result.""" + if not isinstance(gen_result, dict): + raise ValueError(f"Expected dict, got {type(gen_result)}") + if "output" in gen_result: + output = gen_result["output"] + if isinstance(output, list) and output: + return output[0].get("content", "") + if isinstance(output, str): + return output + if "response" in gen_result: + return gen_result["response"] + raise ValueError(f"Cannot extract generated text from: {list(gen_result.keys())}") + + +async def generate_response(example: Dict, server_client: ServerClient, model_server: str) -> Dict: + """Generate a single response using the policy model.""" + params = example.get("responses_create_params") + if params is None: + raise ValueError(f"Example missing 'responses_create_params': {list(example.keys())}") + res = await server_client.post(server_name=model_server, url_path="/v1/responses", json=params) + await raise_for_status(res) + return await res.json() diff --git a/nemo_gym/config_types.py b/nemo_gym/config_types.py index 55e21e4b5..57ccdd640 100644 --- a/nemo_gym/config_types.py +++ b/nemo_gym/config_types.py @@ -393,6 +393,7 @@ class Domain(str, Enum): GAMES = "games" TRANSLATION = "translation" E2E = "e2e" + RLHF = "rlhf" OTHER = "other" diff --git a/nemo_gym/global_config.py b/nemo_gym/global_config.py index d69b3025d..0215f0fb8 100644 --- a/nemo_gym/global_config.py +++ b/nemo_gym/global_config.py @@ -45,6 +45,10 @@ DISALLOWED_PORTS_KEY_NAME = "disallowed_ports" HEAD_SERVER_DEPS_KEY_NAME = "head_server_deps" PYTHON_VERSION_KEY_NAME = "python_version" +RAY_HEAD_NODE_ADDRESS_KEY_NAME = "ray_head_node_address" +RAY_NAMESPACE_KEY_NAME = "ray_namespace" +RAY_GPU_NODES_KEY_NAME = "ray_gpu_nodes" +RAY_NUM_GPUS_PER_NODE_KEY_NAME = "ray_num_gpus_per_node" USE_ABSOLUTE_IP = "use_absolute_ip" NEMO_GYM_RESERVED_TOP_LEVEL_KEYS = [ CONFIG_PATHS_KEY_NAME, @@ -54,6 +58,10 @@ DISALLOWED_PORTS_KEY_NAME, HEAD_SERVER_DEPS_KEY_NAME, PYTHON_VERSION_KEY_NAME, + RAY_HEAD_NODE_ADDRESS_KEY_NAME, + RAY_NAMESPACE_KEY_NAME, + RAY_GPU_NODES_KEY_NAME, + RAY_NUM_GPUS_PER_NODE_KEY_NAME, USE_ABSOLUTE_IP, ] @@ -371,11 +379,19 @@ def get_first_server_config_dict(global_config_dict: DictConfig, top_level_path: def find_open_port( disallowed_ports: Optional[List[int]] = None, - max_retries: int = 50, + max_retries: int = 100, ) -> int: # pragma: no cover if disallowed_ports is None: disallowed_ports = [] + default_disallowed_ports = set( + list(range(53000, 53010+1)) + + list(range(54000, 60000+1)) + + [10001, 8265, 52365, 52365+1] + ) + + disallowed_ports = default_disallowed_ports | set(disallowed_ports) + # Find an open port that doesn't conflict with disallowed ports. for _ in range(max_retries): with socket() as s: diff --git a/nemo_gym/openai_utils.py b/nemo_gym/openai_utils.py index baae61476..bb214256a 100644 --- a/nemo_gym/openai_utils.py +++ b/nemo_gym/openai_utils.py @@ -25,19 +25,42 @@ Union, ) -from openai.types.chat import ( - ChatCompletion, - ChatCompletionAssistantMessageParam, - ChatCompletionContentPartTextParam, - ChatCompletionDeveloperMessageParam, - ChatCompletionMessage, - ChatCompletionMessageToolCall, - ChatCompletionMessageToolCallParam, - ChatCompletionSystemMessageParam, - ChatCompletionToolMessageParam, - ChatCompletionToolParam, - ChatCompletionUserMessageParam, -) + +# Backward-compatible imports for different openai versions +try: + from openai.types.chat import ( + ChatCompletion, + ChatCompletionAssistantMessageParam, + ChatCompletionContentPartTextParam, + ChatCompletionDeveloperMessageParam, + ChatCompletionMessage, + ChatCompletionMessageToolCall, + ChatCompletionMessageToolCallParam, + ChatCompletionSystemMessageParam, + ChatCompletionToolMessageParam, + ChatCompletionToolParam, + ChatCompletionUserMessageParam, + ) +except ImportError: + from openai.types.chat.chat_completion import ChatCompletion + from openai.types.chat.chat_completion_message import ChatCompletionMessage + from openai.types.chat.chat_completion_message_tool_call import ( + ChatCompletionMessageToolCall, + ) + + try: + from openai.types.chat import ( + ChatCompletionAssistantMessageParam, + ChatCompletionContentPartTextParam, + ChatCompletionDeveloperMessageParam, + ChatCompletionMessageToolCallParam, + ChatCompletionSystemMessageParam, + ChatCompletionToolMessageParam, + ChatCompletionToolParam, + ChatCompletionUserMessageParam, + ) + except ImportError as exc2: + raise ImportError("openai types.chat param classes missing; please install a newer openai package") from exc2 from openai.types.chat.chat_completion import Choice from openai.types.chat.chat_completion_assistant_message_param import ( ContentArrayOfContentPart, @@ -147,13 +170,13 @@ class NeMoGymResponseOutputMessage(BaseModel): class NeMoGymEasyInputMessage(BaseModel): content: Union[str, ResponseInputMessageContentListParam] - role: Literal["user", "assistant", "system", "developer"] + role: Literal["user", "assistant", "system", "developer", "response_1", "response_2", "principle"] type: Literal["message"] = "message" class NeMoGymMessage(BaseModel): content: ResponseInputMessageContentListParam - role: Literal["user", "system", "developer"] + role: Literal["user", "system", "developer", "response_1", "response_2", "principle"] status: Literal["in_progress", "completed", "incomplete"] = "completed" type: Literal["message"] = "message" @@ -365,6 +388,13 @@ class NeMoGymChatCompletionToolMessageParam(ChatCompletionToolMessageParam): content: Required[Union[str, List[NeMoGymChatCompletionContentPartTextParam]]] +class NeMoGymChatCompletionCustomRoleMessageParam(TypedDict): + """Message param for custom roles not in the standard OpenAI API (e.g., GenRM response_1/response_2).""" + + role: Required[str] + content: Required[str] + + class NeMoGymFunctionToolParam(FunctionToolParam): pass @@ -375,6 +405,7 @@ class NeMoGymFunctionToolParam(FunctionToolParam): NeMoGymChatCompletionUserMessageParam, NeMoGymChatCompletionAssistantMessageParam, NeMoGymChatCompletionToolMessageParam, + NeMoGymChatCompletionCustomRoleMessageParam, # Don't add deprecated. # NeMoGymChatCompletionFunctionMessageParam, # Training: diff --git a/nemo_gym/ray_utils.py b/nemo_gym/ray_utils.py new file mode 100644 index 000000000..77ce6430b --- /dev/null +++ b/nemo_gym/ray_utils.py @@ -0,0 +1,163 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys +from collections import defaultdict +from time import sleep +from typing import Dict, Optional + +import ray +import ray.util.state +from ray.actor import ActorClass, ActorProxy +from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy + +from nemo_gym.global_config import ( + RAY_GPU_NODES_KEY_NAME, + RAY_NUM_GPUS_PER_NODE_KEY_NAME, + get_global_config_dict, +) + + +def _prepare_ray_worker_env_vars() -> Dict[str, str]: # pragma: no cover + worker_env_vars = { + **os.environ, + } + pop_env_vars = [ + "CUDA_VISIBLE_DEVICES", + "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", + "RAY_JOB_ID", + "RAY_RAYLET_PID", + ] + for k in pop_env_vars: + worker_env_vars.pop(k, None) + return worker_env_vars + + +def _start_global_ray_gpu_scheduling_helper(node_id: Optional[str] = None) -> ActorProxy: # pragma: no cover + cfg = get_global_config_dict() + helper_options = { + "name": "_NeMoGymRayGPUSchedulingHelper", + "num_cpus": 0, + } + if node_id is not None: + helper_options["scheduling_strategy"] = NodeAffinitySchedulingStrategy( + node_id=node_id, + soft=True, + ) + helper = _NeMoGymRayGPUSchedulingHelper.options(**helper_options).remote(cfg) + ray.get(helper._post_init.remote()) + return helper + + +def get_global_ray_gpu_scheduling_helper() -> ActorProxy: # pragma: no cover + cfg = get_global_config_dict() + while True: + try: + get_actor_args = { + "name": "_NeMoGymRayGPUSchedulingHelper", + } + ray_namespace = cfg.get("ray_namespace", None) + if ray_namespace is None: + ray_namespace = "nemo_gym" + get_actor_args["namespace"] = ray_namespace + worker = ray.get_actor(**get_actor_args) + return worker + except ValueError: + sleep(3) + + +@ray.remote +class _NeMoGymRayGPUSchedulingHelper: # pragma: no cover + def __init__(self, cfg): + self.cfg = cfg + self.avail_gpus_dict = defaultdict(int) + self.used_gpus_dict = defaultdict(int) + + def _post_init(self) -> None: + # If value of RAY_GPU_NODES_KEY_NAME is None, then Gym will use all Ray GPU nodes + # for scheduling GPU actors. + # Otherwise if value of RAY_GPU_NODES_KEY_NAME is a list, then Gym will only use + # the listed Ray GPU nodes for scheduling GPU actors. + allowed_gpu_nodes = self.cfg.get(RAY_GPU_NODES_KEY_NAME, None) + if allowed_gpu_nodes is not None: + allowed_gpu_nodes = set(allowed_gpu_nodes) + + head = self.cfg["ray_head_node_address"] + node_states = ray.util.state.list_nodes(head, detail=True, limit=10000) + for state in node_states: + assert state.node_id is not None + avail_num_gpus = state.resources_total.get("GPU", 0) + if allowed_gpu_nodes is not None and state.node_id not in allowed_gpu_nodes: + continue + self.avail_gpus_dict[state.node_id] += avail_num_gpus + + def alloc_gpu_node(self, num_gpus: int) -> Optional[str]: + for node_id, avail_num_gpus in self.avail_gpus_dict.items(): + used_num_gpus = self.used_gpus_dict[node_id] + if used_num_gpus + num_gpus <= avail_num_gpus: + self.used_gpus_dict[node_id] += num_gpus + return node_id + return None + + +def lookup_ray_node_id_to_ip_dict() -> Dict[str, str]: # pragma: no cover + cfg = get_global_config_dict() + head = cfg["ray_head_node_address"] + id_to_ip = {} + node_states = ray.util.state.list_nodes(head, limit=10000) + for state in node_states: + id_to_ip[state.node_id] = state.node_ip + return id_to_ip + + +def lookup_current_ray_node_id() -> str: # pragma: no cover + return ray.get_runtime_context().get_node_id() + + +def lookup_current_ray_node_ip() -> str: # pragma: no cover + return lookup_ray_node_id_to_ip_dict()[lookup_current_ray_node_id()] + + +def spinup_single_ray_gpu_node_worker( + worker_cls: ActorClass, + num_gpus: int, + *worker_args, + **worker_kwargs, +) -> ActorProxy: # pragma: no cover + cfg = get_global_config_dict() + + num_gpus_per_node = cfg.get(RAY_NUM_GPUS_PER_NODE_KEY_NAME, 8) + assert num_gpus >= 1, f"Must request at least 1 GPU node for spinning up {worker_cls}" + assert num_gpus <= num_gpus_per_node, ( + f"Requested {num_gpus} > {num_gpus_per_node} GPU nodes for spinning up {worker_cls}" + ) + + helper = get_global_ray_gpu_scheduling_helper() + node_id = ray.get(helper.alloc_gpu_node.remote(num_gpus)) + if node_id is None: + raise RuntimeError(f"Cannot find an available Ray node with {num_gpus} GPUs to spin up {worker_cls}") + + worker_options = {} + worker_options["num_gpus"] = num_gpus + worker_options["scheduling_strategy"] = NodeAffinitySchedulingStrategy( + node_id=node_id, + soft=False, + ) + worker_options["runtime_env"] = { + "py_executable": sys.executable, + "env_vars": _prepare_ray_worker_env_vars(), + } + worker = worker_cls.options(**worker_options).remote(*worker_args, **worker_kwargs) + return worker diff --git a/nemo_gym/rollout_collection.py b/nemo_gym/rollout_collection.py index f522d2b07..5bbcdfad8 100644 --- a/nemo_gym/rollout_collection.py +++ b/nemo_gym/rollout_collection.py @@ -14,11 +14,12 @@ # limitations under the License. import asyncio import json +import logging from asyncio import Future, Semaphore -from collections import Counter +from collections import Counter, defaultdict from contextlib import nullcontext from itertools import chain, repeat -from typing import Any, Dict, Iterator, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Set, Tuple from pydantic import BaseModel, Field from tqdm.asyncio import tqdm @@ -34,6 +35,12 @@ ) +if TYPE_CHECKING: + from nemo_gym.comparison_strategies import ComparisonStrategy + +logger = logging.getLogger(__name__) + + class RolloutCollectionConfig(BaseNeMoGymCLIConfig): """ Perform a batch of rollout collection. @@ -122,13 +129,28 @@ async def _post_coroutine(row: dict) -> None: print(json.dumps(avg_metrics, indent=4)) def run_examples( - self, examples: List[Dict], head_server_config: Optional[BaseServerConfig] = None + self, + examples: List[Dict], + head_server_config: Optional[BaseServerConfig] = None, + comparison_strategy: Optional["ComparisonStrategy"] = None, ) -> Iterator[Future]: """ - We provide this function as a lower level interface for running rollout collection. + Run rollout collection with optional comparison strategy. + + When comparison_strategy is provided, samples matching strategy.agent_names + are processed with generation-only + buffering + comparison, while other + samples go through the standard agent /run path. Both run in parallel. """ server_client = self.setup_server_client(head_server_config) + if comparison_strategy: + return self._run_with_comparison_strategy(examples, server_client, comparison_strategy) + else: + return self._run_standard(examples, server_client) + + def _run_standard(self, examples: List[Dict], server_client: ServerClient) -> Iterator[Future]: + """Standard rollout collection - each sample through its agent.""" + async def _post_subroutine(row: Dict) -> Tuple[Dict, Dict]: res = await server_client.post(server_name=row["agent_ref"]["name"], url_path="/run", json=row) await raise_for_status(res) @@ -138,6 +160,113 @@ async def _post_subroutine(row: Dict) -> Tuple[Dict, Dict]: map(_post_subroutine, examples), desc="Collecting rollouts", miniters=10, total=len(examples) ) + def _run_with_comparison_strategy( + self, + examples: List[Dict], + server_client: ServerClient, + strategy: "ComparisonStrategy", + ) -> Iterator[Future]: + """Run with comparison strategy - strategy samples get generation + compare, others get /run.""" + from nemo_gym.comparison_strategies import ( + extract_conversation_history, + generate_response, + get_prompt_key, + ) + + strategy_agent_names = set(strategy.agent_names) + strategy_samples = [] + standard_samples = [] + + for idx, example in enumerate(examples): + agent_ref = example.get("agent_ref", {}) + agent_name = agent_ref.get("name", "") if isinstance(agent_ref, dict) else "" + if agent_name in strategy_agent_names: + strategy_samples.append((idx, example)) + else: + standard_samples.append((idx, example)) + + logger.info(f"Comparison strategy: {len(strategy_samples)} samples, Standard: {len(standard_samples)} samples") + + async def _run_all() -> List[Dict]: + results = [None] * len(examples) + + async def process_standard(): + async def _do(idx: int, ex: Dict): + ex_copy = ex.copy() + agent_name = ex_copy.pop("agent_ref")["name"] + res = await server_client.post(server_name=agent_name, url_path="/run", json=ex_copy) + await raise_for_status(res) + results[idx] = await res.json() + + if standard_samples: + await asyncio.gather(*[_do(idx, ex) for idx, ex in standard_samples]) + + async def process_strategy(): + if not strategy_samples: + return + num_gens = strategy.num_generations_per_prompt + policy_model = strategy.policy_model_server_name + prompt_buffers: Dict[str, List[tuple]] = defaultdict(list) + compare_tasks: List[asyncio.Task] = [] + compared: Set[str] = set() + lock = asyncio.Lock() + + async def on_gen_complete(idx: int, example: Dict, gen_result: Dict): + prompt_key = get_prompt_key(example) + async with lock: + prompt_buffers[prompt_key].append((idx, example, gen_result)) + if len(prompt_buffers[prompt_key]) == num_gens and prompt_key not in compared: + compared.add(prompt_key) + group = prompt_buffers[prompt_key] + task = asyncio.create_task(_compare_group(prompt_key, group)) + compare_tasks.append(task) + + async def _compare_group(prompt_key: str, group: List[tuple]): + first_example = group[0][1] + conv_history = extract_conversation_history(first_example) + # Extract principle from example data for principle-based GenRM + principle = first_example.get("principle") + + # Debug log: show whether GenRM is using principle-based judging + if principle: + print(f"[GenRM] Judging with PRINCIPLE (len={len(principle)}): {principle}") + else: + print("[GenRM] Judging WITHOUT principle") + + # Pass raw Response API objects - text extraction happens in genrm_compare + response_objs = [gr for _, _, gr in group] + rewards, genrm_metrics = await strategy.compare( + conv_history, response_objs, server_client, principle=principle + ) + + for i, (idx, _, gen_result) in enumerate(group): + # Include GenRM metrics in each result so they flow back to NeMo-RL + results[idx] = { + "response": gen_result, + "reward": rewards[i], + **{f"genrm_{k}": v for k, v in genrm_metrics.items()}, + } + + async def gen_and_notify(idx: int, example: Dict): + gen_result = await generate_response(example, server_client, policy_model) + await on_gen_complete(idx, example, gen_result) + + await asyncio.gather(*[gen_and_notify(idx, ex) for idx, ex in strategy_samples]) + if compare_tasks: + await asyncio.gather(*compare_tasks) + + await asyncio.gather(process_standard(), process_strategy()) + return results + + main_future = asyncio.ensure_future(_run_all()) + + async def _get_at(idx: int) -> Tuple[Dict, Dict]: + results = await main_future + return examples[idx], results[idx] + + futures = [asyncio.ensure_future(_get_at(i)) for i in range(len(examples))] + return tqdm.as_completed(futures, desc="Collecting rollouts", miniters=10, total=len(examples)) + def setup_server_client(self, head_server_config: Optional[BaseServerConfig] = None) -> ServerClient: server_client = ServerClient.load_from_global_config(head_server_config) diff --git a/nemo_gym/server_utils.py b/nemo_gym/server_utils.py index 58562ea58..59372d6ef 100644 --- a/nemo_gym/server_utils.py +++ b/nemo_gym/server_utils.py @@ -218,7 +218,13 @@ def load_from_global_config(cls, head_server_config: Optional[BaseServerConfig] ) from e global_config_dict_yaml = response.content.decode() - global_config_dict = OmegaConf.create(json.loads(global_config_dict_yaml)) + # TODO: json.loads below can fail. + try: + global_config_dict = OmegaConf.create(json.loads(global_config_dict_yaml)) + except Exception as e: + print(f"Failed to load global config dict yaml: {type(e).__name__} {e}") + print("Retrying without json.loads...", flush=True) + global_config_dict = OmegaConf.create(global_config_dict_yaml) return cls(head_server_config=head_server_config, global_config_dict=global_config_dict) @@ -350,6 +356,7 @@ def initialize_ray() -> None: global_config_dict = get_global_config_dict() ray_head_node_address = global_config_dict.get("ray_head_node_address") + ray_namespace = global_config_dict.get("ray_namespace", None) ray_init_kwargs = dict(ignore_reinit_error=True) if ray_head_node_address: @@ -358,6 +365,11 @@ def initialize_ray() -> None: else: print("Starting Ray cluster...") + if ray_namespace is None: + ray_namespace = "nemo_gym" + print(f"Ray namespace: {ray_namespace}") + ray_init_kwargs["namespace"] = ray_namespace + ray.init(**ray_init_kwargs) if not ray_head_node_address: diff --git a/pyproject.toml b/pyproject.toml index 8973f9753..400c4cdd1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -147,7 +147,7 @@ dependencies = [ # Ray: Used for distributed processing # Updated Fri Oct 18, 2025 with ray[default]==2.46.0 # License: Apache 2.0 https://github.com/ray-project/ray/blob/master/LICENSE - "ray[default]", + "ray[default]==2.49.2", # psutil: Cross-platform process and system utilities # Updated: Fri Nov 07, 2025 with psutil==6.1.1 diff --git a/resources_servers/calendar/app.py b/resources_servers/calendar/app.py index c844d7f4c..f0afeab64 100644 --- a/resources_servers/calendar/app.py +++ b/resources_servers/calendar/app.py @@ -47,7 +47,27 @@ def setup_webserver(self) -> FastAPI: return app async def verify(self, body: CalendarVerifyRequest) -> BaseVerifyResponse: - assistant_response = body.response.output[-1].content[0].text + # Extract the assistant's text response from the last output item. + # + # For reasoning models (e.g., with deepseek_r1 reasoning_parser), the output + # structure is: [ReasoningItem, MessageItem] where: + # - ReasoningItem: has .reasoning attribute (thinking/CoT tokens) + # - MessageItem: has .content attribute (actual response text) + # + # The last item should be a MessageItem with .content, but if the model + # hit the token limit while still thinking, the last item will be a + # ReasoningItem without .content. In that case, we return reward=0. + assistant_response = "" + if body.response.output: + last_output = body.response.output[-1] + if hasattr(last_output, "content") and last_output.content: + assistant_response = last_output.content[0].text + + # If no valid response (e.g., model only produced thinking tokens), + # return zero reward + if not assistant_response: + return BaseVerifyResponse(**body.model_dump(), reward=0) + exp_cal_state = body.exp_cal_state try: reward, reason = grade_assistant_response(assistant_response, exp_cal_state) diff --git a/resources_servers/equivalence_llm_judge/app.py b/resources_servers/equivalence_llm_judge/app.py index e6d74153a..14cdc48fb 100644 --- a/resources_servers/equivalence_llm_judge/app.py +++ b/resources_servers/equivalence_llm_judge/app.py @@ -21,7 +21,9 @@ # limitations under the License. from __future__ import annotations +import asyncio import re +from contextlib import nullcontext from typing import Any, Optional from fastapi import FastAPI @@ -58,8 +60,11 @@ class LLMJudgeResourcesServerConfig(BaseResourcesServerConfig): judge_model_server: ModelServerRef judge_responses_create_params: NeMoGymResponseCreateParamsNonStreaming + # Concurrency limit for judge endpoint requests. Set to None to disable limiting. + judge_endpoint_max_concurrency: Optional[int] = 64 + judge_system_message: Optional[str] = None - judge_prompt_template: str + judge_prompt_template_fpath: str = "prompt_templates/equivalence_llm_judge.txt" judge_equal_label: str = "[[A=B]]" judge_not_equal_label: str = "[[A!=B]]" # Optional regex to extract the question from the last user message. @@ -249,6 +254,17 @@ class LLMJudgeResourcesServer(SimpleResourcesServer): config: LLMJudgeResourcesServerConfig + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + if self.config.judge_endpoint_max_concurrency is not None: + self._judge_endpoint_max_concurrency = asyncio.Semaphore(value=self.config.judge_endpoint_max_concurrency) + else: + self._judge_endpoint_max_concurrency = nullcontext() + + with open(self.config.judge_prompt_template_fpath, "r") as f: + self._judge_prompt_template = f.read().strip() + def setup_webserver(self) -> FastAPI: app = super().setup_webserver() return app @@ -419,7 +435,7 @@ async def _generate_judge_evaluation( not_equal_label = cfg.judge_not_equal_label responses_create_params = cfg.judge_responses_create_params.model_copy(deep=True) - prompt_template = cfg.judge_prompt_template + prompt_template = self._judge_prompt_template system_message = cfg.judge_system_message user_prompt = prompt_template.format( @@ -432,12 +448,21 @@ async def _generate_judge_evaluation( msgs.append(NeMoGymEasyInputMessage(role="user", content=user_prompt)) responses_create_params.input = msgs - response = await self.server_client.post( - server_name=cfg.judge_model_server.name, - url_path="/v1/responses", - json=responses_create_params, - ) - judge_response = NeMoGymResponse.model_validate(await response.json()) + async with self._judge_endpoint_max_concurrency: + try: + response = await self.server_client.post( + server_name=cfg.judge_model_server.name, + url_path="/v1/responses", + json=responses_create_params, + ) + judge_response = NeMoGymResponse.model_validate(await response.json()) + except Exception as e: + print( + f"DEBUG: LLMJudgeResourcesServer: judge model server HTTP POST error: {type(e).__name__} {e}", + flush=True, + ) + raise e + eval_record = JudgeEvaluation( responses_create_params=responses_create_params, response=judge_response, diff --git a/resources_servers/equivalence_llm_judge/configs/equivalence_llm_judge.yaml b/resources_servers/equivalence_llm_judge/configs/equivalence_llm_judge.yaml index 55bfb96d2..d2275fe85 100644 --- a/resources_servers/equivalence_llm_judge/configs/equivalence_llm_judge.yaml +++ b/resources_servers/equivalence_llm_judge/configs/equivalence_llm_judge.yaml @@ -7,64 +7,9 @@ equivalence_llm_judge: name: policy_model judge_responses_create_params: input: [] - judge_prompt_template: |- - ===== System role ===== - You are a meticulous STEM grader. Compare a candidate answer to a GOLD reference for a scientific question and decide strict equivalence. - - Grading priorities (in order): - 1) Factual equivalence to GOLD (accept algebraically/formally equivalent formulations). - 2) Completeness on required parts — the candidate must include the same core parts/subclaims as the GOLD. - - Rules: - - Treat GOLD as authoritative for what counts as correct. - - If GOLD is a range or set, the candidate is equivalent only if it lies within that range or is a member of that set. - - For formulas/derivations, accept mathematically identical transformations (e.g., symbol reordering, factoring, equivalent identities). - - Multi-part: all essential parts must match for “equivalent”; otherwise they are not equivalent. - - Be concise. Do NOT reveal or rewrite the GOLD. - - Show your reason why they are equivalent or not equivalent first and then provide the output. - - Output (at the end after double newlines): - - If equivalent: [[A=B]] they are equivalent - - If not equivalent: [[A!=B]] they are not equivalent - - ===== Example 1 (equivalent) ===== - QUESTION: - State Avogadro’s constant (include units). - - GOLD: - 6.022 × 10^23 mol^-1 - - CANDIDATE: - 6.022e23 per mole. - - The candidate gives the same magnitude in scientific notation and the same “per mole” unit; no extra or missing qualifiers. - - [[A=B]] they are equivalent - - ===== Example 2 (not equivalent) ===== - QUESTION: - State the first law of thermodynamics for a closed system and identify what each symbol represents. - - GOLD: - ΔU = Q − W ; ΔU is change in internal energy, Q is heat added to the system, W is work done by the system. - - CANDIDATE: - ΔU = Q + W ; ΔU is change in internal energy, Q is heat added to the system, W is work done by the system. - - The candidate uses the opposite sign convention for work relative to the required relationship; one core part is incorrect, so the overall statement does not match. - - [[A!=B]] they are not equivalent - - ===== Inputs ===== - QUESTION: - {question} - - GOLD: - {expected_answer} - - CANDIDATE: - {generated_answer} + judge_prompt_template_fpath: prompt_templates/equivalence_llm_judge.txt + # Concurrency and rate limiting configuration + judge_endpoint_max_concurrency: 64 # Reduced from 64 to avoid rate limits judge_system_message: null judge_equal_label: "[[A=B]]" judge_not_equal_label: "[[A!=B]]" @@ -135,7 +80,7 @@ equivalence_llm_judge_simple_agent: license: "TBD" jsonl_fpath: resources_servers/equivalence_llm_judge/data/example_openqa.jsonl - name: train - type: train + type: example license: Apache 2.0 jsonl_fpath: resources_servers/equivalence_llm_judge/data/train.jsonl huggingface_identifier: diff --git a/resources_servers/equivalence_llm_judge/configs/lc.yaml b/resources_servers/equivalence_llm_judge/configs/lc.yaml new file mode 120000 index 000000000..b8218928a --- /dev/null +++ b/resources_servers/equivalence_llm_judge/configs/lc.yaml @@ -0,0 +1 @@ +lc_judge.yaml \ No newline at end of file diff --git a/resources_servers/equivalence_llm_judge/configs/lc_judge.yaml b/resources_servers/equivalence_llm_judge/configs/lc_judge.yaml new file mode 100644 index 000000000..db3843b50 --- /dev/null +++ b/resources_servers/equivalence_llm_judge/configs/lc_judge.yaml @@ -0,0 +1,35 @@ +lc_judge: + resources_servers: + equivalence_llm_judge: + entrypoint: app.py + judge_model_server: + type: responses_api_models + name: policy_model + judge_responses_create_params: + input: [] + judge_prompt_template_fpath: prompt_templates/lc_judge.txt + judge_endpoint_max_concurrency: 64 + judge_system_message: null + judge_equal_label: CORRECT + judge_not_equal_label: INCORRECT + check_twice_swap: false + reward_if_swap_fails: 0.0 + question_extract_regex: ^QUESTION:\s*(.*)$ + response_extract_regex: null + domain: knowledge + verified: false +lc_judge_simple_agent: + responses_api_agents: + simple_agent: + entrypoint: app.py + resources_server: + type: resources_servers + name: lc_judge + model_server: + type: responses_api_models + name: policy_model + datasets: + - name: example + type: example + license: TBD + jsonl_fpath: resources_servers/equivalence_llm_judge/data/example.jsonl diff --git a/resources_servers/equivalence_llm_judge/prompt_templates/equivalence_llm_judge.txt b/resources_servers/equivalence_llm_judge/prompt_templates/equivalence_llm_judge.txt new file mode 100644 index 000000000..7a7de001b --- /dev/null +++ b/resources_servers/equivalence_llm_judge/prompt_templates/equivalence_llm_judge.txt @@ -0,0 +1,58 @@ + +===== System role ===== +You are a meticulous STEM grader. Compare a candidate answer to a GOLD reference for a scientific question and decide strict equivalence. + +Grading priorities (in order): +1) Factual equivalence to GOLD (accept algebraically/formally equivalent formulations). +2) Completeness on required parts — the candidate must include the same core parts/subclaims as the GOLD. + +Rules: +- Treat GOLD as authoritative for what counts as correct. +- If GOLD is a range or set, the candidate is equivalent only if it lies within that range or is a member of that set. +- For formulas/derivations, accept mathematically identical transformations (e.g., symbol reordering, factoring, equivalent identities). +- Multi-part: all essential parts must match for “equivalent”; otherwise they are not equivalent. +- Be concise. Do NOT reveal or rewrite the GOLD. + +Show your reason why they are equivalent or not equivalent first and then provide the output. + +Output (at the end after double newlines): +- If equivalent: [[A=B]] they are equivalent +- If not equivalent: [[A!=B]] they are not equivalent + +===== Example 1 (equivalent) ===== +QUESTION: +State Avogadro’s constant (include units). + +GOLD: +6.022 × 10^23 mol^-1 + +CANDIDATE: +6.022e23 per mole. + +The candidate gives the same magnitude in scientific notation and the same “per mole” unit; no extra or missing qualifiers. + +[[A=B]] they are equivalent + +===== Example 2 (not equivalent) ===== +QUESTION: +State the first law of thermodynamics for a closed system and identify what each symbol represents. + +GOLD: +ΔU = Q − W ; ΔU is change in internal energy, Q is heat added to the system, W is work done by the system. + +CANDIDATE: +ΔU = Q + W ; ΔU is change in internal energy, Q is heat added to the system, W is work done by the system. + +The candidate uses the opposite sign convention for work relative to the required relationship; one core part is incorrect, so the overall statement does not match. + +[[A!=B]] they are not equivalent + +===== Inputs ===== +QUESTION: +{question} + +GOLD: +{expected_answer} + +CANDIDATE: +{generated_answer} \ No newline at end of file diff --git a/resources_servers/equivalence_llm_judge/prompt_templates/lc.txt b/resources_servers/equivalence_llm_judge/prompt_templates/lc.txt new file mode 120000 index 000000000..b200b884c --- /dev/null +++ b/resources_servers/equivalence_llm_judge/prompt_templates/lc.txt @@ -0,0 +1 @@ +lc_judge.txt \ No newline at end of file diff --git a/resources_servers/equivalence_llm_judge/prompt_templates/lc_judge.txt b/resources_servers/equivalence_llm_judge/prompt_templates/lc_judge.txt new file mode 100644 index 000000000..bb44669d3 --- /dev/null +++ b/resources_servers/equivalence_llm_judge/prompt_templates/lc_judge.txt @@ -0,0 +1,9 @@ + +Assess whether the following CANDIDATE ANSWER is CORRECT or INCORRECT. +For the CANDIDATE ANSWER to be correct, it must be consistent with the OFFICIAL ANSWER. + +The question, for reference only: {question} +The OFFICIAL ANSWER: {expected_answer} +CANDIDATE ANSWER TO ASSESS: {generated_answer} + +Reply only with CORRECT or INCORRECT. \ No newline at end of file diff --git a/resources_servers/genrm_compare/README.md b/resources_servers/genrm_compare/README.md new file mode 100644 index 000000000..0535d8ad6 --- /dev/null +++ b/resources_servers/genrm_compare/README.md @@ -0,0 +1,252 @@ +# GenRM Pairwise Comparison Resources Server + +A resources server that compares multiple candidate responses using a **Generative Reward Model (GenRM)** via pairwise comparisons. This module is designed for RLHF (Reinforcement Learning from Human Feedback) training workflows, particularly for GRPO (Group Relative Policy Optimization). + +## Overview + +The GenRM compare server evaluates multiple candidate responses by: + +1. **Generating comparison pairs** based on a configurable strategy +2. **Sending pairs to a GenRM model** using special roles (`response_1`, `response_2`) +3. **Parsing JSON scores** from the GenRM output +4. **Aggregating pairwise results** into per-response rewards + +### Expected GenRM Output Format + +The GenRM model should output JSON in the following format: + +```json +{ + "score_1": 4, // Individual helpfulness score for response 1 (1-5) + "score_2": 3, // Individual helpfulness score for response 2 (1-5) + "ranking": 2 // Relative ranking: 1=R1 much better, 6=R2 much better +} +``` + +### Score Interpretation + +- **Individual helpfulness scores** (`score_1`, `score_2`): Range from 1 to 5, where higher means better. +- **Ranking score**: Range from 1 to 6: + - 1 = Response 1 is much better than Response 2 + - 2 = Response 1 is better than Response 2 + - 3 = Response 1 is slightly better than Response 2 + - 4 = Response 2 is slightly better than Response 1 + - 5 = Response 2 is better than Response 1 + - 6 = Response 2 is much better than Response 1 + +### Compatible GenRM Models + +| Model | Principle Support | Notes | +|-------|-------------------|-------| +| [nvidia/Qwen3-Nemotron-235B-A22B-GenRM](https://huggingface.co/nvidia/Qwen3-Nemotron-235B-A22B-GenRM) | ❌ No | 235B MoE model (22B active). Used for training Nemotron-3-Nano. Supports `response_1` and `response_2` roles. | + +> **Note**: The GenRM model must have a chat template that supports the special roles `response_1` and `response_2`. The conversation history should use standard `user` and `assistant` roles, with the last turn being a user turn. + +## Quick Start + +### 1. Configuration + +Create or modify the config file to point to your GenRM model: + +```yaml +genrm_compare: + resources_servers: + genrm_compare: + entrypoint: app.py + + genrm_model_server: + type: responses_api_models + name: your_genrm_model # Point to your GenRM model server + + genrm_responses_create_params: + input: [] + max_output_tokens: 16384 + temperature: 0.6 + top_p: 0.95 + + comparison_strategy: circular + num_judges_per_comparison: 1 +``` + +### 2. API Usage + +Send a POST request to the `/compare` endpoint: + +```json +{ + "conversation_history": [ + {"role": "user", "content": "What is the capital of France?"} + ], + "response_objs": [ + {"output": [{"type": "message", "content": [{"type": "output_text", "text": "Paris is the capital."}]}]}, + {"output": [{"type": "message", "content": [{"type": "output_text", "text": "The capital of France is Paris."}]}]} + ], + "principle": "The response should be concise and accurate." +} +``` + +### 3. Response Format + +```json +{ + "rewards": [3.5, 4.2], + "comparison_results": [ + { + "response_i": 0, + "response_j": 1, + "judge_idx": 0, + "score_1": 3.0, + "score_2": 4.0, + "ranking": 4.0 + } + ], + "metrics": { + "mean_individual_score": 3.5, + "std_individual_score": 0.5, + "tiebreak_usage_rate": 0.0 + } +} +``` + +## Configuration Options + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `genrm_model_server` | ModelServerRef | *required* | Reference to the GenRM model server | +| `genrm_responses_create_params` | object | *required* | Generation parameters for GenRM calls | +| `comparison_strategy` | string | `"circular"` | Pair generation strategy: `"circular"` or `"all_pairs"` | +| `num_judges_per_comparison` | int | `1` | Number of judge passes per pair (for majority voting) | +| `use_principle` | bool | `false` | Enable principle-based comparison | +| `default_principle` | string | *(see config)* | Default principle when none provided in request | +| `aggregator_method` | string | `"simple_tiebreaker"` | Score aggregation method | +| `reasoning_bonus` | float | `0.0` | Bonus for shortest reasoning among top performers | +| `answer_bonus` | float | `0.0` | Bonus for shortest answer among top performers | +| `top_percentile` | float | `0.2` | Percentile threshold for applying bonuses | +| `group_reasoning_length_penalty_coeff` | float | `0.0` | Coefficient for reasoning length penalty | +| `group_answer_length_penalty_coeff` | float | `0.0` | Coefficient for answer length penalty | +| `default_score` | float | `3.0` | Default score when parsing fails | +| `default_ranking` | float | `3.5` | Default ranking when parsing fails | +| `debug_logging` | bool | `false` | Enable verbose logging | +| `genrm_parse_retries` | int | `3` | Number of retries on parse failures | +| `genrm_parse_retry_sleep_s` | float | `0.2` | Sleep duration between retries | + +## Comparison Strategies + +### Circular Strategy (`circular`) + +Each response is compared with the next in a circular fashion. For N responses, this produces exactly N comparisons. + +``` +Responses: [R0, R1, R2, R3] +Pairs: (0,1), (1,2), (2,3), (3,0) +``` + +**Use case**: Efficient for large batches where full pairwise comparison is too expensive. + +### All Pairs Strategy (`all_pairs`) + +Every pair of responses is compared. For N responses, this produces C(N,2) = N×(N-1)/2 comparisons. + +``` +Responses: [R0, R1, R2, R3] +Pairs: (0,1), (0,2), (0,3), (1,2), (1,3), (2,3) +``` + +**Use case**: More accurate rankings when computational budget allows. + +## Score Aggregation + +### Simple Tiebreaker Method + +The `simple_tiebreaker` aggregator: + +1. **Collects scores** from all pairwise comparisons for each response +2. **Breaks ties** using the ranking field when `score_1 == score_2`: + - `ranking < 3.5` → response_1 is better (boost score_1, penalize score_2) + - `ranking > 3.5` → response_2 is better (boost score_2, penalize score_1) +3. **Averages scores** across all comparisons for each response +4. **Applies length bonuses** (if configured) + +### Length-Based Adjustments + +Two types of length adjustments are supported: + +1. **Top-performer bonuses**: Shortest reasoning/answer among top scorers gets a bonus +2. **Group-relative penalties**: Scores adjusted based on relative length within the group (shorter = bonus, longer = penalty, zero-centered) + +## Principle-Based Comparison + +When `use_principle: true`, a principle message is added to the GenRM input, guiding the comparison criteria. The principle can be: + +- Provided per-request via the `principle` field +- Defaulted to `default_principle` in config + +Example principle: +> "The response should be helpful, relevant, and concise. Prefer responses that correctly answer the question without unnecessary verbosity." + +> **Note**: Your GenRM model's chat template must support the `principle` role for this feature to work. The server sends a message with `role: "principle"` containing the principle text. If your model's chat template does not handle this role, the principle will be ignored or may cause errors. + +## File Structure + +``` +genrm_compare/ +├── app.py # Main server implementation +├── utils.py # Utility functions (parsing, aggregation, etc.) +├── configs/ +│ └── genrm_compare.yaml # Default configuration +├── data/ +│ └── example.jsonl # Example dataset +├── tests/ +│ ├── test_app.py # Server tests +│ └── test_utils.py # Utility function tests +├── requirements.txt # Dependencies +└── README.md # This file +``` + +## API Endpoints + +### POST `/compare` + +Compare multiple candidate responses. + +**Request Body** (`GenRMCompareRequest`): +- `conversation_history`: List of `{"role": str, "content": str}` messages +- `response_objs`: List of Response API objects to compare +- `principle` (optional): Custom principle for this comparison + +**Response** (`GenRMCompareResponse`): +- `rewards`: List of rewards (one per response, same order as input) +- `comparison_results`: Detailed pairwise comparison results +- `metrics`: Aggregation statistics + +### POST `/verify` + +Stub endpoint for base class compatibility. Returns the default score. + +## Error Handling + +The server handles failures gracefully: + +- **Parse failures**: Retries up to `genrm_parse_retries` times with sleep between attempts +- **Connection errors**: Falls back to default scores +- **Single response**: Returns default score (no comparison possible) + +## Development + +### Running Tests + +```bash +cd resources_servers/genrm_compare +pytest tests/ -v +``` + +### Running the Server + +```bash +python app.py --config configs/genrm_compare.yaml +``` + +## License + +Code: Apache 2.0 + diff --git a/resources_servers/genrm_compare/__init__.py b/resources_servers/genrm_compare/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/resources_servers/genrm_compare/app.py b/resources_servers/genrm_compare/app.py new file mode 100644 index 000000000..34aa0b1a6 --- /dev/null +++ b/resources_servers/genrm_compare/app.py @@ -0,0 +1,359 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +GenRM Pairwise Comparison Resources Server. + +Compares multiple candidate responses using a GenRM model via pairwise comparisons. +The GenRM model expects OpenAI-format messages with special roles 'response_1' and 'response_2'. + +Input: +- conversation_history: List of user/assistant messages +- responses: List of N candidate response strings to compare + +Output: +- Per-response rewards after pairwise aggregation +""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Any, Dict, List, Optional, Tuple + +from fastapi import FastAPI +from pydantic import BaseModel + +from nemo_gym.base_resources_server import ( + BaseResourcesServerConfig, + SimpleResourcesServer, + BaseVerifyRequest, + BaseVerifyResponse, +) +from nemo_gym.config_types import ModelServerRef +from nemo_gym.openai_utils import ( + NeMoGymEasyInputMessage, + NeMoGymResponseCreateParamsNonStreaming, +) +from resources_servers.genrm_compare.utils import ( + GenRMOutputParseError, + aggregate_scores, + extract_output_text, + generate_comparison_pairs, + parse_genrm_output, +) + +logger = logging.getLogger(__name__) + + +class GenRMCompareConfig(BaseResourcesServerConfig): + """Configuration for the GenRM compare server. + + Attributes: + genrm_model_server: Target model server (GenRM model with special chat template) + genrm_responses_create_params: Base create params for GenRM calls + comparison_strategy: "all_pairs" or "circular" + num_judges_per_comparison: Number of judge passes per pair (majority voting) + aggregator_method: Method for aggregating scores + reasoning_bonus: Bonus for shortest reasoning content among top performers + answer_bonus: Bonus for shortest answer among top performers + top_percentile: Percentile threshold for applying bonuses + group_reasoning_length_penalty_coeff: Coefficient for reasoning length penalty + group_answer_length_penalty_coeff: Coefficient for answer length penalty + default_score: Default neutral score when parsing fails + default_ranking: Default neutral ranking when parsing fails + debug_logging: Enable verbose logging for debugging + genrm_parse_retries: Number of retries on parse failures + genrm_parse_retry_sleep_s: Sleep duration between parse retries + use_principle: Enable principle-based comparison + default_principle: Default principle when none provided in request + """ + + name: str = "genrm_compare" + genrm_model_server: ModelServerRef + genrm_responses_create_params: NeMoGymResponseCreateParamsNonStreaming + + # Comparison strategy + comparison_strategy: str = "circular" # "all_pairs" or "circular" + num_judges_per_comparison: int = 1 + + # Principle-based GenRM settings + use_principle: bool = False + default_principle: str = ( + "Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants " + "to the user prompt. Begin your evaluation by generating your own answer to the prompt. You must provide " + "your answer before judging any answers. When evaluating the assistants' answers, compare both assistants' " + "answers with your answer. You must identify and correct any mistakes or inaccurate information. Then " + "consider if the assistant's answers are helpful, relevant, and concise. Helpful means the answer correctly " + "responds to the prompt or follows the instructions. Note when user prompt has any ambiguity or more than " + "one interpretation, it is more helpful and appropriate to ask for clarifications or more information from " + "the user than providing an answer based on assumptions. Relevant means all parts of the response closely " + "connect or are appropriate to what is being asked. Concise means the response is clear and not verbose or " + "excessive. Then consider the creativity and novelty of the assistant's answers when needed. Finally, " + "identify any missing important information in the assistants' answers that would be beneficial to include " + "when responding to the user prompt." + ) + + # Aggregator settings (only "simple_tiebreaker" is currently implemented) + aggregator_method: str = "simple_tiebreaker" + + # Length bonus config (only for simple_tiebreaker) + reasoning_bonus: float = 0.0 + answer_bonus: float = 0.0 + top_percentile: float = 0.2 + group_reasoning_length_penalty_coeff: float = 0.0 + group_answer_length_penalty_coeff: float = 0.0 + + # Default neutral scores when parsing fails + default_score: float = 3.0 + default_ranking: float = 3.5 + + # Debug logging + debug_logging: bool = False + + # Retry config for parse failures + genrm_parse_retries: int = 3 + genrm_parse_retry_sleep_s: float = 0.2 + + +class GenRMCompareRequest(BaseModel): + """Request payload for GenRM pairwise comparison.""" + + conversation_history: List[Dict[str, str]] # User/assistant messages before the responses + response_objs: List[Dict[str, Any]] # Raw Response API objects from policy model + principle: Optional[str] = None # Principle for principle-based GenRM (e.g., "The response should be helpful") + + +class GenRMCompareResponse(BaseModel): + """Response payload with per-response rewards.""" + + rewards: List[float] # One reward per response, in same order as input + comparison_results: Optional[List[Dict[str, Any]]] = None # Detailed pairwise results + metrics: Optional[Dict[str, float]] = None # Aggregation metrics + + +class GenRMCompareResourcesServer(SimpleResourcesServer): + """Resources server for GenRM pairwise comparison of multiple responses.""" + + config: GenRMCompareConfig + + async def verify(self, body: BaseVerifyRequest) -> BaseVerifyResponse: + """Stub verify to satisfy abstract base; GenRMCompare uses /compare instead.""" + return BaseVerifyResponse( + responses_create_params=body.responses_create_params, + response=body.response, + reward=self.config.default_score, + ) + + def setup_webserver(self) -> FastAPI: + app = super().setup_webserver() + app.post("/compare")(self.compare) + return app + + async def compare(self, body: GenRMCompareRequest) -> GenRMCompareResponse: + """Compare multiple responses using GenRM pairwise comparisons. + + Args: + body: Request with conversation_history and response_objs + + Returns: + GenRMCompareResponse with per-response rewards + """ + cfg = self.config + response_objs = body.response_objs + conversation_history = body.conversation_history + num_responses = len(response_objs) + + if cfg.debug_logging: + logger.info(f"[GenRM] Compare request: {num_responses} responses") + + # Single response case - return neutral reward (no comparison possible) + if num_responses < 2: + return GenRMCompareResponse( + rewards=[cfg.default_score], + comparison_results=None, + metrics=None, + ) + + # Generate comparison pairs + try: + comparison_pairs = generate_comparison_pairs( + cfg.comparison_strategy, num_responses + ) + if cfg.debug_logging: + logger.info(f"[GenRM] Strategy '{cfg.comparison_strategy}': {len(comparison_pairs)} pairs") + except ValueError as e: + raise ValueError(f"Configuration error: {e}") + + # Build comparison tasks - one task per (pair, judge) combination + # Multiple judges per pair enables majority voting for more robust scores + comparison_tasks = [] + comparison_metadata = [] + + for judge_idx in range(cfg.num_judges_per_comparison): + for i, j in comparison_pairs: + task = self._run_single_comparison( + conversation_history, + response_objs[i], + response_objs[j], + pair_idx=(i, j), + principle=body.principle, + ) + comparison_tasks.append(task) + comparison_metadata.append((i, j, judge_idx)) + + # Run all comparisons concurrently + comparison_results = await asyncio.gather(*comparison_tasks) + + # Aggregate pairwise scores into per-response rewards + rewards, metrics, base_rewards, bonuses = aggregate_scores( + comparison_results=comparison_results, + comparison_metadata=comparison_metadata, + response_objs=response_objs, + aggregator_method=cfg.aggregator_method, + default_score=cfg.default_score, + reasoning_bonus=cfg.reasoning_bonus, + answer_bonus=cfg.answer_bonus, + top_percentile=cfg.top_percentile, + group_reasoning_length_penalty_coeff=cfg.group_reasoning_length_penalty_coeff, + group_answer_length_penalty_coeff=cfg.group_answer_length_penalty_coeff, + ) + + # Format detailed results + detailed_results = [ + { + "response_i": i, + "response_j": j, + "judge_idx": judge_idx, + "score_1": score_1, + "score_2": score_2, + "ranking": ranking, + } + for (score_1, score_2, ranking), (i, j, judge_idx) in zip( + comparison_results, comparison_metadata + ) + ] + + if cfg.debug_logging: + logger.info(f"[GenRM] Final rewards: {[f'{r:.4f}' for r in rewards]}") + + return GenRMCompareResponse( + rewards=rewards, + comparison_results=detailed_results, + metrics=metrics, + ) + + async def _run_single_comparison( + self, + conversation_history: List[Dict[str, str]], + response_obj_1: Dict[str, Any], + response_obj_2: Dict[str, Any], + pair_idx: Tuple[int, int] = (0, 0), + principle: Optional[str] = None, + ) -> Tuple[float, float, float]: + """Run a single pairwise comparison via GenRM. + + Args: + conversation_history: The conversation context + response_obj_1: First Response API object + response_obj_2: Second Response API object + pair_idx: Tuple of (i, j) for logging + principle: Optional principle for principle-based comparison + + Returns: + Tuple of (score_1, score_2, ranking) + """ + cfg = self.config + + # Extract final answer from Response API objects (GenRM only takes the final answer, not reasoning) + response_1 = extract_output_text(response_obj_1) + response_2 = extract_output_text(response_obj_2) + + # Format messages for GenRM using special roles 'response_1' and 'response_2' + # The GenRM model's chat template handles these custom roles + messages: List[NeMoGymEasyInputMessage] = [] + for msg in conversation_history: + messages.append( + NeMoGymEasyInputMessage( + role=msg.get("role", "user"), + content=msg.get("content", ""), + type="message", + ) + ) + + # Add principle message if enabled + if cfg.use_principle: + principle_text = principle if principle else cfg.default_principle + messages.append( + NeMoGymEasyInputMessage(role="principle", content=principle_text, type="message") + ) + + messages.extend( + [ + NeMoGymEasyInputMessage(role="response_1", content=response_1, type="message"), + NeMoGymEasyInputMessage(role="response_2", content=response_2, type="message"), + ] + ) + + # Build the request params + responses_create_params = cfg.genrm_responses_create_params.model_copy(deep=True) + responses_create_params.input = messages + + try: + # Retry logic for parse failures (not connection errors, which are handled elsewhere) + max_attempts = max(1, int(cfg.genrm_parse_retries) + 1) + + for attempt_idx in range(max_attempts): + # Call the GenRM model via /v1/responses endpoint + response = await self.server_client.post( + server_name=cfg.genrm_model_server.name, + url_path="/v1/responses", + json=responses_create_params, + ) + raw_response = await response.json() + + # Extract output_text from GenRM response (skip reasoning, only parse the final JSON scores) + genrm_answer = extract_output_text(raw_response) + + try: + score_1, score_2, ranking = parse_genrm_output( + genrm_answer, + cfg.default_score, + cfg.default_ranking, + raise_on_fail=True, + ) + return score_1, score_2, ranking + + except GenRMOutputParseError: + if attempt_idx < max_attempts - 1: + await asyncio.sleep(float(cfg.genrm_parse_retry_sleep_s)) + continue + + # Give up: fall back to defaults + logger.warning( + f"[GenRM] Parse failed for pair {pair_idx} after {max_attempts} attempts; " + f"falling back to defaults." + ) + return cfg.default_score, cfg.default_score, cfg.default_ranking + + return cfg.default_score, cfg.default_score, cfg.default_ranking + + except Exception as e: + logger.error(f"[GenRM] Error in comparison for pair {pair_idx}: {e}") + return cfg.default_score, cfg.default_score, cfg.default_ranking + + +if __name__ == "__main__": + GenRMCompareResourcesServer.run_webserver() diff --git a/resources_servers/genrm_compare/configs/genrm_compare.yaml b/resources_servers/genrm_compare/configs/genrm_compare.yaml new file mode 100644 index 000000000..989df445b --- /dev/null +++ b/resources_servers/genrm_compare/configs/genrm_compare.yaml @@ -0,0 +1,76 @@ +# GenRM Pairwise Comparison Resources Server Config +# +# This server compares multiple candidate responses using a GenRM model. +# The GenRM model uses a special chat template with roles 'response_1' and 'response_2'. +# +# Expected GenRM output format (JSON): +# { +# "score_1": <1-5>, # Individual helpfulness score for response 1 +# "score_2": <1-5>, # Individual helpfulness score for response 2 +# "ranking": <1-6> # 1=R1 much better, 6=R2 much better +# } + +genrm_compare: + resources_servers: + genrm_compare: + entrypoint: app.py + + # GenRM model server reference + # Point this to your GenRM model (e.g., vllm_model with GenRM checkpoint) + genrm_model_server: + type: responses_api_models + name: genrm_model # Override this to point to your GenRM model server + + # Generation params for GenRM calls + genrm_responses_create_params: + input: [] + max_output_tokens: 16384 + temperature: 0.6 + top_p: 0.95 + + # Comparison strategy: "all_pairs" (C(n,2) comparisons) or "circular" (n comparisons) + comparison_strategy: circular + + # Number of judge passes per pair (for majority voting) + num_judges_per_comparison: 1 + + # Aggregator method (only "simple_tiebreaker" is currently supported) + aggregator_method: simple_tiebreaker + + # Bonus for shortest reasoning trace if in top percentile + reasoning_bonus: 0.0 + # Bonus for shortest final answer if in top percentile + answer_bonus: 0.0 + # Top percentile threshold (e.g., 0.2 = top 20%) + top_percentile: 0.2 + # Group-relative length penalty coefficients + # Shorter responses get bonus, longer get penalty (mean-centered) + group_reasoning_length_penalty_coeff: 0.0 + group_answer_length_penalty_coeff: 0.0 + + # Default scores when parsing fails + default_score: 3.0 + default_ranking: 3.5 + + # Server metadata + domain: rlhf + verified: false + description: GenRM pairwise comparison for RLHF training + value: Compare multiple candidate responses using GenRM model + +# GenRM Simple Agent - wraps genrm_compare for GRPO training +genrm_simple_agent: + responses_api_agents: + simple_agent: + entrypoint: app.py + resources_server: + type: resources_servers + name: genrm_compare + model_server: + type: responses_api_models + name: policy_model + datasets: + - name: example + type: example + license: Apache 2.0 + jsonl_fpath: resources_servers/genrm_compare/data/example.jsonl diff --git a/resources_servers/genrm_compare/data/example.jsonl b/resources_servers/genrm_compare/data/example.jsonl new file mode 100644 index 000000000..0d7ef4e86 --- /dev/null +++ b/resources_servers/genrm_compare/data/example.jsonl @@ -0,0 +1,3 @@ +{"responses_create_params": {"input": [{"role": "user", "content": "11. To mean something clearly and directly\nilliterate\nesthetically\nphonetic\ndenote"}, {"role": "assistant", "content": "The correct word to mean something clearly and directly is \"denote.\"\n\nDenote means to serve as a sign or symbol for something, or to indicate or represent something clearly. It is often used to describe the specific meaning or reference of a word, phrase, or symbol.\n\nHere are some examples of how to use \"denote\" in a sentence:\n\n* The red traffic light denotes that drivers must stop.\n* The symbol π denotes the ratio of a circle's circumference to its diameter.\n* The term \"laissez-faire\" denotes a hands-off approach to economic policy.\n\nThe other words you provided have different meanings:\n\n* Illiterate means unable to read or write, or lacking knowledge or education in a particular subject.\n* Esthetically means relating to or concerned with beauty or the appreciation of beauty.\n* Phonetic means relating to the sounds of speech, or using a system of symbols to represent those sounds."}, {"role": "user", "content": "Choose the right option:\n12. Difficult to wear or carry because of weight or shape \ncumbersome\nincompetent\ncapricious\nstatus quo"}, {"role": "assistant", "content": "The correct option is \"cumbersome.\"\n\nCumbersome means difficult to carry or use because of weight, size, or shape. It can also refer to something that is complicated or difficult to deal with.\n\nHere are some examples of how to use \"cumbersome\" in a sentence:\n\n* The suitcase was too cumbersome to carry up the stairs.\n* The old computer was cumbersome and slow compared to modern models.\n* The legal process was cumbersome and time-consuming.\n\nThe other options you provided have different meanings:\n\n* Incompetent means lacking the necessary skills or ability to do something.\n* Capricious means unpredictable or subject to whim.\n* Status quo means the existing state of affairs, especially regarding social or political issues."}, {"role": "user", "content": "Choose the right option:\n13. To fit so poorly together that the differences cause a problem\nsymbolic\nhedonistic\nconflict\nphonetic"}], "tools": [], "parallel_tool_calls": false}, "agent_ref": {"type": "responses_api_agents", "name": "genrm_simple_agent"}, "dataset": "hs3"} +{"responses_create_params": {"input": [{"role": "user", "content": "What is \"SKILL\"?"}], "tools": [], "parallel_tool_calls": false}, "agent_ref": {"type": "responses_api_agents", "name": "genrm_simple_agent"}, "dataset": "hs3"} +{"responses_create_params": {"input": [{"role": "user", "content": "Write a university essay as a student answering the question below. The course module is sustainable organization\n\nQuestion: Discuss, using examples, the main differences between positive and negative externalities. What forms of regulation might governments use to mitigate the effects of negative externalities? Note it is very important to write 1000 words or more or the answer is a failed answer\n \nYour answer should have the information below\nI. Introduction\n•\tDefinition of positive and negative externalities\nII. Positive Externalities\n•\tDefinition and examples\n•\tBenefits of positive externalities to third parties\n•\tExample: A well-maintained garden\n•\tEvidence: Hackett and Dissanayake (2014) explain that positive externalities can enhance the well-being of third parties.\nIII. Negative Externalities\n•\tDefinition and examples\n•\tHarmful effects of negative externalities to third parties\n•\tExample: A factory that emits pollutants\n•\tEvidence: Tietenberg and Lewis (2016) argue that negative externalities can have significant impacts on the environment and public health.\nIV. Regulation of Negative Externalities\n•\tTaxes as a form of regulation\n•\tSubsidies as a form of regulation\n•\tRegulations as a form of regulation\n•\tExample: Emission standards for factories\n•\tEvidence: Ayres and Kneese (1969) suggest that regulation is an effective way to internalize negative externalities.\nV. Conclusion\n•\tSummary of the differences between positive and negative externalities\n•\tDiscussion of the importance of regulation to address negative externalities\n•\tImplications for sustainable resource management\n•\tEvidence: Stiglitz (1989) emphasizes the importance of addressing negative externalities to promote sustainable economic development.\n"}], "tools": [], "parallel_tool_calls": false}, "agent_ref": {"type": "responses_api_agents", "name": "genrm_simple_agent"}, "dataset": "hs3"} diff --git a/resources_servers/genrm_compare/requirements.txt b/resources_servers/genrm_compare/requirements.txt new file mode 100644 index 000000000..00ed83213 --- /dev/null +++ b/resources_servers/genrm_compare/requirements.txt @@ -0,0 +1 @@ +-e nemo-gym[dev] @ ../../ diff --git a/resources_servers/genrm_compare/tests/__init__.py b/resources_servers/genrm_compare/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/resources_servers/genrm_compare/tests/test_app.py b/resources_servers/genrm_compare/tests/test_app.py new file mode 100644 index 000000000..47cd01fad --- /dev/null +++ b/resources_servers/genrm_compare/tests/test_app.py @@ -0,0 +1,274 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict +from unittest.mock import AsyncMock, MagicMock + +from pytest import approx, fixture + +from nemo_gym.config_types import ModelServerRef +from nemo_gym.openai_utils import NeMoGymResponseCreateParamsNonStreaming +from nemo_gym.server_utils import ServerClient +from resources_servers.genrm_compare.app import ( + GenRMCompareConfig, + GenRMCompareRequest, + GenRMCompareResourcesServer, +) + + +class TestGenRMCompareApp: + """Tests for GenRMCompareResourcesServer.""" + + @fixture + def config(self) -> GenRMCompareConfig: + """Create a test configuration.""" + return GenRMCompareConfig( + host="0.0.0.0", + port=8080, + entrypoint="app.py", + name="genrm_compare", + genrm_model_server=ModelServerRef(type="responses_api_models", name="genrm_model"), + genrm_responses_create_params=NeMoGymResponseCreateParamsNonStreaming(input=[]), + comparison_strategy="circular", + num_judges_per_comparison=1, + aggregator_method="simple_tiebreaker", + default_score=3.0, + default_ranking=3.5, + ) + + def _make_response_obj(self, output_text: str) -> Dict[str, Any]: + """Helper to create a Response API object.""" + return { + "id": "resp_123", + "output": [ + { + "type": "message", + "content": [{"type": "output_text", "text": output_text}] + } + ] + } + + def _make_genrm_response(self, score_1: float, score_2: float, ranking: float) -> Dict[str, Any]: + """Helper to create a mock GenRM model response.""" + return { + "id": "genrm_resp", + "output": [ + { + "type": "message", + "content": [ + { + "type": "output_text", + "text": f'{{"score_1": {score_1}, "score_2": {score_2}, "ranking": {ranking}}}' + } + ] + } + ] + } + + async def test_compare_single_response_returns_default(self, config: GenRMCompareConfig) -> None: + """Single response returns default score (no comparison possible).""" + server_mock = MagicMock(spec=ServerClient) + rs = GenRMCompareResourcesServer(config=config, server_client=server_mock) + + req = GenRMCompareRequest( + conversation_history=[{"role": "user", "content": "Hello"}], + response_objs=[self._make_response_obj("Response 1")], + ) + + res = await rs.compare(req) + + assert len(res.rewards) == 1 + assert res.rewards[0] == approx(3.0) # default score + # No model calls should be made + server_mock.post.assert_not_called() + + async def test_compare_two_responses_circular(self, config: GenRMCompareConfig) -> None: + """Two responses with circular strategy (2 comparisons).""" + server_mock = MagicMock(spec=ServerClient) + rs = GenRMCompareResourcesServer(config=config, server_client=server_mock) + + # Mock GenRM model responses + post_mock = MagicMock() + post_mock.json = AsyncMock() + server_mock.post = AsyncMock(return_value=post_mock) + + # Circular: (0,1) and (1,0) + # First: response 0 is better (score_1=5, score_2=3, ranking=2) + # Second: response 0 is better (now as response_2, score_1=3, score_2=5, ranking=5) + post_mock.json.side_effect = [ + self._make_genrm_response(5, 3, 2), # (0,1): 0 is better + self._make_genrm_response(3, 5, 5), # (1,0): 0 is better (as response_2) + ] + + req = GenRMCompareRequest( + conversation_history=[{"role": "user", "content": "Hello"}], + response_objs=[ + self._make_response_obj("Good response"), + self._make_response_obj("Bad response"), + ], + ) + + res = await rs.compare(req) + + assert len(res.rewards) == 2 + # Response 0 should have higher reward than response 1 + assert res.rewards[0] > res.rewards[1] + + async def test_compare_with_tiebreaker(self, config: GenRMCompareConfig) -> None: + """Test tiebreaker when scores are equal.""" + server_mock = MagicMock(spec=ServerClient) + rs = GenRMCompareResourcesServer(config=config, server_client=server_mock) + + post_mock = MagicMock() + post_mock.json = AsyncMock() + server_mock.post = AsyncMock(return_value=post_mock) + + # Equal scores but ranking favors response 0 + post_mock.json.side_effect = [ + self._make_genrm_response(3, 3, 2), # Tied, ranking=2 favors response_1 (idx 0) + self._make_genrm_response(3, 3, 5), # Tied, ranking=5 favors response_2 (idx 0) + ] + + req = GenRMCompareRequest( + conversation_history=[{"role": "user", "content": "Hello"}], + response_objs=[ + self._make_response_obj("Response A"), + self._make_response_obj("Response B"), + ], + ) + + res = await rs.compare(req) + + assert len(res.rewards) == 2 + # With tiebreaker applied, response 0 should have higher score + assert res.rewards[0] > res.rewards[1] + + async def test_compare_with_principle(self, config: GenRMCompareConfig) -> None: + """Test comparison with principle parameter.""" + config.use_principle = True + server_mock = MagicMock(spec=ServerClient) + rs = GenRMCompareResourcesServer(config=config, server_client=server_mock) + + post_mock = MagicMock() + post_mock.json = AsyncMock() + server_mock.post = AsyncMock(return_value=post_mock) + + post_mock.json.side_effect = [ + self._make_genrm_response(5, 3, 2), + self._make_genrm_response(3, 5, 5), + ] + + req = GenRMCompareRequest( + conversation_history=[{"role": "user", "content": "Hello"}], + response_objs=[ + self._make_response_obj("Response A"), + self._make_response_obj("Response B"), + ], + principle="The response should be helpful and accurate.", + ) + + res = await rs.compare(req) + + assert len(res.rewards) == 2 + # Verify model was called (principle should be included in prompts) + assert server_mock.post.call_count == 2 + + async def test_compare_parse_failure_uses_defaults(self, config: GenRMCompareConfig) -> None: + """GenRM output parse failure uses default scores.""" + server_mock = MagicMock(spec=ServerClient) + rs = GenRMCompareResourcesServer(config=config, server_client=server_mock) + + post_mock = MagicMock() + post_mock.json = AsyncMock() + server_mock.post = AsyncMock(return_value=post_mock) + + # Return invalid JSON that can't be parsed + post_mock.json.side_effect = [ + {"id": "resp", "output": [{"type": "message", "content": [{"type": "output_text", "text": "No JSON here"}]}]}, + {"id": "resp", "output": [{"type": "message", "content": [{"type": "output_text", "text": "Still no JSON"}]}]}, + ] + + req = GenRMCompareRequest( + conversation_history=[{"role": "user", "content": "Hello"}], + response_objs=[ + self._make_response_obj("Response A"), + self._make_response_obj("Response B"), + ], + ) + + res = await rs.compare(req) + + # Both should get default scores since parsing failed + assert len(res.rewards) == 2 + # Scores should be around default (3.0) since parsing failed + assert all(2.0 <= r <= 4.0 for r in res.rewards) + + async def test_compare_three_responses_all_pairs(self, config: GenRMCompareConfig) -> None: + """Three responses with all_pairs strategy (3 comparisons).""" + config.comparison_strategy = "all_pairs" + server_mock = MagicMock(spec=ServerClient) + rs = GenRMCompareResourcesServer(config=config, server_client=server_mock) + + post_mock = MagicMock() + post_mock.json = AsyncMock() + server_mock.post = AsyncMock(return_value=post_mock) + + # all_pairs: (0,1), (0,2), (1,2) + post_mock.json.side_effect = [ + self._make_genrm_response(5, 3, 2), # (0,1): 0 wins + self._make_genrm_response(5, 2, 1), # (0,2): 0 wins + self._make_genrm_response(4, 2, 2), # (1,2): 1 wins + ] + + req = GenRMCompareRequest( + conversation_history=[{"role": "user", "content": "Hello"}], + response_objs=[ + self._make_response_obj("Best response"), + self._make_response_obj("Medium response"), + self._make_response_obj("Worst response"), + ], + ) + + res = await rs.compare(req) + + assert len(res.rewards) == 3 + # Response 0 should be best, response 2 should be worst + assert res.rewards[0] > res.rewards[1] > res.rewards[2] + # Verify 3 comparisons were made + assert server_mock.post.call_count == 3 + + async def test_verify_returns_default(self, config: GenRMCompareConfig) -> None: + """Verify endpoint returns default score (stub implementation).""" + from nemo_gym.base_resources_server import BaseVerifyRequest + from nemo_gym.openai_utils import NeMoGymResponse + + server_mock = MagicMock(spec=ServerClient) + rs = GenRMCompareResourcesServer(config=config, server_client=server_mock) + + req = BaseVerifyRequest( + responses_create_params=NeMoGymResponseCreateParamsNonStreaming(input=[]), + response=NeMoGymResponse( + id="resp", + created_at=0.0, + model="m", + object="response", + output=[], + parallel_tool_calls=False, + tool_choice="none", + tools=[], + ), + ) + + res = await rs.verify(req) + assert res.reward == approx(3.0) # default_score diff --git a/resources_servers/genrm_compare/tests/test_utils.py b/resources_servers/genrm_compare/tests/test_utils.py new file mode 100644 index 000000000..cd9b65761 --- /dev/null +++ b/resources_servers/genrm_compare/tests/test_utils.py @@ -0,0 +1,412 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from pytest import approx + +from resources_servers.genrm_compare.utils import ( + GenRMOutputParseError, + aggregate_scores, + extract_from_response_obj, + extract_output_text, + generate_comparison_pairs, + parse_genrm_output, + apply_length_bonuses, + EMPTY_OUTPUT_PLACEHOLDER, +) + + +class TestGenerateComparisonPairs: + """Tests for generate_comparison_pairs function.""" + + def test_circular_strategy_3_responses(self) -> None: + """Circular strategy with 3 responses: (0,1), (1,2), (2,0).""" + pairs = generate_comparison_pairs("circular", 3) + assert pairs == [(0, 1), (1, 2), (2, 0)] + + def test_circular_strategy_4_responses(self) -> None: + """Circular strategy with 4 responses: (0,1), (1,2), (2,3), (3,0).""" + pairs = generate_comparison_pairs("circular", 4) + assert pairs == [(0, 1), (1, 2), (2, 3), (3, 0)] + + def test_circular_strategy_2_responses(self) -> None: + """Circular strategy with 2 responses: (0,1), (1,0).""" + pairs = generate_comparison_pairs("circular", 2) + assert pairs == [(0, 1), (1, 0)] + + def test_all_pairs_strategy_3_responses(self) -> None: + """All pairs strategy with 3 responses: C(3,2) = 3 pairs.""" + pairs = generate_comparison_pairs("all_pairs", 3) + assert pairs == [(0, 1), (0, 2), (1, 2)] + + def test_all_pairs_strategy_4_responses(self) -> None: + """All pairs strategy with 4 responses: C(4,2) = 6 pairs.""" + pairs = generate_comparison_pairs("all_pairs", 4) + assert pairs == [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)] + + def test_unsupported_strategy_raises(self) -> None: + """Unsupported strategy raises ValueError.""" + with pytest.raises(ValueError, match="Unknown comparison strategy"): + generate_comparison_pairs("unknown", 3) + + def test_less_than_2_responses_raises(self) -> None: + """Less than 2 responses raises ValueError.""" + with pytest.raises(ValueError, match="Need at least 2 responses"): + generate_comparison_pairs("circular", 1) + + +class TestParseGenRMOutput: + """Tests for parse_genrm_output function.""" + + def test_valid_json_fenced(self) -> None: + """Parse JSON from fenced code block.""" + output = """Here's my evaluation: +```json +{"score_1": 4, "score_2": 3, "ranking": 2} +``` +""" + score_1, score_2, ranking = parse_genrm_output(output, 3.0, 3.5) + assert score_1 == approx(4.0) + assert score_2 == approx(3.0) + assert ranking == approx(2.0) + + def test_valid_json_unfenced(self) -> None: + """Parse JSON from unfenced block.""" + output = 'The result is {"score_1": 5, "score_2": 2, "ranking": 1}' + score_1, score_2, ranking = parse_genrm_output(output, 3.0, 3.5) + assert score_1 == approx(5.0) + assert score_2 == approx(2.0) + assert ranking == approx(1.0) + + def test_partial_json_uses_defaults(self) -> None: + """Missing keys use default values.""" + output = '{"score_1": 4}' + score_1, score_2, ranking = parse_genrm_output(output, 3.0, 3.5) + assert score_1 == approx(4.0) + assert score_2 == approx(3.0) # default + assert ranking == approx(3.5) # default + + def test_no_json_returns_defaults(self) -> None: + """No JSON returns all defaults.""" + output = "This is just plain text without any JSON." + score_1, score_2, ranking = parse_genrm_output(output, 3.0, 3.5) + assert score_1 == approx(3.0) + assert score_2 == approx(3.0) + assert ranking == approx(3.5) + + def test_invalid_json_returns_defaults(self) -> None: + """Invalid JSON returns defaults.""" + output = '{"score_1": invalid}' + score_1, score_2, ranking = parse_genrm_output(output, 3.0, 3.5) + assert score_1 == approx(3.0) + assert score_2 == approx(3.0) + assert ranking == approx(3.5) + + def test_raise_on_fail(self) -> None: + """raise_on_fail=True raises GenRMOutputParseError.""" + output = "No JSON here" + with pytest.raises(GenRMOutputParseError): + parse_genrm_output(output, 3.0, 3.5, raise_on_fail=True) + + def test_multiple_json_uses_last_valid(self) -> None: + """Multiple JSON blocks uses the last valid one.""" + output = '{"score_1": 1} more text {"score_1": 5, "score_2": 4, "ranking": 3}' + score_1, score_2, ranking = parse_genrm_output(output, 3.0, 3.5) + assert score_1 == approx(5.0) + assert score_2 == approx(4.0) + assert ranking == approx(3.0) + + def test_pretty_printed_json(self) -> None: + """Parse pretty-printed JSON.""" + output = """ +{ + "score_1": 4, + "score_2": 3, + "ranking": 2 +} +""" + score_1, score_2, ranking = parse_genrm_output(output, 3.0, 3.5) + assert score_1 == approx(4.0) + assert score_2 == approx(3.0) + assert ranking == approx(2.0) + + +class TestExtractFromResponseObj: + """Tests for extract_from_response_obj function.""" + + def test_extract_output_text_only(self) -> None: + """Extract output text from message type.""" + response_obj = { + "output": [ + { + "type": "message", + "content": [ + {"type": "output_text", "text": "The answer is 42."} + ] + } + ] + } + reasoning, output = extract_from_response_obj(response_obj) + assert reasoning == "" + assert output == "The answer is 42." + + def test_extract_reasoning_and_output(self) -> None: + """Extract both reasoning and output.""" + response_obj = { + "output": [ + { + "type": "reasoning", + "summary": [{"text": "Let me think step by step."}] + }, + { + "type": "message", + "content": [ + {"type": "output_text", "text": "Final answer: 42"} + ] + } + ] + } + reasoning, output = extract_from_response_obj(response_obj) + assert reasoning == "Let me think step by step." + assert output == "Final answer: 42" + + def test_empty_response_obj(self) -> None: + """Empty response object returns empty strings.""" + response_obj = {} + reasoning, output = extract_from_response_obj(response_obj) + assert reasoning == "" + assert output == "" + + def test_non_dict_response_obj(self) -> None: + """Non-dict response object returns empty strings.""" + reasoning, output = extract_from_response_obj("not a dict") + assert reasoning == "" + assert output == "" + + +class TestExtractOutputText: + """Tests for extract_output_text function.""" + + def test_extract_text(self) -> None: + """Extract output text from response object.""" + response_obj = { + "output": [ + { + "type": "message", + "content": [{"type": "output_text", "text": "Hello world"}] + } + ] + } + text = extract_output_text(response_obj) + assert text == "Hello world" + + def test_empty_returns_placeholder(self) -> None: + """Empty output returns placeholder.""" + response_obj = {"output": []} + text = extract_output_text(response_obj) + assert text == EMPTY_OUTPUT_PLACEHOLDER + + def test_whitespace_only_returns_placeholder(self) -> None: + """Whitespace-only output returns placeholder.""" + response_obj = { + "output": [ + { + "type": "message", + "content": [{"type": "output_text", "text": " "}] + } + ] + } + text = extract_output_text(response_obj) + assert text == EMPTY_OUTPUT_PLACEHOLDER + + +class TestAggregateScores: + """Tests for aggregate_scores function.""" + + def _make_response_obj(self, output_text: str) -> dict: + """Helper to create a minimal Response API object.""" + return { + "output": [ + { + "type": "message", + "content": [{"type": "output_text", "text": output_text}] + } + ] + } + + def test_simple_tiebreaker_no_tie(self) -> None: + """Aggregate scores without tiebreaker.""" + # Two comparisons: (0,1) and (1,2) with clear winners + comparison_results = [ + (5.0, 3.0, 2.0), # Response 0 wins against 1 + (4.0, 2.0, 1.0), # Response 1 wins against 2 + ] + comparison_metadata = [(0, 1, 0), (1, 2, 0)] + response_objs = [ + self._make_response_obj("Answer 0"), + self._make_response_obj("Answer 1"), + self._make_response_obj("Answer 2"), + ] + + final_scores, metrics, base_scores, bonuses = aggregate_scores( + comparison_results=comparison_results, + comparison_metadata=comparison_metadata, + response_objs=response_objs, + aggregator_method="simple_tiebreaker", + default_score=3.0, + reasoning_bonus=0.0, + answer_bonus=0.0, + top_percentile=0.2, + group_reasoning_length_penalty_coeff=0.0, + group_answer_length_penalty_coeff=0.0, + ) + + # Response 0: score 5 from one comparison -> avg = 5 + # Response 1: score 3 + 4 from two comparisons -> avg = 3.5 + # Response 2: score 2 from one comparison -> avg = 2 + assert final_scores[0] == approx(5.0) + assert final_scores[1] == approx(3.5) + assert final_scores[2] == approx(2.0) + + def test_simple_tiebreaker_with_tie(self) -> None: + """Aggregate scores with tiebreaker activation.""" + # Tied scores with ranking deciding winner + comparison_results = [ + (3.0, 3.0, 2.0), # Tied, ranking=2 means response_1 (idx 0) is better + ] + comparison_metadata = [(0, 1, 0)] + response_objs = [ + self._make_response_obj("Answer 0"), + self._make_response_obj("Answer 1"), + ] + + final_scores, metrics, base_scores, bonuses = aggregate_scores( + comparison_results=comparison_results, + comparison_metadata=comparison_metadata, + response_objs=response_objs, + aggregator_method="simple_tiebreaker", + default_score=3.0, + reasoning_bonus=0.0, + answer_bonus=0.0, + top_percentile=0.2, + group_reasoning_length_penalty_coeff=0.0, + group_answer_length_penalty_coeff=0.0, + ) + + # Tiebreaker: adjustment = 3.5 - 2.0 = 1.5 + # Response 0: 3.0 + 1.5 = 4.5 + # Response 1: 3.0 - 1.5 = 1.5 + assert final_scores[0] == approx(4.5) + assert final_scores[1] == approx(1.5) + assert metrics["tiebreak_usage_rate"] == approx(1.0) + + def test_unsupported_aggregator_raises(self) -> None: + """Unsupported aggregator method raises ValueError.""" + with pytest.raises(ValueError, match="Unsupported aggregator_method"): + aggregate_scores( + comparison_results=[], + comparison_metadata=[], + response_objs=[], + aggregator_method="unknown", + default_score=3.0, + reasoning_bonus=0.0, + answer_bonus=0.0, + top_percentile=0.2, + group_reasoning_length_penalty_coeff=0.0, + group_answer_length_penalty_coeff=0.0, + ) + + def test_no_comparisons_returns_default(self) -> None: + """No comparisons for a response returns default score.""" + response_objs = [ + self._make_response_obj("Answer 0"), + ] + + final_scores, metrics, base_scores, bonuses = aggregate_scores( + comparison_results=[], + comparison_metadata=[], + response_objs=response_objs, + aggregator_method="simple_tiebreaker", + default_score=3.0, + reasoning_bonus=0.0, + answer_bonus=0.0, + top_percentile=0.2, + group_reasoning_length_penalty_coeff=0.0, + group_answer_length_penalty_coeff=0.0, + ) + + assert final_scores[0] == approx(3.0) + + +class TestApplyLengthBonuses: + """Tests for apply_length_bonuses function.""" + + def _make_response_obj(self, output_text: str, reasoning_text: str = "") -> dict: + """Helper to create a Response API object with reasoning and output.""" + output = [] + if reasoning_text: + output.append({ + "type": "reasoning", + "summary": [{"text": reasoning_text}] + }) + output.append({ + "type": "message", + "content": [{"type": "output_text", "text": output_text}] + }) + return {"output": output} + + def test_answer_bonus_shortest_among_top_gets_bonus(self) -> None: + """Bonus goes to the shortest answer within the top scorers; others unchanged.""" + scores = [5.0, 5.0, 4.0] # Top scorers: idx 0 and 1 (both 5.0) + response_objs = [ + self._make_response_obj("a"), # 1 char (shortest and top) + self._make_response_obj("much longer texttt"), # longer but also top scorer + self._make_response_obj("mid length here"), # not top scorer, longer than idx0 + ] + + adjusted, bonuses = apply_length_bonuses( + scores=scores, + response_objs=response_objs, + reasoning_bonus=0.0, + answer_bonus=0.5, + top_percentile=0.67, # Top ~2 of 3 + group_reasoning_length_penalty_coeff=0.0, + group_answer_length_penalty_coeff=0.0, + ) + + # Only idx 0 is shortest among the top scorers -> gets the bonus + assert adjusted == approx([5.5, 5.0, 4.0]) + assert bonuses == approx([0.5, 0.0, 0.0]) + + def test_no_bonus_when_disabled(self) -> None: + """No bonuses applied when all bonus configs are 0.""" + scores = [5.0, 4.0] + response_objs = [ + self._make_response_obj("short"), + self._make_response_obj("much longer text here"), + ] + + adjusted, bonuses = apply_length_bonuses( + scores=scores, + response_objs=response_objs, + reasoning_bonus=0.0, + answer_bonus=0.0, + top_percentile=0.2, + group_reasoning_length_penalty_coeff=0.0, + group_answer_length_penalty_coeff=0.0, + ) + + assert adjusted[0] == approx(5.0) + assert adjusted[1] == approx(4.0) + assert bonuses == [0.0, 0.0] diff --git a/resources_servers/genrm_compare/utils.py b/resources_servers/genrm_compare/utils.py new file mode 100644 index 000000000..397478a0b --- /dev/null +++ b/resources_servers/genrm_compare/utils.py @@ -0,0 +1,497 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for GenRM comparison server. + +This module provides: +- Comparison pair generation strategies (circular, all_pairs) +- GenRM output parsing (JSON score extraction) +- Response API object text extraction +- Score aggregation with tiebreaker logic +- Length-based bonus/penalty computation +""" + +from __future__ import annotations + +import itertools +import json +import logging +import re +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Constants +# ============================================================================= + +SUPPORTED_COMPARISON_STRATEGIES = frozenset({"circular", "all_pairs"}) +SUPPORTED_AGGREGATOR_METHODS = frozenset({"simple_tiebreaker"}) + +# GenRM ranking midpoint - rankings are 1-6, midpoint is 3.5 +# Used in tiebreaker: ranking < 3.5 means response_1 is better, > 3.5 means response_2 is better +RANKING_MIDPOINT = 3.5 + +# Default placeholder for empty/missing output +EMPTY_OUTPUT_PLACEHOLDER = "None" + + +# ============================================================================= +# Exceptions +# ============================================================================= + +class GenRMOutputParseError(ValueError): + """Raised when GenRM output cannot be parsed into expected JSON scores. + + Expected format: {"score_1": <1-5>, "score_2": <1-5>, "ranking": <1-6>} + """ + pass + + +# ============================================================================= +# Comparison Strategy +# ============================================================================= + +def generate_comparison_pairs(strategy: str, num_responses: int) -> List[Tuple[int, int]]: + """Generate pairs of response indices for pairwise comparison. + + Args: + strategy: Comparison strategy - "circular" or "all_pairs" + - "circular": Each response compared with next (N comparisons) + - "all_pairs": Every pair compared (N*(N-1)/2 comparisons) + num_responses: Number of responses to compare + + Returns: + List of (i, j) tuples where i < j for all_pairs, or circular neighbors + + Raises: + ValueError: If strategy is not supported or num_responses < 2 + """ + if strategy not in SUPPORTED_COMPARISON_STRATEGIES: + raise ValueError( + f"Unknown comparison strategy: '{strategy}'. " + f"Supported: {sorted(SUPPORTED_COMPARISON_STRATEGIES)}" + ) + + if num_responses < 2: + raise ValueError(f"Need at least 2 responses for comparison, got {num_responses}") + + if strategy == "all_pairs": + return list(itertools.combinations(range(num_responses), 2)) + else: # circular + return [(i, (i + 1) % num_responses) for i in range(num_responses)] + + +# ============================================================================= +# GenRM Output Parsing +# ============================================================================= + +def parse_genrm_output( + output: str, + default_score: float, + default_ranking: float, + *, + raise_on_fail: bool = False, +) -> Tuple[float, float, float]: + """Parse GenRM output to extract scores from JSON format. + + Searches for JSON in the output text, trying: + 1. Fenced JSON blocks (```json {...} ```) + 2. Any {...} JSON objects, taking the last valid one + + Expected JSON format: + {"score_1": <1-5>, "score_2": <1-5>, "ranking": <1-6>} + + Args: + output: Raw text output from GenRM model + default_score: Default score if parsing fails + default_ranking: Default ranking if parsing fails + raise_on_fail: If True, raise GenRMOutputParseError on failure + + Returns: + Tuple of (score_1, score_2, ranking) + + Raises: + GenRMOutputParseError: If raise_on_fail=True and parsing fails + """ + def _try_parse(json_str: str) -> Optional[Tuple[float, float, float]]: + """Attempt to parse a JSON string into scores.""" + try: + parsed = json.loads(json_str) + except json.JSONDecodeError: + return None + + if not isinstance(parsed, dict): + return None + + # Must have at least one expected key + if not any(k in parsed for k in ("score_1", "score_2", "ranking")): + return None + + try: + score_1 = float(parsed.get("score_1", default_score)) + score_2 = float(parsed.get("score_2", default_score)) + ranking = float(parsed.get("ranking", default_ranking)) + return score_1, score_2, ranking + except (TypeError, ValueError): + return None + + try: + # Strategy 1: Look for fenced JSON blocks (```json {...} ```) + for match in re.finditer(r"```json\s*(\{[\s\S]*?\})\s*```", output, flags=re.IGNORECASE): + result = _try_parse(match.group(1)) + if result is not None: + return result + + # Strategy 2: Find all {...} and take the last valid one. + # + # Note: We keep this intentionally permissive because model outputs can include + # extra prose around a JSON blob, and the JSON may be pretty-printed across lines. + last_valid: Optional[Tuple[float, float, float]] = None + for match in re.finditer(r"(\{[\s\S]*?\})", output): + result = _try_parse(match.group(1)) + if result is not None: + last_valid = result + + if last_valid is not None: + return last_valid + + # Parsing failed + preview = output[:200] + "..." if len(output) > 200 else output + msg = f"No parseable JSON found in GenRM output: {preview}" + + if raise_on_fail: + raise GenRMOutputParseError(msg) + + logger.warning(msg) + return default_score, default_score, default_ranking + except Exception as e: + preview = output[:200] + "..." if len(output) > 200 else output + msg = f"Error parsing GenRM output: {e}. Output: {preview}" + if raise_on_fail: + raise GenRMOutputParseError(msg) from e + logger.exception(msg) + return default_score, default_score, default_ranking + + +# ============================================================================= +# Response API Object Extraction +# ============================================================================= + +def extract_from_response_obj(response_obj: Dict[str, Any]) -> Tuple[str, str]: + """Extract reasoning and output text from a Response API object. + + Parses the nested Response API structure to find: + - Reasoning content from "reasoning" type items + - Output text from "message" type items with "output_text" content + + Args: + response_obj: Raw Response API object with "output" field + + Returns: + Tuple of (reasoning_content, output_text) + """ + reasoning_content = "" + output_text = "" + + if not isinstance(response_obj, dict): + return reasoning_content, output_text + + output = response_obj.get("output", []) + if not isinstance(output, list): + return reasoning_content, output_text + + for item in output: + if not isinstance(item, dict): + continue + + item_type = item.get("type", "") + + if item_type == "reasoning": + # Extract from summary field + summary = item.get("summary", []) + if isinstance(summary, list): + for s in summary: + if isinstance(s, dict) and "text" in s: + reasoning_content += s.get("text", "") + + elif item_type == "message": + # Extract from content field + content = item.get("content", []) + if isinstance(content, list): + for c in content: + if isinstance(c, dict) and c.get("type") == "output_text": + output_text += c.get("text", "") + + return reasoning_content, output_text + + +def extract_output_text(response_obj: Dict[str, Any]) -> str: + """Extract only the output text (final answer) from Response API object. + + Args: + response_obj: Raw Response API object + + Returns: + The output text, or "None" if empty/missing + """ + _, output = extract_from_response_obj(response_obj) + + if not output or not output.strip(): + return EMPTY_OUTPUT_PLACEHOLDER + + return output + + +# ============================================================================= +# Length-Based Bonuses +# ============================================================================= + +def apply_length_bonuses( + scores: List[float], + response_objs: List[Dict[str, Any]], + reasoning_bonus: float, + answer_bonus: float, + top_percentile: float, + group_reasoning_length_penalty_coeff: float, + group_answer_length_penalty_coeff: float, +) -> Tuple[List[float], List[float]]: + """Apply length-based bonuses/penalties to scores. + + Two types of adjustments: + 1. Top-performer bonuses: Shortest reasoning/answer among top scorers gets bonus + 2. Group-relative penalties: Scores adjusted based on relative length within group + + Args: + scores: Base scores to adjust (modified in place) + response_objs: Response API objects to extract lengths from + reasoning_bonus: Bonus for shortest reasoning among top performers + answer_bonus: Bonus for shortest answer among top performers + top_percentile: Fraction of top scorers eligible for bonuses (0.0-1.0) + group_reasoning_length_penalty_coeff: Coefficient for reasoning length penalty + group_answer_length_penalty_coeff: Coefficient for answer length penalty + + Returns: + Tuple of (adjusted_scores, bonuses_per_response) + """ + num_responses = len(response_objs) + + if num_responses < 2: + return scores, [0.0] * len(scores) + + # Extract lengths from Response API objects + bonuses = [0.0] * num_responses + reasoning_lengths: List[int] = [] + answer_lengths: List[int] = [] + + for obj in response_objs: + reasoning, answer = extract_from_response_obj(obj) + reasoning_lengths.append(len(reasoning.strip())) + answer_lengths.append(len(answer.strip())) + + logger.debug(f"Reasoning lengths: {reasoning_lengths}") + logger.debug(f"Answer lengths: {answer_lengths}") + + # Determine top percentile threshold + sorted_scores = sorted(scores, reverse=True) + threshold_idx = max(0, int(len(scores) * top_percentile) - 1) + top_threshold = sorted_scores[threshold_idx] + + # Bonus for shortest non-empty reasoning among top performers + if reasoning_bonus > 0: + valid = [(i, length) for i, length in enumerate(reasoning_lengths) if length > 0] + if valid: + idx, length = min(valid, key=lambda x: x[1]) + if scores[idx] >= top_threshold: + scores[idx] += reasoning_bonus + bonuses[idx] += reasoning_bonus + logger.debug(f"Reasoning bonus +{reasoning_bonus} to response {idx} (len={length})") + + # Bonus for shortest non-empty answer among top performers + if answer_bonus > 0: + valid = [(i, length) for i, length in enumerate(answer_lengths) if length > 0] + if valid: + idx, length = min(valid, key=lambda x: x[1]) + if scores[idx] >= top_threshold: + scores[idx] += answer_bonus + bonuses[idx] += answer_bonus + logger.debug(f"Answer bonus +{answer_bonus} to response {idx} (len={length})") + + # Group-relative length adjustment (shorter = higher weight, zero-centered) + if group_reasoning_length_penalty_coeff > 0 or group_answer_length_penalty_coeff > 0: + reasoning_weights = _compute_length_weights(reasoning_lengths) + answer_weights = _compute_length_weights(answer_lengths) + + for idx in range(num_responses): + reasoning_adj = reasoning_weights[idx] * group_reasoning_length_penalty_coeff + answer_adj = answer_weights[idx] * group_answer_length_penalty_coeff + total_adj = reasoning_adj + answer_adj + + if total_adj != 0: + scores[idx] += total_adj + bonuses[idx] += total_adj + logger.debug( + f"Length adjustment {total_adj:+.4f} to response {idx} " + f"(reasoning={reasoning_adj:+.4f}, answer={answer_adj:+.4f})" + ) + + return scores, bonuses + + +def _compute_length_weights(lengths: List[int]) -> List[float]: + """Compute zero-centered weights where shorter = higher weight. + + Args: + lengths: List of lengths + + Returns: + List of weights, zero-centered (sum to ~0) + """ + max_len, min_len = max(lengths), min(lengths) + + if max_len == min_len: + return [0.0] * len(lengths) + + span = max_len - min_len + # Shorter = weight closer to 1, longer = weight closer to 0 + raw_weights = [1.0 - ((length - min_len) / span) for length in lengths] + # Zero-center + mean_weight = sum(raw_weights) / len(raw_weights) + return [w - mean_weight for w in raw_weights] + + +# ============================================================================= +# Score Aggregation +# ============================================================================= + +def aggregate_scores( + comparison_results: List[Tuple[float, float, float]], + comparison_metadata: List[Tuple[int, int, int]], + response_objs: List[Dict[str, Any]], + aggregator_method: str, + default_score: float, + reasoning_bonus: float, + answer_bonus: float, + top_percentile: float, + group_reasoning_length_penalty_coeff: float, + group_answer_length_penalty_coeff: float, +) -> Tuple[List[float], Dict[str, float], List[float], List[float]]: + """Aggregate pairwise comparison results into per-response rewards. + + For "simple_tiebreaker" method: + - When score_1 == score_2, use ranking to break the tie + - ranking < 3.5 means response_1 is better → boost score_1, penalize score_2 + - ranking > 3.5 means response_2 is better → boost score_2, penalize score_1 + + Args: + comparison_results: List of (score_1, score_2, ranking) from pairwise comparisons + comparison_metadata: List of (response_i, response_j, judge_idx) for each comparison + response_objs: Raw Response API objects for length bonus computation + aggregator_method: Only "simple_tiebreaker" is supported + default_score: Default score when no comparisons exist for a response + reasoning_bonus: Bonus for shortest reasoning among top performers + answer_bonus: Bonus for shortest answer among top performers + top_percentile: Percentile threshold for length bonuses + group_reasoning_length_penalty_coeff: Coefficient for reasoning length penalty + group_answer_length_penalty_coeff: Coefficient for answer length penalty + + Returns: + Tuple of: + - final_scores: Per-response rewards after all adjustments + - metrics: Aggregation statistics (mean, std, tiebreak rate) + - base_scores: Scores before length bonuses + - bonuses: Length bonus applied to each response + + Raises: + ValueError: If aggregator_method is not supported + """ + if aggregator_method not in SUPPORTED_AGGREGATOR_METHODS: + raise ValueError( + f"Unsupported aggregator_method: '{aggregator_method}'. " + f"Supported: {sorted(SUPPORTED_AGGREGATOR_METHODS)}" + ) + + num_responses = len(response_objs) + + # Initialize accumulators + accumulated_scores = [0.0] * num_responses + comparison_counts = [0] * num_responses + + # Track metrics + all_individual_scores: List[float] = [] + tiebreak_count = 0 + + # Process each comparison + for (score_1, score_2, ranking), (i, j, _judge_idx) in zip( + comparison_results, comparison_metadata + ): + all_individual_scores.extend([score_1, score_2]) + + # Apply tiebreaker when scores are equal + if score_1 == score_2: + tiebreak_count += 1 + # ranking < 3.5 → response_1 better, ranking > 3.5 → response_2 better + adjustment = RANKING_MIDPOINT - ranking + score_1 = score_1 + adjustment + score_2 = score_2 - adjustment + + # Accumulate + accumulated_scores[i] += score_1 + accumulated_scores[j] += score_2 + comparison_counts[i] += 1 + comparison_counts[j] += 1 + + # Compute average scores + final_scores = [ + accumulated_scores[idx] / comparison_counts[idx] if comparison_counts[idx] > 0 else default_score + for idx in range(num_responses) + ] + + # Store base scores before length adjustments + base_scores = list(final_scores) + bonuses = [0.0] * num_responses + + # Apply length bonuses if any are configured + if any([ + reasoning_bonus > 0, + answer_bonus > 0, + group_reasoning_length_penalty_coeff > 0, + group_answer_length_penalty_coeff > 0, + ]): + final_scores, bonuses = apply_length_bonuses( + scores=final_scores, + response_objs=response_objs, + reasoning_bonus=reasoning_bonus, + answer_bonus=answer_bonus, + top_percentile=top_percentile, + group_reasoning_length_penalty_coeff=group_reasoning_length_penalty_coeff, + group_answer_length_penalty_coeff=group_answer_length_penalty_coeff, + ) + + # Compute metrics + metrics: Dict[str, float] = {} + + if all_individual_scores: + scores_array = np.array(all_individual_scores) + metrics["mean_individual_score"] = float(np.mean(scores_array)) + metrics["std_individual_score"] = float(np.std(scores_array)) + + if comparison_results: + metrics["tiebreak_usage_rate"] = tiebreak_count / len(comparison_results) + + return final_scores, metrics, base_scores, bonuses diff --git a/responses_api_models/vllm_model/app.py b/responses_api_models/vllm_model/app.py index 1f875ea6a..8c6ed19f0 100644 --- a/responses_api_models/vllm_model/app.py +++ b/responses_api_models/vllm_model/app.py @@ -12,12 +12,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json +import os import re +import urllib from copy import deepcopy -from time import time +from multiprocessing import Process +from time import sleep, time from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union from uuid import uuid4 +import ray from aiohttp.client_exceptions import ClientResponseError from fastapi import Request from pydantic import BaseModel, Field @@ -27,6 +32,7 @@ Body, SimpleResponsesAPIModel, ) +from nemo_gym.global_config import find_open_port from nemo_gym.openai_utils import ( RESPONSES_TO_TRAIN, NeMoGymAsyncOpenAI, @@ -34,6 +40,7 @@ NeMoGymChatCompletionAssistantMessageForTrainingParam, NeMoGymChatCompletionAssistantMessageParam, NeMoGymChatCompletionCreateParamsNonStreaming, + NeMoGymChatCompletionCustomRoleMessageParam, NeMoGymChatCompletionDeveloperMessageParam, NeMoGymChatCompletionMessage, NeMoGymChatCompletionMessageParam, @@ -55,6 +62,10 @@ NeMoGymSummary, TokenIDLogProbMixin, ) +from nemo_gym.ray_utils import ( + lookup_current_ray_node_ip, + spinup_single_ray_gpu_node_worker, +) from nemo_gym.server_utils import SESSION_ID_KEY @@ -72,23 +83,193 @@ class VLLMModelConfig(BaseResponsesAPIModelConfig): # Corresponds to the extra_body of OpenAI Client. extra_body: Optional[Dict[str, Any]] = None + spinup_server: bool = False + server_args: Optional[Dict[str, Any]] = None + server_env: Optional[Dict[str, str]] = None + + router_dp_size: int = 1 + def model_post_init(self, context): if isinstance(self.base_url, str): self.base_url = [self.base_url] return super().model_post_init(context) +def _start_vllm_server(config: VLLMModelConfig, server_host: str, server_port: int, router_dp_rank: int) -> None: + for k, v in (config.server_env or {}).items(): + os.environ[k] = v + + import uvloop + import vllm.engine.arg_utils + import vllm.entrypoints.openai.api_server + import vllm.entrypoints.openai.cli_args + import vllm.utils.argparse_utils + + argv = [] + argv.append("--model") + argv.append(config.model) + argv.append("--host") + argv.append(server_host) + argv.append("--port") + argv.append(f"{server_port}") + argv.append("--distributed-executor-backend") + argv.append("mp") + for k, v in (config.server_args or {}).items(): + k2 = k.replace("_", "-") + if v is None: + pass + elif isinstance(v, bool): + if not v: + arg_key = f"--no-{k2}" + else: + arg_key = f"--{k2}" + argv.append(arg_key) + elif isinstance(v, dict): + # Dict values must be passed as JSON strings to vLLM CLI + arg_key = f"--{k2}" + argv.append(arg_key) + argv.append(json.dumps(v)) + else: + arg_key = f"--{k2}" + argv.append(arg_key) + argv.append(f"{v}") + + server_args = vllm.utils.argparse_utils.FlexibleArgumentParser() + server_args = vllm.entrypoints.openai.cli_args.make_arg_parser(server_args) + server_args = server_args.parse_args(argv) + vllm.entrypoints.openai.cli_args.validate_parsed_serve_args(server_args) + + uvloop.run(vllm.entrypoints.openai.api_server.run_server(server_args)) + + +@ray.remote +class VLLMServerSpinupWorker: + def __init__(self, config: VLLMModelConfig, working_dir: Optional[str], router_dp_rank: int): + self.config = config + self.working_dir = working_dir + self.router_dp_rank = router_dp_rank + self._server_host = lookup_current_ray_node_ip() + self._server_port = find_open_port() + + if self.working_dir is not None: + os.chdir(self.working_dir) + + server_proc = Process( + target=_start_vllm_server, + args=( + self.config, + self._server_host, + self._server_port, + self.router_dp_rank, + ), + daemon=False, + ) + server_proc.start() + self._server_proc = server_proc + + def _get_ip(self) -> int: + return self._server_host + + def _get_port(self) -> int: + return self._server_port + + +# Use this to query the VLLM servers during spinup without having to start an +# asyncio event loop for the async client. +def _vllm_server_heartbeat(base_url: str): + req_headers = { + "Content-Type": "application/json", + "Accept": "application/json", + } + req_body = { + "messages": [ + { + "role": "user", + "content": "hi", + } + ], + "max_tokens": 8, + "temperature": 1.0, + } + req_data = json.dumps(req_body).encode("utf-8") + req_url = f"{base_url}/chat/completions" + req = urllib.request.Request( + req_url, + headers=req_headers, + data=req_data, + ) + with urllib.request.urlopen(req, timeout=5) as out: + out_status = out.status + out_data = out.read() + output = out_data.decode("utf-8") + return { + "_status": out_status, + "output": output, + "except": None, + } + + class VLLMModel(SimpleResponsesAPIModel): config: VLLMModelConfig def model_post_init(self, context): - self._clients = [ - NeMoGymAsyncOpenAI( - base_url=base_url, - api_key=self.config.api_key, - ) - for base_url in self.config.base_url - ] + working_dir = os.getcwd() + + if self.config.spinup_server: + self._server_urls = [] + self._server_workers = [] + self._clients = [] + + # TODO: support for other parallel sizes. + server_tp_size = (self.config.server_args or {}).get("tensor_parallel_size", 1) + server_dp_size = (self.config.server_args or {}).get("data_parallel_size", 1) + + assert server_dp_size == 1 + + router_dp_size = max(1, self.config.router_dp_size) + + for router_dp_rank in range(router_dp_size): + server_worker = spinup_single_ray_gpu_node_worker( + VLLMServerSpinupWorker, + server_tp_size, + config=self.config, + working_dir=working_dir, + router_dp_rank=router_dp_rank, + ) + + server_ip = ray.get(server_worker._get_ip.remote()) + server_port = ray.get(server_worker._get_port.remote()) + server_url = f"http://{server_ip}:{server_port}/v1" + + self._server_urls.append(server_url) + self._server_workers.append(server_worker) + + self._clients.append( + NeMoGymAsyncOpenAI( + base_url=server_url, + api_key=self.config.api_key, + ) + ) + + for server_url in self._server_urls: + while True: + try: + _vllm_server_heartbeat(server_url) + break + except Exception: + sleep(3) + continue + + else: + self._server_urls = None + self._server_workers = None + self._clients = [ + NeMoGymAsyncOpenAI( + base_url=base_url, + api_key=self.config.api_key, + ) + for base_url in self.config.base_url + ] self._session_id_to_client: Dict[str, NeMoGymAsyncOpenAI] = dict() @@ -523,6 +704,10 @@ def _format_message( role="developer", ) ] + # Custom roles (e.g., GenRM response_1/response_2 for pairwise comparison) + case "response_1" | "response_2" | "principle": + state.flush_assistant() + converted = [NeMoGymChatCompletionCustomRoleMessageParam(role=m["role"], content=content)] case _: # pragma: no cover raise NotImplementedError(f"Unrecognized role for message: `{m['role']}`") diff --git a/tests/unit_tests/test_server_utils.py b/tests/unit_tests/test_server_utils.py index dfd39da65..7313db752 100644 --- a/tests/unit_tests/test_server_utils.py +++ b/tests/unit_tests/test_server_utils.py @@ -194,7 +194,9 @@ def test_initialize_ray_with_address(self, monkeypatch: MonkeyPatch) -> None: ray_is_initialized_mock.assert_called_once() get_global_config_dict_mock.assert_called_once() - ray_init_mock.assert_called_once_with(address="ray://test-address:10001", ignore_reinit_error=True) + ray_init_mock.assert_called_once_with( + address="ray://test-address:10001", ignore_reinit_error=True, namespace="nemo_gym" + ) def test_initialize_ray_without_address(self, monkeypatch: MonkeyPatch) -> None: ray_is_initialized_mock = self._mock_ray_return_value(monkeypatch, False) @@ -217,5 +219,5 @@ def test_initialize_ray_without_address(self, monkeypatch: MonkeyPatch) -> None: ray_is_initialized_mock.assert_called_once() get_global_config_dict_mock.assert_called_once() - ray_init_mock.assert_called_once_with(ignore_reinit_error=True) + ray_init_mock.assert_called_once_with(ignore_reinit_error=True, namespace="nemo_gym") ray_get_runtime_context_mock.assert_called_once()