diff --git a/examples/grpo_math/README.md b/examples/grpo_gsm8k_openjudge/README.md similarity index 100% rename from examples/grpo_math/README.md rename to examples/grpo_gsm8k_openjudge/README.md diff --git a/examples/grpo_gsm8k_openjudge/gsm8k.yaml b/examples/grpo_gsm8k_openjudge/gsm8k.yaml new file mode 100644 index 00000000000..4d734765c6b --- /dev/null +++ b/examples/grpo_gsm8k_openjudge/gsm8k.yaml @@ -0,0 +1,69 @@ +project: "Trinity-RFT" +name: "qwen2.5-1.5B-gsm8k-openjudge" +checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} +algorithm: + algorithm_type: grpo + repeat_times: 8 + optimizer: + lr: 1e-5 +model: + model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct} + max_response_tokens: 1024 + max_model_len: 2048 +cluster: + node_num: 1 + gpu_per_node: 2 +buffer: + total_epochs: 1 + batch_size: 96 + explorer_input: + taskset: + name: gsm8k + storage_type: file + path: ${oc.env:TRINITY_TASKSET_PATH,openai/gsm8k} + subset_name: 'main' + split: 'train' + format: + prompt_key: 'question' + response_key: 'answer' + rollout_args: + temperature: 1.0 + reward_fn_args: + model_name: ${oc.env:TRINITY_JUDGE_MODEL_NAME,qwen3-max} + judge_api_base_url_env: OPENAI_BASE_URL + judge_api_key_env: OPENAI_API_KEY + normalize_score: true + score_min: 1.0 + score_max: 3.0 + eval_tasksets: + - name: gsm8k-eval + storage_type: file + path: ${oc.env:TRINITY_TASKSET_PATH,openai/gsm8k} + subset_name: 'main' + split: 'test' + format: + prompt_key: 'question' + response_key: 'answer' + default_workflow_type: 'async_math_openjudge_workflow' + default_reward_fn_type: 'trajectory_accuracy_grader_reward' +explorer: + eval_interval: 50 + runner_per_model: 8 + rollout_model: + engine_num: 1 + tensor_parallel_size: 1 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + seed: 42 +synchronizer: + sync_method: 'nccl' + sync_interval: 1 + sync_timeout: 1200 +trainer: + trainer_type: 'verl' + save_interval: 100 + grad_clip: 1.0 + use_dynamic_bsz: true + max_token_len_per_gpu: 16384 + ulysses_sequence_parallel_size: 1 diff --git a/examples/grpo_math/math.yaml b/examples/grpo_math/math.yaml deleted file mode 100644 index 1ec35ce86cf..00000000000 --- a/examples/grpo_math/math.yaml +++ /dev/null @@ -1,58 +0,0 @@ -project: grpo_math -name: grpo_math_example -checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} -model: - model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct} - max_response_tokens: 3072 - max_model_len: 4096 -algorithm: - algorithm_type: grpo - repeat_times: 8 - optimizer: - lr: 5e-7 -cluster: - node_num: 1 - gpu_per_node: 8 -buffer: - total_epochs: 20 - batch_size: 288 - explorer_input: - taskset: - name: math - storage_type: file - path: ${oc.env:TRINITY_TASKSET_PATH} - format: - prompt_key: 'question' - response_key: 'gt_answer' - rollout_args: - temperature: 1.0 - logprobs: 0 - reward_fn_args: - reward_name: math_verify_reward - default_workflow_type: 'math_rm_workflow' - default_reward_fn_type: 'rm_gallery_reward' - trainer_input: - experience_buffer: - name: math_buffer - storage_type: queue - path: 'sqlite:///math.db' -explorer: - eval_interval: 10 - runner_per_model: 8 - rollout_model: - engine_num: 2 - tensor_parallel_size: 1 - enable_prefix_caching: false - enforce_eager: true - dtype: bfloat16 - seed: 42 -synchronizer: - sync_method: 'nccl' - sync_interval: 1 - sync_timeout: 1200 -trainer: - save_interval: 100 - grad_clip: 1.0 - use_dynamic_bsz: true - max_token_len_per_gpu: 16384 - ulysses_sequence_parallel_size: 1 diff --git a/pyproject.toml b/pyproject.toml index e140fd1d6c3..4e705aecbf5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,8 +62,8 @@ data = [ agent = [ "agentscope[tuner]>=1.0.18" ] -rm_gallery = [ - "rm-gallery>=0.1.5" +openjudge = [ + "py-openjudge>=0.2.2" ] dev = [ "pre-commit>=2.17.0", diff --git a/trinity/common/rewards/__init__.py b/trinity/common/rewards/__init__.py index c129b6b3bf5..46fddb78bf7 100644 --- a/trinity/common/rewards/__init__.py +++ b/trinity/common/rewards/__init__.py @@ -1,24 +1,23 @@ # -*- coding: utf-8 -*- """Reward functions for RFT""" - from trinity.common.rewards.reward_fn import RewardFn from trinity.utils.registry import Registry REWARD_FUNCTIONS = Registry( "reward_functions", default_mapping={ - "rm_gallery_reward": "trinity.common.rewards.reward_fn.RMGalleryFn", "math_reward": "trinity.common.rewards.math_reward.MathRewardFn", "math_boxed_reward": "trinity.common.rewards.math_reward.MathBoxedRewardFn", "format_reward": "trinity.common.rewards.format_reward.FormatReward", "countdown_reward": "trinity.common.rewards.countdown_reward.CountDownRewardFn", "accuracy_reward": "trinity.common.rewards.accuracy_reward.AccuracyReward", "math_dapo_reward": "trinity.common.rewards.dapo_reward.MathDAPORewardFn", + "trajectory_accuracy_grader_reward": "trinity.common.rewards.open_judge_reward.TrajectoryAccuracyGrader", + "openjudge_multi_grader_reward": "trinity.common.rewards.open_judge_reward.OpenJudgeRewardFn", }, ) __all__ = [ "RewardFn", - "RMGalleryFn", "REWARD_FUNCTIONS", ] diff --git a/trinity/common/rewards/open_judge_reward.py b/trinity/common/rewards/open_judge_reward.py new file mode 100644 index 00000000000..69bba9aaaf9 --- /dev/null +++ b/trinity/common/rewards/open_judge_reward.py @@ -0,0 +1,224 @@ +"""OpenJudge reward function classes.""" + +import asyncio +import os +from typing import Any, Dict, List, Optional + +from trinity.common.experience import Experience +from trinity.common.rewards.reward_fn import RewardFn + + +class OpenJudgeRewardFn(RewardFn): + """Reward Function using OpenJudge multi-grader pipeline. + + Args: + grader_configs: Dict mapping grader name to a GraderConfig, a + BaseGrader instance, or a (BaseGrader, mapper) tuple. + When *None* a default pair (CorrectnessGrader + RelevanceGrader) + is used so the class is usable out of the box. + model_name: Default judge model for any grader that needs one. + max_concurrency: Passed to GradingRunner. + score_aggregation: How to combine per-grader scores into the final + ``reward`` key. ``"mean"`` (default) or ``"sum"``. + judge_api_base_url_env: Env-var holding the judge API base URL. + judge_api_key_env: Env-var holding the judge API key. + """ + + def __init__( + self, + grader_configs: Optional[Dict[str, Any]] = None, + model_name: str = "qwen3-32b", + max_concurrency: int = 8, + score_aggregation: str = "mean", + judge_api_base_url_env: str = "OPENAI_BASE_URL", + judge_api_key_env: str = "OPENAI_API_KEY", + **kwargs, + ): + try: + from openjudge.models.openai_chat_model import ( # pyright: ignore[reportMissingImports] + OpenAIChatModel, + ) + from openjudge.runner.grading_runner import ( # pyright: ignore[reportMissingImports] + GradingRunner, + ) + except ImportError as e: + raise ImportError( + "OpenJudge dependencies are not installed. " + "Please install with `pip install -e .[openjudge]`." + ) from e + + self.score_aggregation = score_aggregation + + if grader_configs is None: + from openjudge.graders.common.correctness import ( # pyright: ignore[reportMissingImports] + CorrectnessGrader, + ) + from openjudge.graders.common.relevance import ( # pyright: ignore[reportMissingImports] + RelevanceGrader, + ) + + judge_base_url = os.getenv(judge_api_base_url_env, "") + if not judge_base_url: + raise ValueError(f"Judge base URL is missing. Set env `{judge_api_base_url_env}`.") + model_kwargs: Dict[str, Any] = { + "model": model_name, + "base_url": judge_base_url, + "api_key": os.getenv(judge_api_key_env, ""), + } + model = OpenAIChatModel(**model_kwargs) + grader_configs = { + "correctness": CorrectnessGrader(model=model), + "relevance": RelevanceGrader(model=model), + } + + self.runner = GradingRunner( + grader_configs=grader_configs, + max_concurrency=max_concurrency, + show_progress=False, + ) + + def __call__( # type: ignore[override] + self, + experience: Any, + messages: List[Dict[str, Any]], + **kwargs, + ) -> Dict[str, float]: + """Evaluate a single experience and return a reward dict.""" + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(self.acall(experience, messages, **kwargs)) + + raise RuntimeError( + "OpenJudgeRewardFn.__call__ cannot be used inside a running event loop. " + "Use `await reward_fn.acall(...)` in async workflows." + ) + + async def acall( # type: ignore[override] + self, + experience: Any, + messages: List[Dict[str, Any]], + **kwargs, + ) -> Dict[str, float]: + """Async evaluation for event-loop contexts.""" + merged_messages = list(messages) + if not merged_messages or merged_messages[-1].get("role") != "assistant": + merged_messages.append( + { + "role": "assistant", + "content": str(getattr(experience, "response_text", "") or ""), + } + ) + + data = {"messages": merged_messages} + batch_results = await self.runner.arun([data]) + return self._extract_reward(batch_results) + + def _extract_reward(self, batch_results: Dict[str, Any]) -> Dict[str, float]: + from openjudge.graders.schema import ( # pyright: ignore[reportMissingImports] + GraderError, + GraderScore, + ) + + reward_dict: Dict[str, float] = {} + scores: List[float] = [] + + for grader_name, grader_results in batch_results.items(): + if not grader_results: + continue + result = grader_results[0] + if isinstance(result, GraderScore): + reward_dict[f"{grader_name}_score"] = result.score + scores.append(result.score) + elif isinstance(result, GraderError): + reward_dict[f"{grader_name}_score"] = 0.0 + scores.append(0.0) + + if scores: + reward_dict["reward"] = ( + sum(scores) / len(scores) if self.score_aggregation == "mean" else sum(scores) + ) + else: + reward_dict["reward"] = 0.0 + + return reward_dict + + +class TrajectoryAccuracyGrader(OpenJudgeRewardFn): + """Single-grader reward using OpenJudge TrajectoryAccuracyGrader. + + Args: + reward_name: Logical name for this reward (used for logging/registry). + model_name: Judge model passed to OpenAIChatModel. + normalize_score: When True, linearly maps the raw score to [0, 1] + using score_min / score_max. + score_min: Lower bound for normalisation (default 1.0). + score_max: Upper bound for normalisation (default 3.0). + judge_api_base_url_env: Env-var holding the judge API base URL. + judge_api_key_env: Env-var holding the judge API key. + """ + + def __init__( + self, + reward_name: str = "openjudge_trajectory_accuracy_reward", + model_name: str = "qwen3-max", + normalize_score: bool = True, + score_min: float = 1.0, + score_max: float = 3.0, + judge_api_base_url_env: str = "OPENAI_BASE_URL", + judge_api_key_env: str = "OPENAI_API_KEY", + **kwargs, + ): + try: + from openjudge.graders.agent.trajectory.trajectory_accuracy import ( + TrajectoryAccuracyGrader as _TrajectoryAccuracyGrader, # pyright: ignore[reportMissingImports] + ) + from openjudge.models.openai_chat_model import ( # pyright: ignore[reportMissingImports] + OpenAIChatModel, + ) + except ImportError as e: + raise ImportError( + "OpenJudge dependencies are not installed. " + "Please install with `pip install -e .[openjudge]`." + ) from e + + judge_base_url = os.getenv(judge_api_base_url_env, "") + if not judge_base_url: + raise ValueError(f"Judge base URL is missing. Set env `{judge_api_base_url_env}`.") + judge_model = OpenAIChatModel( + model=kwargs.get("judge_model_name", model_name), + base_url=judge_base_url, + api_key=os.getenv(judge_api_key_env, ""), + temperature=kwargs.get("temperature", 0.0), + ) + + super().__init__( + grader_configs={"trajectory": _TrajectoryAccuracyGrader(model=judge_model)}, + max_concurrency=kwargs.get("max_concurrency", 8), + judge_api_base_url_env=judge_api_base_url_env, + judge_api_key_env=judge_api_key_env, + ) + + self.reward_name = reward_name + self.normalize_score = normalize_score + self.score_min = float(score_min) + self.score_max = float(score_max) + + def __call__( # type: ignore[override] + self, + experience: Experience, + messages: List[Dict[str, Any]], + **kwargs, + ) -> Dict[str, float]: + return super().__call__(experience, messages, **kwargs) + + def _extract_reward(self, batch_results: Dict[str, Any]) -> Dict[str, float]: + reward_dict = super()._extract_reward(batch_results) + + if not self.normalize_score or self.score_max <= self.score_min: + return reward_dict + + raw = reward_dict.get("reward", 0.0) + normalized = (raw - self.score_min) / (self.score_max - self.score_min) + reward_dict["reward"] = max(0.0, min(1.0, normalized)) + return reward_dict diff --git a/trinity/common/rewards/reward_fn.py b/trinity/common/rewards/reward_fn.py index 90acfa340a1..411a397e755 100644 --- a/trinity/common/rewards/reward_fn.py +++ b/trinity/common/rewards/reward_fn.py @@ -1,10 +1,7 @@ # -*- coding: utf-8 -*- -"""Base Reward Function Class.""" +"""Base reward function classes.""" from abc import ABC, abstractmethod -from typing import Any, Dict, List - -from trinity.common.experience import Experience -from trinity.common.rewards.utils import to_rm_gallery_messages +from typing import Dict class RewardFn(ABC): @@ -18,84 +15,10 @@ def __init__(self, **kwargs) -> None: def __call__(self, **kwargs) -> Dict[str, float]: pass + async def acall(self, **kwargs) -> Dict[str, float]: + """Async reward entrypoint. -class RMGalleryFn(RewardFn): - """Reward Function from RMGallery. - https://github.com/modelscope/RM-Gallery - - TODO: Update to OpenJudgeFn - """ - - def __init__( - self, - reward_name, - **kwargs, - ): - from rm_gallery.core.reward.registry import RewardRegistry - - self.reward_model = RewardRegistry.get(reward_name)(**kwargs) - - def __call__(self, experience: Experience, messages: List[Dict[str, Any]], **kwargs) -> Dict[str, float]: # type: ignore - """Call the reward function.""" - - sample = self._build_sample_from_experience(experience, messages, **kwargs) - - sample_with_reward = self.reward_model.evaluate(sample, **kwargs) - - return self._extract_reward(sample_with_reward) - - def _build_sample_from_experience( - self, experience: Experience, messages: List[Dict[str, Any]], **kwargs - ) -> Any: - """Convert experience to sample. - Ref: https://github.com/modelscope/RM-Gallery/blob/main/rm_gallery/core/data/schema.py + Default implementation delegates to sync __call__ for backward compatibility. + Async-capable reward functions can override this to avoid blocking the event loop. """ - from rm_gallery.core.data.schema import DataOutput, DataSample, Step - - output = [ - DataOutput( - answer=Step( - role="assistant", - content=str(experience.response_text), - label={"reference": kwargs.get("ground_truth", "")}, - ), - ) - ] - - sample = DataSample( - unique_id=experience.eid.uid, - input=to_rm_gallery_messages(messages), - output=output, - metadata=experience.info, - ) - return sample - - def _extract_reward(self, sample: Any) -> Dict[str, float]: - """ - Extract reward from DataSample in rm-gallery - """ - reward_dict = {} - - try: - reward_obj = sample.output[0].answer.reward - except Exception as e: - raise ValueError(f"No reward is found in sample: {e}") - - from rm_gallery.core.reward.schema import ( - RewardDimensionWithRank, - RewardDimensionWithScore, - ) - - if reward_obj.details: - for detail in reward_obj.details: - if isinstance(detail, RewardDimensionWithScore): - reward_dict[detail.name] = detail.score - elif isinstance(detail, RewardDimensionWithRank): - # TODO: support multi-ranked dimension - if detail: - top_ranked_item = detail[0] - reward_dict[top_ranked_item.name] = top_ranked_item.score - else: - reward_dict["reward"] = reward_obj.score - - return reward_dict + return self.__call__(**kwargs) diff --git a/trinity/common/rewards/utils.py b/trinity/common/rewards/utils.py index 0b66e6700cb..1c0f2262ace 100644 --- a/trinity/common/rewards/utils.py +++ b/trinity/common/rewards/utils.py @@ -2,16 +2,16 @@ def to_rm_gallery_messages(messages: List[Dict[str, Any]]) -> Any: - """ - Converts string list to structured ChatMessage list for debugging. - - Args: - messages: List of alternating user/assistant messages + """Deprecated: was used by the removed RMGalleryFn. - Returns: - List of structured ChatMessage objects + Converts a list of ``{"role": ..., "content": ...}`` dicts to + rm_gallery ChatMessage objects. Kept for any external callers that + may still reference this helper; remove once confirmed unused. """ - from rm_gallery.core.model.message import ChatMessage, MessageRole + from rm_gallery.core.model.message import ( # pyright: ignore[reportMissingImports] + ChatMessage, + MessageRole, + ) role_map = { "system": MessageRole.SYSTEM, diff --git a/trinity/common/workflows/__init__.py b/trinity/common/workflows/__init__.py index fac7ba54e59..079ac4f1561 100644 --- a/trinity/common/workflows/__init__.py +++ b/trinity/common/workflows/__init__.py @@ -15,8 +15,8 @@ "async_math_boxed_workflow": "trinity.common.workflows.customized_math_workflows.AsyncMathBoxedWorkflow", "math_eval_workflow": "trinity.common.workflows.eval_workflow.MathEvalWorkflow", "async_math_eval_workflow": "trinity.common.workflows.eval_workflow.AsyncMathEvalWorkflow", - "math_rm_workflow": "trinity.common.workflows.math_rm_workflow.MathRMWorkflow", - "async_math_rm_workflow": "trinity.common.workflows.math_rm_workflow.AsyncMathRMWorkflow", + "math_openjudge_workflow": "trinity.common.workflows.math_openjudge_workflow.MathOpenJudgeWorkflow", + "async_math_openjudge_workflow": "trinity.common.workflows.math_openjudge_workflow.AsyncMathOpenJudgeWorkflow", # tool_call "tool_call_workflow": "trinity.common.workflows.customized_toolcall_workflows.ToolCallWorkflow", # agentscope diff --git a/trinity/common/workflows/math_rm_workflow.py b/trinity/common/workflows/math_openjudge_workflow.py similarity index 53% rename from trinity/common/workflows/math_rm_workflow.py rename to trinity/common/workflows/math_openjudge_workflow.py index 13e2fd4cb28..8ac7b90f6e4 100644 --- a/trinity/common/workflows/math_rm_workflow.py +++ b/trinity/common/workflows/math_openjudge_workflow.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -"""We include the math workflow with rm-gallery reward in this file.""" +"""We include the math workflow with OpenJudge reward in this file.""" from typing import List, Optional @@ -8,8 +8,8 @@ from trinity.common.workflows.workflow import SimpleWorkflow, Task -class MathRMWorkflow(SimpleWorkflow): - """A workflow for math tasks as introduced in DeepSeek-R1.""" +class AsyncMathOpenJudgeWorkflow(SimpleWorkflow): + is_async: bool = True def __init__( self, @@ -25,41 +25,13 @@ def __init__( auxiliary_models=auxiliary_models, ) - def run(self) -> List[Experience]: - messages = self.format_messages() - - self.logger.debug("start chat") - responses = self.model.chat(messages, **self.rollout_args) - for i, response in enumerate(responses): - reward_dict = self.reward_fn( # type: ignore - response, - messages, - ground_truth=self.truth, - ) - - if response.metrics is None: - response.metrics = {} - response.metrics.update(reward_dict) - reward = sum(reward_dict.values()) - response.reward = reward - response.eid.run = i + self.run_id_base - - self.logger.debug( - f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}" - ) - return responses - - -class AsyncMathRMWorkflow(MathRMWorkflow): - is_async: bool = True - async def run_async(self) -> List[Experience]: messages = self.format_messages() self.logger.debug("start chat") responses = await self.model.chat_async(messages, **self.rollout_args) for i, response in enumerate(responses): - reward_dict = self.reward_fn( # type: ignore + reward_dict = await self.reward_fn.acall( # type: ignore response, messages, ground_truth=self.truth, @@ -68,7 +40,7 @@ async def run_async(self) -> List[Experience]: if response.metrics is None: response.metrics = {} response.metrics.update(reward_dict) - reward = sum(reward_dict.values()) + reward = reward_dict.get("reward", sum(reward_dict.values())) response.reward = reward response.eid.run = i + self.run_id_base