diff --git a/docs/sphinx_doc/source/tutorial/example_mix_algo.md b/docs/sphinx_doc/source/tutorial/example_mix_algo.md index 56312a7b987..d95aac45466 100644 --- a/docs/sphinx_doc/source/tutorial/example_mix_algo.md +++ b/docs/sphinx_doc/source/tutorial/example_mix_algo.md @@ -106,7 +106,7 @@ class MixSampleStrategy(SampleStrategy): async def sample(self, step: int) -> Tuple[Experiences, Dict, List]: metrics = {} - with Timer(metrics, "read_time"): + with Timer(metrics, "time/read_experience"): usual_exp_list = await self.usual_exp_buffer.read_async() for exp in usual_exp_list: if exp.info is None: @@ -132,7 +132,7 @@ class MixSampleStrategy(SampleStrategy): exp_list = usual_exp_list + expert_exp_list repr_samples = representative_sample(exp_list) - with Timer(metrics, "gather_time"): + with Timer(metrics, "time/gather_experience"): exps = Experiences.gather_experiences( experiences=exp_list, pad_token_id=self.pad_token_id, # type: ignore [arg-type] diff --git a/docs/sphinx_doc/source_zh/tutorial/example_mix_algo.md b/docs/sphinx_doc/source_zh/tutorial/example_mix_algo.md index 4c0b575fd5a..f6f4be9810f 100644 --- a/docs/sphinx_doc/source_zh/tutorial/example_mix_algo.md +++ b/docs/sphinx_doc/source_zh/tutorial/example_mix_algo.md @@ -98,7 +98,7 @@ class MixSampleStrategy(SampleStrategy): async def sample(self, step: int) -> Tuple[Experiences, Dict, List]: metrics = {} - with Timer(metrics, "read_time"): + with Timer(metrics, "time/read_experience"): usual_exp_list = await self.usual_exp_buffer.read_async() for exp in usual_exp_list: if exp.info is None: @@ -124,7 +124,7 @@ class MixSampleStrategy(SampleStrategy): exp_list = usual_exp_list + expert_exp_list repr_samples = representative_sample(exp_list) - with Timer(metrics, "gather_time"): + with Timer(metrics, "time/gather_experience"): exps = Experiences.gather_experiences( experiences=exp_list, pad_token_id=self.pad_token_id, # type: ignore [arg-type] diff --git a/tests/algorithm/advantage_fn_test.py b/tests/algorithm/advantage_fn_test.py index 70a15ddd76e..5d03ca55c1b 100644 --- a/tests/algorithm/advantage_fn_test.py +++ b/tests/algorithm/advantage_fn_test.py @@ -107,8 +107,8 @@ def test_grpo_reward_std(self): exps, metrics = advantage_fn(exps) self.assertEqual(len(exps), 0) - self.assertIn("group_advantages/skipped_count/mean", metrics) - self.assertEqual(metrics["group_advantages/skipped_count/mean"], 5) + self.assertIn("filtered_count", metrics) + self.assertEqual(metrics["filtered_count"], 15) def test_grpo_correct_bias(self): advantage_fn_cls = ADVANTAGE_FN.get("grpo") diff --git a/tests/buffer/experience_pipeline_test.py b/tests/buffer/experience_pipeline_test.py index 62d4df4ebac..af506f6781a 100644 --- a/tests/buffer/experience_pipeline_test.py +++ b/tests/buffer/experience_pipeline_test.py @@ -67,7 +67,7 @@ async def test_experience_pipeline(self): experiences = get_experiences(task_num=task_num, repeat_times=repeat_times) metrics = await pipeline.process.remote(experiences) self.assertEqual( - metrics["pipeline/experience_count"], task_num * (repeat_times - 1) + metrics["experience_pipeline/experience_count"], task_num * (repeat_times - 1) ) # first experience of each task will be filtered out by the reward filter # tests diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index 27f5b0d4555..dfce872240e 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -109,8 +109,8 @@ def test_explorer(self): eval_metrics = parser.metric_list("eval") self.assertTrue(len(eval_metrics) == 0) self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4) - self.assertTrue(parser.metric_exist("pipeline/experience_count")) - experience_counts = parser.metric_values("pipeline/experience_count") + self.assertTrue(parser.metric_exist("experience_pipeline/experience_count")) + experience_counts = parser.metric_values("experience_pipeline/experience_count") self.assertTrue(len(experience_counts) == 4) for count in experience_counts: self.assertTrue(count >= 0) diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 8815c3a1327..8b8b58467b6 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -819,7 +819,7 @@ def test_trainer(self): self.assertTrue(len(rollout_metrics) > 0) self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4) self.assertEqual( - parser.metric_values("pipeline/experience_count")[1], 16 + parser.metric_values("experience_pipeline/experience_count")[1], 16 ) # 16 rft experiences # test actor metrics actor_metrics = parser.metric_list("actor") diff --git a/trinity/algorithm/advantage_fn/advantage_fn.py b/trinity/algorithm/advantage_fn/advantage_fn.py index 31f3b20bb4d..6810feda72f 100644 --- a/trinity/algorithm/advantage_fn/advantage_fn.py +++ b/trinity/algorithm/advantage_fn/advantage_fn.py @@ -76,10 +76,7 @@ def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]: for group_id, group_exps in exp_groups.items(): group_exps, group_metrics = self.calculate_group_advantage(group_id, group_exps) metric_list.append(group_metrics) - try: - metrics = gather_metrics(metric_list, "group_advantages") - except ValueError: - metrics = {} # empty metric list causes ValueError, ignore it + metrics = gather_metrics(metric_list, "group_advantages") exps = [exp for group in exp_groups.values() for exp in group] # Flatten the list return exps, metrics diff --git a/trinity/algorithm/advantage_fn/grpo_advantage.py b/trinity/algorithm/advantage_fn/grpo_advantage.py index e60d05faff2..a438c028692 100644 --- a/trinity/algorithm/advantage_fn/grpo_advantage.py +++ b/trinity/algorithm/advantage_fn/grpo_advantage.py @@ -214,11 +214,11 @@ def process(self, exps): group_id, group_exps, precomputed_std=precomputed_std ) metric_list.append(group_metrics) - try: - # TODO: sum skipped count - metrics = gather_metrics(metric_list, "group_advantages") - except ValueError: - metrics = {} # empty metric list causes ValueError, ignore it + + # Update the filtered_count metric + filtered_count = sum(metric.pop("skipped_count", 0) for metric in metric_list) + metrics = gather_metrics(metric_list, "group_advantages") + metrics["filtered_count"] = filtered_count if self.duplicate_experiences and self.std_threshold is not None: exps = self._duplicate_experiences(exp_groups) else: diff --git a/trinity/algorithm/advantage_fn/multi_step_grpo_advantage.py b/trinity/algorithm/advantage_fn/multi_step_grpo_advantage.py index 19ed4529c03..a085ee1ae2b 100644 --- a/trinity/algorithm/advantage_fn/multi_step_grpo_advantage.py +++ b/trinity/algorithm/advantage_fn/multi_step_grpo_advantage.py @@ -142,11 +142,8 @@ def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]: cnt += len(exps) result_exps.extend(exps) - try: - metrics = gather_metrics(metric_list, "group_advantages") - metrics["experience_count"] = cnt - except ValueError: - metrics = {} # empty metric list causes ValueError, ignore it + metrics = gather_metrics(metric_list, "group_advantages") + metrics["experience_count"] = cnt return result_exps, metrics def __call__(self, exps, **kwargs): diff --git a/trinity/algorithm/sample_strategy/mix_sample_strategy.py b/trinity/algorithm/sample_strategy/mix_sample_strategy.py index 5e535a6d25b..2c213a12d8f 100644 --- a/trinity/algorithm/sample_strategy/mix_sample_strategy.py +++ b/trinity/algorithm/sample_strategy/mix_sample_strategy.py @@ -60,7 +60,7 @@ def __init__(self, buffer_config: BufferConfig, **kwargs): async def sample(self, step: int) -> Tuple[Experiences, Dict, List]: metrics = {} - with Timer(metrics, "read_time"): + with Timer(metrics, "time/read_experience"): usual_exp_list = await self.usual_exp_buffer.read_async() for exp in usual_exp_list: if exp.info is None: @@ -86,7 +86,7 @@ async def sample(self, step: int) -> Tuple[Experiences, Dict, List]: exp_list = usual_exp_list + expert_exp_list repr_samples = representative_sample(exp_list) - with Timer(metrics, "gather_time"): + with Timer(metrics, "time/gather_experience"): exps = Experiences.gather_experiences( experiences=exp_list, pad_token_id=self.pad_token_id, # type: ignore [arg-type] diff --git a/trinity/algorithm/utils.py b/trinity/algorithm/utils.py index a2898c094cb..271cc00352e 100644 --- a/trinity/algorithm/utils.py +++ b/trinity/algorithm/utils.py @@ -102,4 +102,5 @@ def prefix_metrics(src_metrics: dict, prefix: str, dst_metrics: dict = None) -> dst_metrics = {} for k, v in src_metrics.items(): dst_metrics[f"{prefix}/{k}"] = v + return dst_metrics diff --git a/trinity/buffer/pipelines/experience_pipeline.py b/trinity/buffer/pipelines/experience_pipeline.py index ca08b758c51..5d5fba59a48 100644 --- a/trinity/buffer/pipelines/experience_pipeline.py +++ b/trinity/buffer/pipelines/experience_pipeline.py @@ -1,3 +1,4 @@ +import time import traceback from typing import Dict, List, Optional @@ -15,6 +16,7 @@ from trinity.common.experience import Experience from trinity.utils.log import get_logger from trinity.utils.plugin_loader import load_plugins +from trinity.utils.timer import Timer def get_input_buffers( @@ -112,26 +114,33 @@ async def process(self, exps: List[Experience]) -> Dict: Returns: Dict: A dictionary containing metrics collected during the processing of experiences. """ + st = time.time() if self.input_store is not None: await self.input_store.write_async(exps) metrics = {} # Process experiences through operators - for operator in self.operators: - exps, metric = operator.process(exps) - metrics.update(metric) - + for idx, operator in enumerate(self.operators): + with Timer( + metrics, f"time/experience_pipeline/operator/{idx}_{operator.__class__.__name__}" + ): + exps, metric = operator.process(exps) + metrics.update(metric) metrics["experience_count"] = len(exps) # Write processed experiences to output buffer - await self.output.write_async(exps) + with Timer(metrics, "time/experience_pipeline/write"): + await self.output.write_async(exps) + metrics["time/experience_pipeline/total"] = time.time() - st # prefix metrics keys with 'pipeline/' result_metrics = {} for key, value in metrics.items(): - if isinstance(value, (int, float)): - result_metrics[f"pipeline/{key}"] = float(value) + if key.startswith("time/"): + result_metrics[key] = value + elif isinstance(value, (int, float)): + result_metrics[f"experience_pipeline/{key}"] = float(value) if SELECTOR_METRIC in metrics: result_metrics[SELECTOR_METRIC] = metrics[SELECTOR_METRIC] diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index b80807213f8..77391f7f946 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -32,6 +32,7 @@ from trinity.utils.log import get_logger from trinity.utils.monitor import MONITOR, gather_metrics from trinity.utils.plugin_loader import load_plugins +from trinity.utils.timer import Timer class Explorer: @@ -74,6 +75,8 @@ def __init__(self, config: Config): self.enable_lora = self.config.explorer.rollout_model.enable_lora self.model_version = -1 self.last_sync_successful = True + self.eval_start_time = None + self.explore_start_time = None self.logger.info("Finished initializing Explorer.") async def setup_weight_sync_group( @@ -205,6 +208,8 @@ async def explore(self) -> str: return self.config.explorer.name async def explore_step(self) -> bool: + if self.explore_start_time is None: + self.explore_start_time = time.time() try: tasks = await self.taskset.read_async() except StopAsyncIteration: @@ -250,6 +255,7 @@ def need_eval(self) -> bool: async def eval(self): """Evaluation on all evaluation data samples.""" + self.eval_start_time = time.time() if len(self.config.buffer.explorer_input.eval_tasksets) == 0: self.logger.warning("No evaluation data samples. Skip evaluation.") return @@ -336,9 +342,16 @@ async def _finish_steps(self, start_step: int, end_step: int, model_version: int await self._finish_explore_step(step=step, model_version=model_version) await self._finish_eval_step(step=step) + # Record the time: read_task + explore_step (>=1) + eval (if any) + if self.explore_start_time is not None: + metric = {"time/explorer_sync_interval": time.time() - self.explore_start_time} + self.explore_start_time = None + self.monitor.log(metric, step=end_step) + async def _finish_explore_step(self, step: int, model_version: int) -> None: - statuses, exps = await self.scheduler.get_results(batch_id=step) metric = {"rollout/model_version": model_version} + with Timer(metric, "time/wait_explore_step"): + statuses, exps = await self.scheduler.get_results(batch_id=step) pipeline_metrics = await self.experience_pipeline.process.remote(exps) self.taskset.update(pipeline_metrics) metric.update(pipeline_metrics) @@ -350,7 +363,6 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva if not self.pending_eval_tasks: return step = step or self.explore_step_num - st = time.time() metric = {} while self.pending_eval_tasks: eval_step, eval_task_name = self.pending_eval_tasks[0] @@ -363,7 +375,9 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva [status.metric for status in eval_results], f"{prefix}/{eval_task_name}" ) ) - metric[f"{prefix}/total_time"] = time.time() - st + if self.eval_start_time is not None: + metric.update({"time/eval": time.time() - self.eval_start_time}) + self.eval_start_time = None self.monitor.log(metric, step) async def shutdown(self) -> None: diff --git a/trinity/service/data_juicer/client.py b/trinity/service/data_juicer/client.py index 44238eab607..2ff5dc41b61 100644 --- a/trinity/service/data_juicer/client.py +++ b/trinity/service/data_juicer/client.py @@ -77,6 +77,7 @@ def initialize(self, config: dict): self.session_id = response.json().get("session_id") def process_experience(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]: + st = time.time() if not self.session_id: raise ValueError("DataJuicer session is not initialized.") @@ -101,6 +102,7 @@ def process_experience(self, exps: List[Experience]) -> Tuple[List[Experience], exp.info = {} for stats_key in sample["__dj__stats__"]: exp.info[stats_key] = sample["__dj__stats__"][stats_key] + metrics["time/dj_process_experience"] = time.time() - st return exps, metrics def process_task(self) -> Dict: diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index d9036e94861..a0e455d50ef 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -5,6 +5,7 @@ from __future__ import annotations import asyncio +import time import traceback from abc import ABC, abstractmethod from typing import Dict, List, Tuple @@ -67,6 +68,7 @@ async def train(self) -> str: """Train the model.""" while self.train_step_num < self.total_steps: try: + st = time.time() # sample may be blocked due to explorer does not generate enough data self.logger.info(f"Sample data for step {self.train_step_num + 1} started.") sample_task = asyncio.create_task(self._sample_data()) @@ -80,6 +82,8 @@ async def train(self) -> str: self.logger.info(f"Sample data for step {self.train_step_num + 1} finished.") metrics.update(await self.train_step(exps)) if await self.need_sync(): + # Record the time: sample_experience + train_step (>=1) + metrics.update({"time/trainer_sync_interval": time.time() - st}) metrics.update(await self.sync_weight()) if self.need_save(): metrics.update(self.save_checkpoint()) diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 54dfba2b7a5..bc879e6afc1 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -472,7 +472,8 @@ def train_step(self, batch: Experiences) -> Dict: # noqa C901 # collect metrics metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) - metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + timing_metrics = compute_timing_metrics(batch=batch, timing_raw=timing_raw) + metrics.update({k.replace("timing_s/", "time/"): v for k, v in timing_metrics.items()}) n_gpus = self.resource_pool_manager.get_n_gpus() metrics.update( compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus) diff --git a/trinity/utils/monitor.py b/trinity/utils/monitor.py index 880d035f836..f0e847f4d2d 100644 --- a/trinity/utils/monitor.py +++ b/trinity/utils/monitor.py @@ -26,15 +26,20 @@ def gather_metrics(metric_list: List[Dict], prefix: str) -> Dict: - df = pd.DataFrame(metric_list) - numeric_df = df.select_dtypes(include=[np.number]) - stats_df = numeric_df.agg(["mean", "max", "min"]) - metric = {} - for col in stats_df.columns: - metric[f"{prefix}/{col}/mean"] = stats_df.loc["mean", col].item() - metric[f"{prefix}/{col}/max"] = stats_df.loc["max", col].item() - metric[f"{prefix}/{col}/min"] = stats_df.loc["min", col].item() - return metric + if not metric_list: + return {} + try: + df = pd.DataFrame(metric_list) + numeric_df = df.select_dtypes(include=[np.number]) + stats_df = numeric_df.agg(["mean", "max", "min"]) + metric = {} + for col in stats_df.columns: + metric[f"{prefix}/{col}/mean"] = stats_df.loc["mean", col].item() + metric[f"{prefix}/{col}/max"] = stats_df.loc["max", col].item() + metric[f"{prefix}/{col}/min"] = stats_df.loc["min", col].item() + return metric + except Exception as e: + raise ValueError(f"Failed to gather metrics: {e}") from e class Monitor(ABC):