Skip to content
Merged
56 changes: 34 additions & 22 deletions trinity/buffer/reader/file_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import List, Optional

import datasets
from datasets import Dataset, load_dataset
from datasets import Dataset, IterableDataset, load_dataset

from trinity.buffer.buffer_reader import BufferReader
from trinity.buffer.schema.formatter import FORMATTER
Expand Down Expand Up @@ -32,18 +32,24 @@ def __init__(
drop_last: bool = True,
total_steps: Optional[int] = None,
enable_progress_bar: Optional[bool] = True,
shuffle: bool = False,
base_seed: Optional[int] = 42,
):
self.dataset = dataset
self.dataset_size = len(dataset)
self.name = name
self.current_batch_size = None
self.drop_last = drop_last
self.shuffle = shuffle
self.base_seed = base_seed

self.current_offset = offset
self.iter = iter(self.dataset)

for _ in range(self.current_offset % self.dataset_size):
next(self.iter)
if self.shuffle:
assert not isinstance(
dataset, IterableDataset
), "Shuffle is not supported for IterableDataset"
self.dataset = dataset.shuffle(seed=self.current_seed)
Comment thread
chenyushuo marked this conversation as resolved.
Outdated
else:
self.dataset = dataset

# convert epochs/steps to sample number
if total_steps:
Expand All @@ -63,29 +69,29 @@ def __init__(

self.progress_bar.update(self.current_offset)

def current_seed(self):
return self.base_seed + self.current_offset // self.dataset_size

def read_batch(self, batch_size: int) -> List:
if self.current_offset >= self.total_samples:
self.progress_bar.close()
raise StopIteration
batch = []

while len(batch) < batch_size:
try:
item = next(self.iter)
batch.append(item)
self.current_offset += 1
except StopIteration:
if self.current_offset >= self.total_samples:
# No more data to read
if not self.drop_last and len(batch) > 0:
# return last batch
self.progress_bar.update(len(batch))
return batch
else:
self.progress_bar.close()
raise StopIteration
# Step to the next epoch
self.iter = iter(self.dataset)
batch.append(self.dataset[self.current_offset % self.dataset_size])
self.current_offset += 1
if self.shuffle and self.current_offset % self.dataset_size == 0:
self.dataset = self.dataset.shuffle(seed=self.current_seed)
Comment thread
chenyushuo marked this conversation as resolved.
Outdated
if self.current_offset >= self.total_samples:
# No more data to read
if not self.drop_last and len(batch) > 0:
# return last batch
self.progress_bar.update(len(batch))
return batch
else:
self.progress_bar.close()
raise StopIteration
self.progress_bar.update(batch_size)
return batch

Expand Down Expand Up @@ -144,9 +150,15 @@ def __init__(self, meta: StorageConfig, config: BufferConfig):
drop_last=not self.meta.is_eval,
total_steps=meta.total_steps,
enable_progress_bar=meta.enable_progress_bar,
shuffle=meta.shuffle,
base_seed=meta.seed,
)
self.formatter = FORMATTER.get("task")(meta)

@property
def index(self) -> int:
return self.dataset.current_offset

def read(self, batch_size: Optional[int] = None) -> List:
batch_size = batch_size or self.read_batch_size
tasks = []
Expand Down
62 changes: 62 additions & 0 deletions trinity/buffer/task_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# -*- coding: utf-8 -*-
"""The taskset scheduler."""

from collections import deque
from typing import Dict, List, Optional
from trinity.buffer.buffer import get_buffer_reader
from trinity.common.config import Config


class TasksetScheduler:
def __init__(self, explorer_state, config: Config):
if 'latest_task_index' in explorer_state:
assert len(config.buffer.explorer_input.taskset) == 1 # old format
Comment thread
chenyushuo marked this conversation as resolved.
Outdated
explorer_state['taskset'] = [
{
"index": explorer_state['latest_task_index'],
}
]

tasksets_config = config.buffer.explorer_input.tasksets

tasksets_state = explorer_state.get('taskset', [{"index": 0}] * len(tasksets_config))
Comment thread
chenyushuo marked this conversation as resolved.
Outdated
self.tasksets = []
for taskset_config, taskset_state in zip(tasksets_config, tasksets_state):
taskset_config.index = taskset_state["index"]
assert not taskset_config.is_eval
self.tasksets.append(get_buffer_reader(taskset_config, config.buffer))
self.tasksets_queue = deque()
for taskset in self.tasksets:
self.tasksets_queue.append(taskset)

def read(self, batch_size: Optional[int] = None) -> List:
batch = []
for _ in range(len(self.tasksets_queue)):
taskset = self.tasksets_queue.popleft()
try:
batch = taskset.read(batch_size)
assert len(batch) == batch_size
self.tasksets_queue.append(taskset)
break
except StopIteration:
pass
if len(batch) == 0:
raise StopIteration
return batch

async def read_async(self, batch_size: Optional[int] = None) -> List:
try:
return self.read(batch_size)
except StopIteration as e:
raise StopAsyncIteration from e

def save_state(self) -> Dict:
Comment thread
chenyushuo marked this conversation as resolved.
Outdated
return [
{
"index": taskset.index,
}
for taskset in self.tasksets
]

def update(self, experiences, explore_metric, eval_metric) -> None:
pass
83 changes: 46 additions & 37 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

from omegaconf import OmegaConf

Expand Down Expand Up @@ -108,6 +108,10 @@ class StorageConfig:
path: Optional[str] = None
repeat_times: Optional[int] = None

# For shuffle
shuffle: bool = False
seed: int = 42

# For continuing training
index: int = 0

Expand Down Expand Up @@ -369,7 +373,8 @@ class ClusterConfig:
class ExplorerInput:
"""Config for explorer input."""

taskset: StorageConfig = field(default_factory=StorageConfig)
taskset: Optional[StorageConfig] = None
tasksets: List[StorageConfig] = field(default_factory=list)
eval_tasksets: List[StorageConfig] = field(default_factory=list)
# The following args provide default values for the corresponding args in `taskset` and `eval_tasksets`
default_workflow_type: Optional[str] = None
Expand Down Expand Up @@ -630,40 +635,44 @@ def _check_buffer(self) -> None: # noqa: C901
trainer_input = self.buffer.trainer_input
experience_buffer = trainer_input.experience_buffer
explorer_input = self.buffer.explorer_input
taskset = explorer_input.taskset

if self.mode != "train" and not taskset.path:
raise ValueError(
"`buffer.explorer_input.taskset.path` is required, please set it to the path of the taskset."
)
if not taskset.name:
taskset.name = "taskset"
if taskset.repeat_times is None or taskset.repeat_times != self.algorithm.repeat_times:
taskset.repeat_times = self.algorithm.repeat_times
logger.info(
"`buffer.explorer_input.taskset.repeat_times` is set to `algorithm.repeat_times`"
f" (={self.algorithm.repeat_times})."
if len(explorer_input.tasksets) == 0 and explorer_input.taskset:
explorer_input.tasksets.append(explorer_input.taskset)
tasksets = explorer_input.tasksets

for taskset in tasksets:
if self.mode != "train" and not taskset.path:
raise ValueError(
"`buffer.explorer_input.taskset.path` is required, please set it to the path of the taskset."
)
if not taskset.name:
taskset.name = "taskset"
if taskset.repeat_times is None or taskset.repeat_times != self.algorithm.repeat_times:
taskset.repeat_times = self.algorithm.repeat_times
logger.info(
"`buffer.explorer_input.taskset.repeat_times` is set to `algorithm.repeat_times`"
f" (={self.algorithm.repeat_times})."
)
if self.mode == "train":
assert (
experience_buffer is not None
), "`buffer.trainer_input.experience_buffer` is required when `mode` is `train`."
experience_buffer.total_epochs = self.buffer.total_epochs
experience_buffer.total_steps = self.buffer.total_steps
else:
taskset.is_eval = False
taskset.total_epochs = self.buffer.total_epochs
taskset.total_steps = self.buffer.total_steps

set_if_none(taskset, "default_workflow_type", explorer_input.default_workflow_type)
set_if_none(
taskset, "default_eval_workflow_type", explorer_input.default_eval_workflow_type
)
if self.mode == "train":
assert (
experience_buffer is not None
), "`buffer.trainer_input.experience_buffer` is required when `mode` is `train`."
experience_buffer.total_epochs = self.buffer.total_epochs
experience_buffer.total_steps = self.buffer.total_steps
else:
taskset.is_eval = False
taskset.total_epochs = self.buffer.total_epochs
taskset.total_steps = self.buffer.total_steps

set_if_none(taskset, "default_workflow_type", explorer_input.default_workflow_type)
set_if_none(
taskset, "default_eval_workflow_type", explorer_input.default_eval_workflow_type
)
set_if_none(taskset, "default_reward_fn_type", explorer_input.default_reward_fn_type)
set_if_none(taskset.format, "system_prompt", explorer_input.system_prompt)
set_if_none(taskset.format, "reply_prefix", explorer_input.reply_prefix)
set_if_none(taskset, "ray_namespace", self.ray_namespace)
set_if_none(taskset.rollout_args, "max_tokens", self.model.max_response_tokens)
set_if_none(taskset, "default_reward_fn_type", explorer_input.default_reward_fn_type)
set_if_none(taskset.format, "system_prompt", explorer_input.system_prompt)
set_if_none(taskset.format, "reply_prefix", explorer_input.reply_prefix)
set_if_none(taskset, "ray_namespace", self.ray_namespace)
set_if_none(taskset.rollout_args, "max_tokens", self.model.max_response_tokens)

remained_tasksets = []
for idx, dataset in enumerate(explorer_input.eval_tasksets):
Expand Down Expand Up @@ -730,8 +739,8 @@ def _check_buffer(self) -> None: # noqa: C901
task_pipeline = self.data_processor.task_pipeline
if task_pipeline is not None:
if task_pipeline.output is None:
if taskset.path is not None:
task_pipeline.output = taskset
if tasksets[0].path is not None:
Comment thread
chenyushuo marked this conversation as resolved.
Outdated
task_pipeline.output = tasksets[0]
elif (
experience_buffer.schema_type in {"dpo", "sft"}
and experience_buffer.path is not None
Expand All @@ -740,7 +749,7 @@ def _check_buffer(self) -> None: # noqa: C901
else:
raise ValueError(
"`data_processor.task_pipeline.output` is required when both "
"`buffer.explorer_input.taskset.path` and `buffer.trainer_input.experience_buffer.path` are "
"`buffer.explorer_input.tasksets[0].path` and `buffer.trainer_input.experience_buffer.path` are "
"None"
)
if task_pipeline.output.path and os.path.exists(task_pipeline.output.path):
Expand Down
23 changes: 12 additions & 11 deletions trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from trinity.buffer.buffer import get_buffer_reader
from trinity.buffer.pipelines.experience_pipeline import ExperiencePipeline
from trinity.buffer.task_scheduler import TasksetScheduler
from trinity.common.config import Config
from trinity.common.constants import (
ROLLOUT_WEIGHT_SYNC_GROUP_NAME,
Expand Down Expand Up @@ -49,12 +50,7 @@ def __init__(self, config: Config):
self.config = config
self.models, self.auxiliary_models = create_inference_models(config)
self.experience_pipeline = self._init_experience_pipeline()
self.config.buffer.explorer_input.taskset.index = explorer_state.get("latest_task_index", 0)
self.taskset = (
get_buffer_reader(self.config.buffer.explorer_input.taskset, self.config.buffer)
if self.config.mode != "serve"
else None
)
self.taskset = TasksetScheduler(explorer_state, config)
self.scheduler = None
self.monitor = MONITOR.get(self.config.monitor.monitor_type)(
project=self.config.project,
Expand Down Expand Up @@ -324,7 +320,7 @@ async def save_checkpoint(self, sync_weight: bool = False) -> None:
# save explore checkpoint
self.state.save_explorer(
current_step=self.explore_step_num,
current_task_index=self.explore_step_num * self.config.buffer.batch_size,
taskset_state=self.taskset.save_state(),
)

async def sync_weight(self) -> None:
Expand All @@ -335,19 +331,23 @@ async def sync_weight(self) -> None:
async def _finish_steps(self, start_step: int, end_step: int, model_version: int) -> None:
for step in range(start_step, end_step + 1):
self.logger.info(f"Log metrics of step {step}")
await self._finish_explore_step(step=step, model_version=model_version)
await self._finish_eval_step(step=step)
explore_metric, exps = await self._finish_explore_step(
step=step, model_version=model_version
)
eval_metric = await self._finish_eval_step(step=step)
self.taskset.update(exps, explore_metric, eval_metric)

async def _finish_explore_step(self, step: int, model_version: int) -> None:
async def _finish_explore_step(self, step: int, model_version: int):
Comment thread
chenyushuo marked this conversation as resolved.
Outdated
statuses, exps = await self.scheduler.get_results(batch_id=step)
metric = {"rollout/model_version": model_version}
pipeline_metrics = await self.experience_pipeline.process.remote(exps)
metric.update(pipeline_metrics)
if statuses:
metric.update(gather_metrics([status.metric for status in statuses], "rollout"))
self.monitor.log(metric, step=step)
return metric, exps

async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eval") -> None:
async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eval"):
Comment thread
chenyushuo marked this conversation as resolved.
Outdated
if not self.pending_eval_tasks:
return
step = step or self.explore_step_num
Expand All @@ -366,6 +366,7 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva
)
metric[f"{prefix}/total_time"] = time.time() - st
self.monitor.log(metric, step)
return metric

async def shutdown(self) -> None:
if self.scheduler:
Expand Down
4 changes: 2 additions & 2 deletions trinity/manager/state_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ def _check_config_consistency(self, config: Config) -> None:
def save_explorer(
self,
current_task_index: int,
current_step: int,
taskset_state: dict,
) -> None:
with open(self.explorer_state_path, "w", encoding="utf-8") as f:
json.dump(
{
"latest_task_index": current_task_index,
"latest_iteration": current_step,
"taskset_state": taskset_state,
},
f,
indent=2,
Expand Down
Loading