Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/sphinx_doc/source/tutorial/example_mix_algo.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class MixSampleStrategy(SampleStrategy):

async def sample(self, step: int) -> Tuple[Experiences, Dict, List]:
metrics = {}
with Timer(metrics, "read_time"):
with Timer(metrics, "time/read_experience"):
usual_exp_list = await self.usual_exp_buffer.read_async()
for exp in usual_exp_list:
if exp.info is None:
Expand Down
2 changes: 1 addition & 1 deletion docs/sphinx_doc/source_zh/tutorial/example_mix_algo.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class MixSampleStrategy(SampleStrategy):

async def sample(self, step: int) -> Tuple[Experiences, Dict, List]:
metrics = {}
with Timer(metrics, "read_time"):
with Timer(metrics, "time/read_experience"):
usual_exp_list = await self.usual_exp_buffer.read_async()
for exp in usual_exp_list:
if exp.info is None:
Expand Down
5 changes: 1 addition & 4 deletions trinity/algorithm/advantage_fn/advantage_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,7 @@ def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
for group_id, group_exps in exp_groups.items():
group_exps, group_metrics = self.calculate_group_advantage(group_id, group_exps)
metric_list.append(group_metrics)
try:
metrics = gather_metrics(metric_list, "group_advantages")
except ValueError:
metrics = {} # empty metric list causes ValueError, ignore it
metrics = gather_metrics(metric_list, "group_advantages")
exps = [exp for group in exp_groups.values() for exp in group] # Flatten the list
return exps, metrics

Expand Down
10 changes: 5 additions & 5 deletions trinity/algorithm/advantage_fn/grpo_advantage.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,11 @@ def process(self, exps):
group_id, group_exps, precomputed_std=precomputed_std
)
metric_list.append(group_metrics)
try:
# TODO: sum skipped count
metrics = gather_metrics(metric_list, "group_advantages")
except ValueError:
metrics = {} # empty metric list causes ValueError, ignore it

# Update the filtered_count metric
filtered_count = sum(metric.pop("skipped_count", 0) for metric in metric_list)
metrics = gather_metrics(metric_list, "group_advantages")
metrics["filtered_count"] = filtered_count
if self.duplicate_experiences and self.std_threshold is not None:
exps = self._duplicate_experiences(exp_groups)
else:
Expand Down
7 changes: 2 additions & 5 deletions trinity/algorithm/advantage_fn/multi_step_grpo_advantage.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,8 @@ def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
cnt += len(exps)
result_exps.extend(exps)

try:
metrics = gather_metrics(metric_list, "group_advantages")
metrics["experience_count"] = cnt
except ValueError:
metrics = {} # empty metric list causes ValueError, ignore it
metrics = gather_metrics(metric_list, "group_advantages")
metrics["experience_count"] = cnt
return result_exps, metrics

def __call__(self, exps, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions trinity/algorithm/sample_strategy/mix_sample_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(self, buffer_config: BufferConfig, **kwargs):

async def sample(self, step: int) -> Tuple[Experiences, Dict, List]:
metrics = {}
with Timer(metrics, "read_time"):
with Timer(metrics, "time/read_experience"):
usual_exp_list = await self.usual_exp_buffer.read_async()
for exp in usual_exp_list:
if exp.info is None:
Expand All @@ -86,7 +86,7 @@ async def sample(self, step: int) -> Tuple[Experiences, Dict, List]:
exp_list = usual_exp_list + expert_exp_list
repr_samples = representative_sample(exp_list)

with Timer(metrics, "gather_time"):
with Timer(metrics, "time/gather_experience"):
exps = Experiences.gather_experiences(
experiences=exp_list,
pad_token_id=self.pad_token_id, # type: ignore [arg-type]
Expand Down
16 changes: 14 additions & 2 deletions trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def __init__(self, config: Config):
self.enable_lora = self.config.explorer.rollout_model.enable_lora
self.model_version = -1
self.last_sync_successful = True
self.explore_start_time = None
self.eval_start_time = None
self.logger.info("Finished initializing Explorer.")

async def setup_weight_sync_group(
Expand Down Expand Up @@ -207,6 +209,8 @@ async def explore(self) -> str:
return self.config.explorer.name

async def explore_step(self) -> bool:
if self.explore_start_time is None:
self.explore_start_time = time.time()
try:
tasks = await self.taskset.read_async()
except StopAsyncIteration:
Expand Down Expand Up @@ -252,6 +256,7 @@ def need_eval(self) -> bool:

async def eval(self):
"""Evaluation on all evaluation data samples."""
self.eval_start_time = time.time()
if len(self.config.buffer.explorer_input.eval_tasksets) == 0:
self.logger.warning("No evaluation data samples. Skip evaluation.")
return
Expand Down Expand Up @@ -338,6 +343,12 @@ async def _finish_steps(self, start_step: int, end_step: int, model_version: int
await self._finish_explore_step(step=step, model_version=model_version)
await self._finish_eval_step(step=step)

# Record the time: read_task + explore_step (>=1) + eval (if any)
if self.explore_start_time is not None:
metric = {"time/explore_total_time": time.time() - self.explore_start_time}
self.explore_start_time = None
self.monitor.log(metric, step=end_step)

async def _finish_explore_step(self, step: int, model_version: int) -> None:
statuses, exps = await self.scheduler.get_results(batch_id=step)
metric = {"rollout/model_version": model_version}
Expand All @@ -351,7 +362,6 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva
if not self.pending_eval_tasks:
return
step = step or self.explore_step_num
st = time.time()
metric = {}
while self.pending_eval_tasks:
eval_step, eval_task_name = self.pending_eval_tasks[0]
Expand All @@ -364,7 +374,9 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva
[status.metric for status in eval_results], f"{prefix}/{eval_task_name}"
)
)
metric[f"{prefix}/total_time"] = time.time() - st
if self.eval_start_time is not None:
metric.update({"time/eval": time.time() - self.eval_start_time})
self.eval_start_time = None
self.monitor.log(metric, step)

async def shutdown(self) -> None:
Expand Down
2 changes: 2 additions & 0 deletions trinity/service/data_juicer/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def initialize(self, config: dict):
self.session_id = response.json().get("session_id")

def process_experience(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
st = time.time()
if not self.session_id:
raise ValueError("DataJuicer session is not initialized.")

Expand All @@ -101,6 +102,7 @@ def process_experience(self, exps: List[Experience]) -> Tuple[List[Experience],
exp.info = {}
for stats_key in sample["__dj__stats__"]:
exp.info[stats_key] = sample["__dj__stats__"][stats_key]
metrics["time/dj_process_experience"] = time.time() - st
return exps, metrics

def process_task(self) -> Dict:
Expand Down
4 changes: 4 additions & 0 deletions trinity/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import asyncio
import time
import traceback
from abc import ABC, abstractmethod
from typing import Dict, List, Tuple
Expand Down Expand Up @@ -67,6 +68,7 @@ async def train(self) -> str:
"""Train the model."""
while self.train_step_num < self.total_steps:
try:
st = time.time()
# sample may be blocked due to explorer does not generate enough data
self.logger.info(f"Sample data for step {self.train_step_num + 1} started.")
sample_task = asyncio.create_task(self._sample_data())
Expand All @@ -80,6 +82,8 @@ async def train(self) -> str:
self.logger.info(f"Sample data for step {self.train_step_num + 1} finished.")
metrics.update(await self.train_step(exps))
if await self.need_sync():
# Record the time: sample_experience + train_step (>=1)
metrics.update({"time/train_total_time": time.time() - st})
metrics.update(await self.sync_weight())
if self.need_save():
metrics.update(self.save_checkpoint())
Expand Down
2 changes: 2 additions & 0 deletions trinity/utils/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@


def gather_metrics(metric_list: List[Dict], prefix: str) -> Dict:
if not metric_list:
return {} # empty metric list causes ValueError, return empty dict
df = pd.DataFrame(metric_list)
numeric_df = df.select_dtypes(include=[np.number])
stats_df = numeric_df.agg(["mean", "max", "min"])
Expand Down