Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/sphinx_doc/source/tutorial/example_mix_algo.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class MixSampleStrategy(SampleStrategy):

async def sample(self, step: int) -> Tuple[Experiences, Dict, List]:
metrics = {}
with Timer(metrics, "read_time"):
with Timer(metrics, "time/read_experience"):
usual_exp_list = await self.usual_exp_buffer.read_async()
for exp in usual_exp_list:
if exp.info is None:
Expand All @@ -132,7 +132,7 @@ class MixSampleStrategy(SampleStrategy):
exp_list = usual_exp_list + expert_exp_list
repr_samples = representative_sample(exp_list)

with Timer(metrics, "gather_time"):
with Timer(metrics, "time/gather_experience"):
exps = Experiences.gather_experiences(
experiences=exp_list,
pad_token_id=self.pad_token_id, # type: ignore [arg-type]
Expand Down
4 changes: 2 additions & 2 deletions docs/sphinx_doc/source_zh/tutorial/example_mix_algo.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class MixSampleStrategy(SampleStrategy):

async def sample(self, step: int) -> Tuple[Experiences, Dict, List]:
metrics = {}
with Timer(metrics, "read_time"):
with Timer(metrics, "time/read_experience"):
usual_exp_list = await self.usual_exp_buffer.read_async()
for exp in usual_exp_list:
if exp.info is None:
Expand All @@ -124,7 +124,7 @@ class MixSampleStrategy(SampleStrategy):
exp_list = usual_exp_list + expert_exp_list
repr_samples = representative_sample(exp_list)

with Timer(metrics, "gather_time"):
with Timer(metrics, "time/gather_experience"):
exps = Experiences.gather_experiences(
experiences=exp_list,
pad_token_id=self.pad_token_id, # type: ignore [arg-type]
Expand Down
4 changes: 2 additions & 2 deletions tests/algorithm/advantage_fn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def test_grpo_reward_std(self):

exps, metrics = advantage_fn(exps)
self.assertEqual(len(exps), 0)
self.assertIn("group_advantages/skipped_count/mean", metrics)
self.assertEqual(metrics["group_advantages/skipped_count/mean"], 5)
self.assertIn("filtered_count", metrics)
self.assertEqual(metrics["filtered_count"], 15)

def test_grpo_correct_bias(self):
advantage_fn_cls = ADVANTAGE_FN.get("grpo")
Expand Down
2 changes: 1 addition & 1 deletion tests/buffer/experience_pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ async def test_experience_pipeline(self):
experiences = get_experiences(task_num=task_num, repeat_times=repeat_times)
metrics = await pipeline.process.remote(experiences)
self.assertEqual(
metrics["pipeline/experience_count"], task_num * (repeat_times - 1)
metrics["experience_pipeline/experience_count"], task_num * (repeat_times - 1)
) # first experience of each task will be filtered out by the reward filter

# tests
Expand Down
4 changes: 2 additions & 2 deletions tests/explorer/explorer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def test_explorer(self):
eval_metrics = parser.metric_list("eval")
self.assertTrue(len(eval_metrics) == 0)
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4)
self.assertTrue(parser.metric_exist("pipeline/experience_count"))
experience_counts = parser.metric_values("pipeline/experience_count")
self.assertTrue(parser.metric_exist("experience_pipeline/experience_count"))
experience_counts = parser.metric_values("experience_pipeline/experience_count")
self.assertTrue(len(experience_counts) == 4)
for count in experience_counts:
self.assertTrue(count >= 0)
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,7 +819,7 @@ def test_trainer(self):
self.assertTrue(len(rollout_metrics) > 0)
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4)
self.assertEqual(
parser.metric_values("pipeline/experience_count")[1], 16
parser.metric_values("experience_pipeline/experience_count")[1], 16
) # 16 rft experiences
# test actor metrics
actor_metrics = parser.metric_list("actor")
Expand Down
5 changes: 1 addition & 4 deletions trinity/algorithm/advantage_fn/advantage_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,7 @@ def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
for group_id, group_exps in exp_groups.items():
group_exps, group_metrics = self.calculate_group_advantage(group_id, group_exps)
metric_list.append(group_metrics)
try:
metrics = gather_metrics(metric_list, "group_advantages")
except ValueError:
metrics = {} # empty metric list causes ValueError, ignore it
metrics = gather_metrics(metric_list, "group_advantages")
exps = [exp for group in exp_groups.values() for exp in group] # Flatten the list
return exps, metrics

Expand Down
10 changes: 5 additions & 5 deletions trinity/algorithm/advantage_fn/grpo_advantage.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,11 @@ def process(self, exps):
group_id, group_exps, precomputed_std=precomputed_std
)
metric_list.append(group_metrics)
try:
# TODO: sum skipped count
metrics = gather_metrics(metric_list, "group_advantages")
except ValueError:
metrics = {} # empty metric list causes ValueError, ignore it

# Update the filtered_count metric
filtered_count = sum(metric.pop("skipped_count", 0) for metric in metric_list)
metrics = gather_metrics(metric_list, "group_advantages")
metrics["filtered_count"] = filtered_count
if self.duplicate_experiences and self.std_threshold is not None:
exps = self._duplicate_experiences(exp_groups)
else:
Expand Down
7 changes: 2 additions & 5 deletions trinity/algorithm/advantage_fn/multi_step_grpo_advantage.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,8 @@ def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
cnt += len(exps)
result_exps.extend(exps)

try:
metrics = gather_metrics(metric_list, "group_advantages")
metrics["experience_count"] = cnt
except ValueError:
metrics = {} # empty metric list causes ValueError, ignore it
metrics = gather_metrics(metric_list, "group_advantages")
metrics["experience_count"] = cnt
return result_exps, metrics

def __call__(self, exps, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions trinity/algorithm/sample_strategy/mix_sample_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(self, buffer_config: BufferConfig, **kwargs):

async def sample(self, step: int) -> Tuple[Experiences, Dict, List]:
metrics = {}
with Timer(metrics, "read_time"):
with Timer(metrics, "time/read_experience"):
usual_exp_list = await self.usual_exp_buffer.read_async()
for exp in usual_exp_list:
if exp.info is None:
Expand All @@ -86,7 +86,7 @@ async def sample(self, step: int) -> Tuple[Experiences, Dict, List]:
exp_list = usual_exp_list + expert_exp_list
repr_samples = representative_sample(exp_list)

with Timer(metrics, "gather_time"):
with Timer(metrics, "time/gather_experience"):
exps = Experiences.gather_experiences(
experiences=exp_list,
pad_token_id=self.pad_token_id, # type: ignore [arg-type]
Expand Down
1 change: 1 addition & 0 deletions trinity/algorithm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,5 @@ def prefix_metrics(src_metrics: dict, prefix: str, dst_metrics: dict = None) ->
dst_metrics = {}
for k, v in src_metrics.items():
dst_metrics[f"{prefix}/{k}"] = v

return dst_metrics
23 changes: 16 additions & 7 deletions trinity/buffer/pipelines/experience_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
import traceback
from typing import Dict, List, Optional

Expand All @@ -15,6 +16,7 @@
from trinity.common.experience import Experience
from trinity.utils.log import get_logger
from trinity.utils.plugin_loader import load_plugins
from trinity.utils.timer import Timer


def get_input_buffers(
Expand Down Expand Up @@ -112,26 +114,33 @@ async def process(self, exps: List[Experience]) -> Dict:
Returns:
Dict: A dictionary containing metrics collected during the processing of experiences.
"""
st = time.time()
if self.input_store is not None:
await self.input_store.write_async(exps)

metrics = {}

# Process experiences through operators
for operator in self.operators:
exps, metric = operator.process(exps)
metrics.update(metric)

for idx, operator in enumerate(self.operators):
with Timer(
metrics, f"time/experience_pipeline/operator/{idx}_{operator.__class__.__name__}"
):
exps, metric = operator.process(exps)
metrics.update(metric)
metrics["experience_count"] = len(exps)

# Write processed experiences to output buffer
await self.output.write_async(exps)
with Timer(metrics, "time/experience_pipeline/write"):
await self.output.write_async(exps)
metrics["time/experience_pipeline/total"] = time.time() - st

# prefix metrics keys with 'pipeline/'
result_metrics = {}
for key, value in metrics.items():
if isinstance(value, (int, float)):
result_metrics[f"pipeline/{key}"] = float(value)
if key.startswith("time/"):
result_metrics[key] = value
elif isinstance(value, (int, float)):
result_metrics[f"experience_pipeline/{key}"] = float(value)
if SELECTOR_METRIC in metrics:
result_metrics[SELECTOR_METRIC] = metrics[SELECTOR_METRIC]

Expand Down
20 changes: 17 additions & 3 deletions trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from trinity.utils.log import get_logger
from trinity.utils.monitor import MONITOR, gather_metrics
from trinity.utils.plugin_loader import load_plugins
from trinity.utils.timer import Timer


class Explorer:
Expand Down Expand Up @@ -74,6 +75,8 @@ def __init__(self, config: Config):
self.enable_lora = self.config.explorer.rollout_model.enable_lora
self.model_version = -1
self.last_sync_successful = True
self.eval_start_time = None
self.explore_start_time = None
self.logger.info("Finished initializing Explorer.")

async def setup_weight_sync_group(
Expand Down Expand Up @@ -205,6 +208,8 @@ async def explore(self) -> str:
return self.config.explorer.name

async def explore_step(self) -> bool:
if self.explore_start_time is None:
self.explore_start_time = time.time()
try:
tasks = await self.taskset.read_async()
except StopAsyncIteration:
Expand Down Expand Up @@ -250,6 +255,7 @@ def need_eval(self) -> bool:

async def eval(self):
"""Evaluation on all evaluation data samples."""
self.eval_start_time = time.time()
if len(self.config.buffer.explorer_input.eval_tasksets) == 0:
self.logger.warning("No evaluation data samples. Skip evaluation.")
return
Expand Down Expand Up @@ -336,9 +342,16 @@ async def _finish_steps(self, start_step: int, end_step: int, model_version: int
await self._finish_explore_step(step=step, model_version=model_version)
await self._finish_eval_step(step=step)

# Record the time: read_task + explore_step (>=1) + eval (if any)
if self.explore_start_time is not None:
metric = {"time/explorer_sync_interval": time.time() - self.explore_start_time}
self.explore_start_time = None
self.monitor.log(metric, step=end_step)

async def _finish_explore_step(self, step: int, model_version: int) -> None:
statuses, exps = await self.scheduler.get_results(batch_id=step)
metric = {"rollout/model_version": model_version}
with Timer(metric, "time/wait_explore_step"):
statuses, exps = await self.scheduler.get_results(batch_id=step)
pipeline_metrics = await self.experience_pipeline.process.remote(exps)
self.taskset.update(pipeline_metrics)
metric.update(pipeline_metrics)
Expand All @@ -350,7 +363,6 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva
if not self.pending_eval_tasks:
return
step = step or self.explore_step_num
st = time.time()
metric = {}
while self.pending_eval_tasks:
eval_step, eval_task_name = self.pending_eval_tasks[0]
Expand All @@ -363,7 +375,9 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva
[status.metric for status in eval_results], f"{prefix}/{eval_task_name}"
)
)
metric[f"{prefix}/total_time"] = time.time() - st
if self.eval_start_time is not None:
metric.update({"time/eval": time.time() - self.eval_start_time})
self.eval_start_time = None
self.monitor.log(metric, step)

async def shutdown(self) -> None:
Expand Down
2 changes: 2 additions & 0 deletions trinity/service/data_juicer/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def initialize(self, config: dict):
self.session_id = response.json().get("session_id")

def process_experience(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
st = time.time()
if not self.session_id:
raise ValueError("DataJuicer session is not initialized.")

Expand All @@ -101,6 +102,7 @@ def process_experience(self, exps: List[Experience]) -> Tuple[List[Experience],
exp.info = {}
for stats_key in sample["__dj__stats__"]:
exp.info[stats_key] = sample["__dj__stats__"][stats_key]
metrics["time/dj_process_experience"] = time.time() - st
return exps, metrics

def process_task(self) -> Dict:
Expand Down
4 changes: 4 additions & 0 deletions trinity/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import asyncio
import time
import traceback
from abc import ABC, abstractmethod
from typing import Dict, List, Tuple
Expand Down Expand Up @@ -67,6 +68,7 @@ async def train(self) -> str:
"""Train the model."""
while self.train_step_num < self.total_steps:
try:
st = time.time()
# sample may be blocked due to explorer does not generate enough data
self.logger.info(f"Sample data for step {self.train_step_num + 1} started.")
sample_task = asyncio.create_task(self._sample_data())
Expand All @@ -80,6 +82,8 @@ async def train(self) -> str:
self.logger.info(f"Sample data for step {self.train_step_num + 1} finished.")
metrics.update(await self.train_step(exps))
if await self.need_sync():
# Record the time: sample_experience + train_step (>=1)
metrics.update({"time/trainer_sync_interval": time.time() - st})
metrics.update(await self.sync_weight())
if self.need_save():
metrics.update(self.save_checkpoint())
Expand Down
3 changes: 2 additions & 1 deletion trinity/trainer/verl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,8 @@ def train_step(self, batch: Experiences) -> Dict: # noqa C901

# collect metrics
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
timing_metrics = compute_timing_metrics(batch=batch, timing_raw=timing_raw)
metrics.update({k.replace("timing_s/", "time/"): v for k, v in timing_metrics.items()})
n_gpus = self.resource_pool_manager.get_n_gpus()
metrics.update(
compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)
Expand Down
23 changes: 14 additions & 9 deletions trinity/utils/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,20 @@


def gather_metrics(metric_list: List[Dict], prefix: str) -> Dict:
df = pd.DataFrame(metric_list)
numeric_df = df.select_dtypes(include=[np.number])
stats_df = numeric_df.agg(["mean", "max", "min"])
metric = {}
for col in stats_df.columns:
metric[f"{prefix}/{col}/mean"] = stats_df.loc["mean", col].item()
metric[f"{prefix}/{col}/max"] = stats_df.loc["max", col].item()
metric[f"{prefix}/{col}/min"] = stats_df.loc["min", col].item()
return metric
if not metric_list:
return {}
try:
df = pd.DataFrame(metric_list)
numeric_df = df.select_dtypes(include=[np.number])
stats_df = numeric_df.agg(["mean", "max", "min"])
metric = {}
for col in stats_df.columns:
metric[f"{prefix}/{col}/mean"] = stats_df.loc["mean", col].item()
metric[f"{prefix}/{col}/max"] = stats_df.loc["max", col].item()
metric[f"{prefix}/{col}/min"] = stats_df.loc["min", col].item()
return metric
except Exception as e:
raise ValueError(f"Failed to gather metrics: {e}") from e


class Monitor(ABC):
Expand Down