From 946963252197dc7641df60a273d882845094196f Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Fri, 24 Oct 2025 15:11:16 +0800 Subject: [PATCH 01/10] add total_time and fix metric --- .../source/tutorial/example_mix_algo.md | 2 +- .../source_zh/tutorial/example_mix_algo.md | 2 +- trinity/algorithm/advantage_fn/advantage_fn.py | 5 +---- trinity/algorithm/advantage_fn/grpo_advantage.py | 10 +++++----- .../advantage_fn/multi_step_grpo_advantage.py | 7 ++----- .../sample_strategy/mix_sample_strategy.py | 4 ++-- trinity/explorer/explorer.py | 16 ++++++++++++++-- trinity/service/data_juicer/client.py | 2 ++ trinity/trainer/trainer.py | 4 ++++ trinity/utils/monitor.py | 2 ++ 10 files changed, 34 insertions(+), 20 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/example_mix_algo.md b/docs/sphinx_doc/source/tutorial/example_mix_algo.md index 56312a7b987..da08f753345 100644 --- a/docs/sphinx_doc/source/tutorial/example_mix_algo.md +++ b/docs/sphinx_doc/source/tutorial/example_mix_algo.md @@ -106,7 +106,7 @@ class MixSampleStrategy(SampleStrategy): async def sample(self, step: int) -> Tuple[Experiences, Dict, List]: metrics = {} - with Timer(metrics, "read_time"): + with Timer(metrics, "time/read_experience"): usual_exp_list = await self.usual_exp_buffer.read_async() for exp in usual_exp_list: if exp.info is None: diff --git a/docs/sphinx_doc/source_zh/tutorial/example_mix_algo.md b/docs/sphinx_doc/source_zh/tutorial/example_mix_algo.md index 4c0b575fd5a..6f640afca17 100644 --- a/docs/sphinx_doc/source_zh/tutorial/example_mix_algo.md +++ b/docs/sphinx_doc/source_zh/tutorial/example_mix_algo.md @@ -98,7 +98,7 @@ class MixSampleStrategy(SampleStrategy): async def sample(self, step: int) -> Tuple[Experiences, Dict, List]: metrics = {} - with Timer(metrics, "read_time"): + with Timer(metrics, "time/read_experience"): usual_exp_list = await self.usual_exp_buffer.read_async() for exp in usual_exp_list: if exp.info is None: diff --git a/trinity/algorithm/advantage_fn/advantage_fn.py b/trinity/algorithm/advantage_fn/advantage_fn.py index 31f3b20bb4d..6810feda72f 100644 --- a/trinity/algorithm/advantage_fn/advantage_fn.py +++ b/trinity/algorithm/advantage_fn/advantage_fn.py @@ -76,10 +76,7 @@ def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]: for group_id, group_exps in exp_groups.items(): group_exps, group_metrics = self.calculate_group_advantage(group_id, group_exps) metric_list.append(group_metrics) - try: - metrics = gather_metrics(metric_list, "group_advantages") - except ValueError: - metrics = {} # empty metric list causes ValueError, ignore it + metrics = gather_metrics(metric_list, "group_advantages") exps = [exp for group in exp_groups.values() for exp in group] # Flatten the list return exps, metrics diff --git a/trinity/algorithm/advantage_fn/grpo_advantage.py b/trinity/algorithm/advantage_fn/grpo_advantage.py index e60d05faff2..a438c028692 100644 --- a/trinity/algorithm/advantage_fn/grpo_advantage.py +++ b/trinity/algorithm/advantage_fn/grpo_advantage.py @@ -214,11 +214,11 @@ def process(self, exps): group_id, group_exps, precomputed_std=precomputed_std ) metric_list.append(group_metrics) - try: - # TODO: sum skipped count - metrics = gather_metrics(metric_list, "group_advantages") - except ValueError: - metrics = {} # empty metric list causes ValueError, ignore it + + # Update the filtered_count metric + filtered_count = sum(metric.pop("skipped_count", 0) for metric in metric_list) + metrics = gather_metrics(metric_list, "group_advantages") + metrics["filtered_count"] = filtered_count if self.duplicate_experiences and self.std_threshold is not None: exps = self._duplicate_experiences(exp_groups) else: diff --git a/trinity/algorithm/advantage_fn/multi_step_grpo_advantage.py b/trinity/algorithm/advantage_fn/multi_step_grpo_advantage.py index 19ed4529c03..a085ee1ae2b 100644 --- a/trinity/algorithm/advantage_fn/multi_step_grpo_advantage.py +++ b/trinity/algorithm/advantage_fn/multi_step_grpo_advantage.py @@ -142,11 +142,8 @@ def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]: cnt += len(exps) result_exps.extend(exps) - try: - metrics = gather_metrics(metric_list, "group_advantages") - metrics["experience_count"] = cnt - except ValueError: - metrics = {} # empty metric list causes ValueError, ignore it + metrics = gather_metrics(metric_list, "group_advantages") + metrics["experience_count"] = cnt return result_exps, metrics def __call__(self, exps, **kwargs): diff --git a/trinity/algorithm/sample_strategy/mix_sample_strategy.py b/trinity/algorithm/sample_strategy/mix_sample_strategy.py index 5e535a6d25b..2c213a12d8f 100644 --- a/trinity/algorithm/sample_strategy/mix_sample_strategy.py +++ b/trinity/algorithm/sample_strategy/mix_sample_strategy.py @@ -60,7 +60,7 @@ def __init__(self, buffer_config: BufferConfig, **kwargs): async def sample(self, step: int) -> Tuple[Experiences, Dict, List]: metrics = {} - with Timer(metrics, "read_time"): + with Timer(metrics, "time/read_experience"): usual_exp_list = await self.usual_exp_buffer.read_async() for exp in usual_exp_list: if exp.info is None: @@ -86,7 +86,7 @@ async def sample(self, step: int) -> Tuple[Experiences, Dict, List]: exp_list = usual_exp_list + expert_exp_list repr_samples = representative_sample(exp_list) - with Timer(metrics, "gather_time"): + with Timer(metrics, "time/gather_experience"): exps = Experiences.gather_experiences( experiences=exp_list, pad_token_id=self.pad_token_id, # type: ignore [arg-type] diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index a10b523af22..e8774adea9a 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -76,6 +76,8 @@ def __init__(self, config: Config): self.enable_lora = self.config.explorer.rollout_model.enable_lora self.model_version = -1 self.last_sync_successful = True + self.explore_start_time = None + self.eval_start_time = None self.logger.info("Finished initializing Explorer.") async def setup_weight_sync_group( @@ -207,6 +209,8 @@ async def explore(self) -> str: return self.config.explorer.name async def explore_step(self) -> bool: + if self.explore_start_time is None: + self.explore_start_time = time.time() try: tasks = await self.taskset.read_async() except StopAsyncIteration: @@ -252,6 +256,7 @@ def need_eval(self) -> bool: async def eval(self): """Evaluation on all evaluation data samples.""" + self.eval_start_time = time.time() if len(self.config.buffer.explorer_input.eval_tasksets) == 0: self.logger.warning("No evaluation data samples. Skip evaluation.") return @@ -338,6 +343,12 @@ async def _finish_steps(self, start_step: int, end_step: int, model_version: int await self._finish_explore_step(step=step, model_version=model_version) await self._finish_eval_step(step=step) + # Record the time: read_task + explore_step (>=1) + eval (if any) + if self.explore_start_time is not None: + metric = {"time/explore_total_time": time.time() - self.explore_start_time} + self.explore_start_time = None + self.monitor.log(metric, step=end_step) + async def _finish_explore_step(self, step: int, model_version: int) -> None: statuses, exps = await self.scheduler.get_results(batch_id=step) metric = {"rollout/model_version": model_version} @@ -351,7 +362,6 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva if not self.pending_eval_tasks: return step = step or self.explore_step_num - st = time.time() metric = {} while self.pending_eval_tasks: eval_step, eval_task_name = self.pending_eval_tasks[0] @@ -364,7 +374,9 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva [status.metric for status in eval_results], f"{prefix}/{eval_task_name}" ) ) - metric[f"{prefix}/total_time"] = time.time() - st + if self.eval_start_time is not None: + metric.update({"time/eval": time.time() - self.eval_start_time}) + self.eval_start_time = None self.monitor.log(metric, step) async def shutdown(self) -> None: diff --git a/trinity/service/data_juicer/client.py b/trinity/service/data_juicer/client.py index 44238eab607..2ff5dc41b61 100644 --- a/trinity/service/data_juicer/client.py +++ b/trinity/service/data_juicer/client.py @@ -77,6 +77,7 @@ def initialize(self, config: dict): self.session_id = response.json().get("session_id") def process_experience(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]: + st = time.time() if not self.session_id: raise ValueError("DataJuicer session is not initialized.") @@ -101,6 +102,7 @@ def process_experience(self, exps: List[Experience]) -> Tuple[List[Experience], exp.info = {} for stats_key in sample["__dj__stats__"]: exp.info[stats_key] = sample["__dj__stats__"][stats_key] + metrics["time/dj_process_experience"] = time.time() - st return exps, metrics def process_task(self) -> Dict: diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index d9036e94861..797dc114f0c 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -5,6 +5,7 @@ from __future__ import annotations import asyncio +import time import traceback from abc import ABC, abstractmethod from typing import Dict, List, Tuple @@ -67,6 +68,7 @@ async def train(self) -> str: """Train the model.""" while self.train_step_num < self.total_steps: try: + st = time.time() # sample may be blocked due to explorer does not generate enough data self.logger.info(f"Sample data for step {self.train_step_num + 1} started.") sample_task = asyncio.create_task(self._sample_data()) @@ -80,6 +82,8 @@ async def train(self) -> str: self.logger.info(f"Sample data for step {self.train_step_num + 1} finished.") metrics.update(await self.train_step(exps)) if await self.need_sync(): + # Record the time: sample_experience + train_step (>=1) + metrics.update({"time/train_total_time": time.time() - st}) metrics.update(await self.sync_weight()) if self.need_save(): metrics.update(self.save_checkpoint()) diff --git a/trinity/utils/monitor.py b/trinity/utils/monitor.py index 880d035f836..ab799426ff9 100644 --- a/trinity/utils/monitor.py +++ b/trinity/utils/monitor.py @@ -26,6 +26,8 @@ def gather_metrics(metric_list: List[Dict], prefix: str) -> Dict: + if not metric_list: + return {} # empty metric list causes ValueError, return empty dict df = pd.DataFrame(metric_list) numeric_df = df.select_dtypes(include=[np.number]) stats_df = numeric_df.agg(["mean", "max", "min"]) From c130278ef4cd4e284b4fb3d772685e8cacc4219a Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Fri, 24 Oct 2025 16:58:04 +0800 Subject: [PATCH 02/10] fix comments and add explore_step --- .../source/tutorial/example_mix_algo.md | 2 +- .../source_zh/tutorial/example_mix_algo.md | 2 +- trinity/explorer/explorer.py | 10 ++++++-- trinity/trainer/trainer.py | 2 +- trinity/utils/monitor.py | 23 +++++++++++-------- 5 files changed, 24 insertions(+), 15 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/example_mix_algo.md b/docs/sphinx_doc/source/tutorial/example_mix_algo.md index da08f753345..d95aac45466 100644 --- a/docs/sphinx_doc/source/tutorial/example_mix_algo.md +++ b/docs/sphinx_doc/source/tutorial/example_mix_algo.md @@ -132,7 +132,7 @@ class MixSampleStrategy(SampleStrategy): exp_list = usual_exp_list + expert_exp_list repr_samples = representative_sample(exp_list) - with Timer(metrics, "gather_time"): + with Timer(metrics, "time/gather_experience"): exps = Experiences.gather_experiences( experiences=exp_list, pad_token_id=self.pad_token_id, # type: ignore [arg-type] diff --git a/docs/sphinx_doc/source_zh/tutorial/example_mix_algo.md b/docs/sphinx_doc/source_zh/tutorial/example_mix_algo.md index 6f640afca17..f6f4be9810f 100644 --- a/docs/sphinx_doc/source_zh/tutorial/example_mix_algo.md +++ b/docs/sphinx_doc/source_zh/tutorial/example_mix_algo.md @@ -124,7 +124,7 @@ class MixSampleStrategy(SampleStrategy): exp_list = usual_exp_list + expert_exp_list repr_samples = representative_sample(exp_list) - with Timer(metrics, "gather_time"): + with Timer(metrics, "time/gather_experience"): exps = Experiences.gather_experiences( experiences=exp_list, pad_token_id=self.pad_token_id, # type: ignore [arg-type] diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 1b4155a61fc..5dd30814ea5 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -74,8 +74,9 @@ def __init__(self, config: Config): self.enable_lora = self.config.explorer.rollout_model.enable_lora self.model_version = -1 self.last_sync_successful = True - self.explore_start_time = None self.eval_start_time = None + self.explore_start_time = None + self.explore_step_start_time = dict() # {step_num: start_time} self.logger.info("Finished initializing Explorer.") async def setup_weight_sync_group( @@ -222,6 +223,7 @@ async def explore_step(self) -> bool: ) await self.shutdown() return False + self.explore_step_start_time.update({self.explore_step_num + 1: time.time()}) self.scheduler.schedule(tasks, batch_id=self.explore_step_num + 1) self.explore_step_num += 1 return True @@ -343,7 +345,7 @@ async def _finish_steps(self, start_step: int, end_step: int, model_version: int # Record the time: read_task + explore_step (>=1) + eval (if any) if self.explore_start_time is not None: - metric = {"time/explore_total_time": time.time() - self.explore_start_time} + metric = {"time/explorer_sync_interval": time.time() - self.explore_start_time} self.explore_start_time = None self.monitor.log(metric, step=end_step) @@ -353,6 +355,10 @@ async def _finish_explore_step(self, step: int, model_version: int) -> None: pipeline_metrics = await self.experience_pipeline.process.remote(exps) self.taskset.update(pipeline_metrics) metric.update(pipeline_metrics) + explore_step_start_time = self.explore_step_start_time.get(step, None) + if explore_step_start_time is not None: + metric.update({"time/explore_step": time.time() - explore_step_start_time}) + self.explore_step_start_time.pop(step) if statuses: metric.update(gather_metrics([status.metric for status in statuses], "rollout")) self.monitor.log(metric, step=step) diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index 797dc114f0c..a0e455d50ef 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -83,7 +83,7 @@ async def train(self) -> str: metrics.update(await self.train_step(exps)) if await self.need_sync(): # Record the time: sample_experience + train_step (>=1) - metrics.update({"time/train_total_time": time.time() - st}) + metrics.update({"time/trainer_sync_interval": time.time() - st}) metrics.update(await self.sync_weight()) if self.need_save(): metrics.update(self.save_checkpoint()) diff --git a/trinity/utils/monitor.py b/trinity/utils/monitor.py index ab799426ff9..f0e847f4d2d 100644 --- a/trinity/utils/monitor.py +++ b/trinity/utils/monitor.py @@ -27,16 +27,19 @@ def gather_metrics(metric_list: List[Dict], prefix: str) -> Dict: if not metric_list: - return {} # empty metric list causes ValueError, return empty dict - df = pd.DataFrame(metric_list) - numeric_df = df.select_dtypes(include=[np.number]) - stats_df = numeric_df.agg(["mean", "max", "min"]) - metric = {} - for col in stats_df.columns: - metric[f"{prefix}/{col}/mean"] = stats_df.loc["mean", col].item() - metric[f"{prefix}/{col}/max"] = stats_df.loc["max", col].item() - metric[f"{prefix}/{col}/min"] = stats_df.loc["min", col].item() - return metric + return {} + try: + df = pd.DataFrame(metric_list) + numeric_df = df.select_dtypes(include=[np.number]) + stats_df = numeric_df.agg(["mean", "max", "min"]) + metric = {} + for col in stats_df.columns: + metric[f"{prefix}/{col}/mean"] = stats_df.loc["mean", col].item() + metric[f"{prefix}/{col}/max"] = stats_df.loc["max", col].item() + metric[f"{prefix}/{col}/min"] = stats_df.loc["min", col].item() + return metric + except Exception as e: + raise ValueError(f"Failed to gather metrics: {e}") from e class Monitor(ABC): From e5624fdcab09748c2f59f92574153d04c94b0172 Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Fri, 24 Oct 2025 17:29:46 +0800 Subject: [PATCH 03/10] replace prefix timing_s to time --- trinity/algorithm/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/trinity/algorithm/utils.py b/trinity/algorithm/utils.py index a2898c094cb..f71835f00e7 100644 --- a/trinity/algorithm/utils.py +++ b/trinity/algorithm/utils.py @@ -102,4 +102,7 @@ def prefix_metrics(src_metrics: dict, prefix: str, dst_metrics: dict = None) -> dst_metrics = {} for k, v in src_metrics.items(): dst_metrics[f"{prefix}/{k}"] = v + + # unify the time metrics + dst_metrics = {k.replace("timing_s/", "time/"): v for k, v in dst_metrics.items()} return dst_metrics From 1e8f46c77d5f782a85928f8453b396b5036d3b5d Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Fri, 24 Oct 2025 17:46:42 +0800 Subject: [PATCH 04/10] fix comment --- trinity/algorithm/utils.py | 2 -- trinity/trainer/verl_trainer.py | 3 ++- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/trinity/algorithm/utils.py b/trinity/algorithm/utils.py index f71835f00e7..271cc00352e 100644 --- a/trinity/algorithm/utils.py +++ b/trinity/algorithm/utils.py @@ -103,6 +103,4 @@ def prefix_metrics(src_metrics: dict, prefix: str, dst_metrics: dict = None) -> for k, v in src_metrics.items(): dst_metrics[f"{prefix}/{k}"] = v - # unify the time metrics - dst_metrics = {k.replace("timing_s/", "time/"): v for k, v in dst_metrics.items()} return dst_metrics diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 54dfba2b7a5..bc879e6afc1 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -472,7 +472,8 @@ def train_step(self, batch: Experiences) -> Dict: # noqa C901 # collect metrics metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) - metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + timing_metrics = compute_timing_metrics(batch=batch, timing_raw=timing_raw) + metrics.update({k.replace("timing_s/", "time/"): v for k, v in timing_metrics.items()}) n_gpus = self.resource_pool_manager.get_n_gpus() metrics.update( compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus) From 1de687fd2c806de7722940bc23eeb0cb86a77c01 Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Mon, 27 Oct 2025 10:03:46 +0800 Subject: [PATCH 05/10] fix test --- tests/algorithm/advantage_fn_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/algorithm/advantage_fn_test.py b/tests/algorithm/advantage_fn_test.py index 70a15ddd76e..7a3297ff29f 100644 --- a/tests/algorithm/advantage_fn_test.py +++ b/tests/algorithm/advantage_fn_test.py @@ -107,8 +107,8 @@ def test_grpo_reward_std(self): exps, metrics = advantage_fn(exps) self.assertEqual(len(exps), 0) - self.assertIn("group_advantages/skipped_count/mean", metrics) - self.assertEqual(metrics["group_advantages/skipped_count/mean"], 5) + self.assertIn("group_advantages/filtered_count/mean", metrics) + self.assertEqual(metrics["group_advantages/filtered_count/mean"], 5) def test_grpo_correct_bias(self): advantage_fn_cls = ADVANTAGE_FN.get("grpo") From 53ea0eba68a9efee9fb37916f4578e27ec88e08e Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Mon, 27 Oct 2025 10:22:30 +0800 Subject: [PATCH 06/10] fix test --- tests/algorithm/advantage_fn_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/algorithm/advantage_fn_test.py b/tests/algorithm/advantage_fn_test.py index 7a3297ff29f..5d03ca55c1b 100644 --- a/tests/algorithm/advantage_fn_test.py +++ b/tests/algorithm/advantage_fn_test.py @@ -107,8 +107,8 @@ def test_grpo_reward_std(self): exps, metrics = advantage_fn(exps) self.assertEqual(len(exps), 0) - self.assertIn("group_advantages/filtered_count/mean", metrics) - self.assertEqual(metrics["group_advantages/filtered_count/mean"], 5) + self.assertIn("filtered_count", metrics) + self.assertEqual(metrics["filtered_count"], 15) def test_grpo_correct_bias(self): advantage_fn_cls = ADVANTAGE_FN.get("grpo") From abab7605b7cb83ffa72de201c97f19a91a537c09 Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Mon, 27 Oct 2025 11:43:54 +0800 Subject: [PATCH 07/10] add time/wait_explore_step --- trinity/explorer/explorer.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 5dd30814ea5..f6ae53c1d36 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -32,6 +32,7 @@ from trinity.utils.log import get_logger from trinity.utils.monitor import MONITOR, gather_metrics from trinity.utils.plugin_loader import load_plugins +from trinity.utils.timer import Timer class Explorer: @@ -223,7 +224,6 @@ async def explore_step(self) -> bool: ) await self.shutdown() return False - self.explore_step_start_time.update({self.explore_step_num + 1: time.time()}) self.scheduler.schedule(tasks, batch_id=self.explore_step_num + 1) self.explore_step_num += 1 return True @@ -350,15 +350,12 @@ async def _finish_steps(self, start_step: int, end_step: int, model_version: int self.monitor.log(metric, step=end_step) async def _finish_explore_step(self, step: int, model_version: int) -> None: - statuses, exps = await self.scheduler.get_results(batch_id=step) metric = {"rollout/model_version": model_version} + with Timer(metric, "time/wait_explore_step"): + statuses, exps = await self.scheduler.get_results(batch_id=step) pipeline_metrics = await self.experience_pipeline.process.remote(exps) self.taskset.update(pipeline_metrics) metric.update(pipeline_metrics) - explore_step_start_time = self.explore_step_start_time.get(step, None) - if explore_step_start_time is not None: - metric.update({"time/explore_step": time.time() - explore_step_start_time}) - self.explore_step_start_time.pop(step) if statuses: metric.update(gather_metrics([status.metric for status in statuses], "rollout")) self.monitor.log(metric, step=step) From dcbe8a49b1e6c427afebf2405fc15b395b9a16a5 Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Mon, 27 Oct 2025 11:54:03 +0800 Subject: [PATCH 08/10] fix comment --- trinity/explorer/explorer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index f6ae53c1d36..77391f7f946 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -77,7 +77,6 @@ def __init__(self, config: Config): self.last_sync_successful = True self.eval_start_time = None self.explore_start_time = None - self.explore_step_start_time = dict() # {step_num: start_time} self.logger.info("Finished initializing Explorer.") async def setup_weight_sync_group( From 4f6af70892e6ceaaabce17d8a7876c959a17599f Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Mon, 27 Oct 2025 14:25:44 +0800 Subject: [PATCH 09/10] fix experience_pipeline time --- .../buffer/pipelines/experience_pipeline.py | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/trinity/buffer/pipelines/experience_pipeline.py b/trinity/buffer/pipelines/experience_pipeline.py index ca08b758c51..b66573718af 100644 --- a/trinity/buffer/pipelines/experience_pipeline.py +++ b/trinity/buffer/pipelines/experience_pipeline.py @@ -1,3 +1,4 @@ +import time import traceback from typing import Dict, List, Optional @@ -15,6 +16,7 @@ from trinity.common.experience import Experience from trinity.utils.log import get_logger from trinity.utils.plugin_loader import load_plugins +from trinity.utils.timer import Timer def get_input_buffers( @@ -112,25 +114,32 @@ async def process(self, exps: List[Experience]) -> Dict: Returns: Dict: A dictionary containing metrics collected during the processing of experiences. """ + st = time.time() if self.input_store is not None: await self.input_store.write_async(exps) metrics = {} # Process experiences through operators - for operator in self.operators: - exps, metric = operator.process(exps) - metrics.update(metric) - + for idx, operator in enumerate(self.operators): + with Timer( + metrics, f"time/experience_pipeline/operator/{idx}_{operator.__class__.__name__}" + ): + exps, metric = operator.process(exps) + metrics.update(metric) metrics["experience_count"] = len(exps) # Write processed experiences to output buffer - await self.output.write_async(exps) + with Timer(metrics, "time/experience_pipeline/write"): + await self.output.write_async(exps) + metrics["time/experience_pipeline/total"] = time.time() - st # prefix metrics keys with 'pipeline/' result_metrics = {} for key, value in metrics.items(): - if isinstance(value, (int, float)): + if key.startswith("time/"): + result_metrics[key] = value + elif isinstance(value, (int, float)): result_metrics[f"pipeline/{key}"] = float(value) if SELECTOR_METRIC in metrics: result_metrics[SELECTOR_METRIC] = metrics[SELECTOR_METRIC] From 38b2ed870a5656a5a44ccfd40e5b50764d38b52f Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Mon, 27 Oct 2025 14:43:43 +0800 Subject: [PATCH 10/10] use experience_pipeline as prefix --- tests/buffer/experience_pipeline_test.py | 2 +- tests/explorer/explorer_test.py | 4 ++-- tests/trainer/trainer_test.py | 2 +- trinity/buffer/pipelines/experience_pipeline.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/buffer/experience_pipeline_test.py b/tests/buffer/experience_pipeline_test.py index 62d4df4ebac..af506f6781a 100644 --- a/tests/buffer/experience_pipeline_test.py +++ b/tests/buffer/experience_pipeline_test.py @@ -67,7 +67,7 @@ async def test_experience_pipeline(self): experiences = get_experiences(task_num=task_num, repeat_times=repeat_times) metrics = await pipeline.process.remote(experiences) self.assertEqual( - metrics["pipeline/experience_count"], task_num * (repeat_times - 1) + metrics["experience_pipeline/experience_count"], task_num * (repeat_times - 1) ) # first experience of each task will be filtered out by the reward filter # tests diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index 27f5b0d4555..dfce872240e 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -109,8 +109,8 @@ def test_explorer(self): eval_metrics = parser.metric_list("eval") self.assertTrue(len(eval_metrics) == 0) self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4) - self.assertTrue(parser.metric_exist("pipeline/experience_count")) - experience_counts = parser.metric_values("pipeline/experience_count") + self.assertTrue(parser.metric_exist("experience_pipeline/experience_count")) + experience_counts = parser.metric_values("experience_pipeline/experience_count") self.assertTrue(len(experience_counts) == 4) for count in experience_counts: self.assertTrue(count >= 0) diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 8815c3a1327..8b8b58467b6 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -819,7 +819,7 @@ def test_trainer(self): self.assertTrue(len(rollout_metrics) > 0) self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4) self.assertEqual( - parser.metric_values("pipeline/experience_count")[1], 16 + parser.metric_values("experience_pipeline/experience_count")[1], 16 ) # 16 rft experiences # test actor metrics actor_metrics = parser.metric_list("actor") diff --git a/trinity/buffer/pipelines/experience_pipeline.py b/trinity/buffer/pipelines/experience_pipeline.py index b66573718af..5d5fba59a48 100644 --- a/trinity/buffer/pipelines/experience_pipeline.py +++ b/trinity/buffer/pipelines/experience_pipeline.py @@ -140,7 +140,7 @@ async def process(self, exps: List[Experience]) -> Dict: if key.startswith("time/"): result_metrics[key] = value elif isinstance(value, (int, float)): - result_metrics[f"pipeline/{key}"] = float(value) + result_metrics[f"experience_pipeline/{key}"] = float(value) if SELECTOR_METRIC in metrics: result_metrics[SELECTOR_METRIC] = metrics[SELECTOR_METRIC]