Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions tests/buffer/experience_pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from tests.tools import RayUnittestBaseAysnc, get_template_config
from trinity.buffer import get_buffer_reader
from trinity.buffer.pipelines.experience_pipeline import ExperiencePipeline
from trinity.buffer.task_scheduler import SELECTOR_METRIC
from trinity.common.config import ExperiencePipelineConfig, OperatorConfig
from trinity.common.experience import EID, Experience

Expand Down Expand Up @@ -79,3 +80,39 @@ async def test_experience_pipeline(self):
with open(config.data_processor.experience_pipeline.input_save_path, "r") as f:
input_data = f.readlines()
self.assertEqual(len(input_data), len(experiences))

async def test_pass_rate_calculation(self) -> None:
config = get_template_config()
config.data_processor.experience_pipeline = ExperiencePipelineConfig(
save_input=True,
input_save_path=BUFFER_FILE_PATH,
operators=[
OperatorConfig(
name="pass_rate_calculator",
)
],
)
config.check_and_update()
config.buffer.trainer_input.experience_buffer.name = "pipeline_test_experience_buffer"
config.buffer.trainer_input.experience_buffer.max_read_timeout = 3

pipeline = (
ray.remote(ExperiencePipeline)
.options(name=f"{config.explorer.name}_pipeline")
.remote(config)
)
await pipeline.prepare.remote()
task_num = 8
repeat_times = 4
experiences = get_experiences(task_num=task_num, repeat_times=repeat_times)
for exp in experiences:
exp.info["task_index"] = {
"taskset_id": 0,
"index": exp.eid.task,
}
metrics = await pipeline.process.remote(experiences)
self.assertIn(SELECTOR_METRIC, metrics)
selector_metrics = metrics[SELECTOR_METRIC]
self.assertEqual(len(selector_metrics), 1)
self.assertEqual(set(selector_metrics[0]["indices"]), set(range(task_num)))
self.assertEqual(selector_metrics[0]["values"], [(repeat_times - 1.0) / 2] * task_num)
259 changes: 259 additions & 0 deletions tests/buffer/task_scheduler_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
import os
import unittest
from typing import Dict, List

from parameterized import parameterized

from tests.tools import get_template_config
from trinity.buffer.task_scheduler import TasksetScheduler
from trinity.common.config import DataSelectorConfig, FormatConfig, StorageConfig
from trinity.common.workflows.workflow import Task


class TestTaskScheduler(unittest.IsolatedAsyncioTestCase):
temp_output_path = "tmp/test_task_scheduler/"

@classmethod
def setUpClass(cls):
super().setUpClass()
os.makedirs(cls.temp_output_path, exist_ok=True)

@classmethod
def tearDownClass(cls):
super().tearDownClass()
if os.path.exists(cls.temp_output_path):
os.system(f"rm -rf {cls.temp_output_path}")
Comment thread
chenyushuo marked this conversation as resolved.
Outdated

def _check_batch_tasks(self, batch_tasks: List[Task], indices: List[Dict[str, int]]) -> None:
for task, index in zip(batch_tasks, indices):
self.assertEqual(task.index["taskset_id"], index["taskset_id"])
self.assertEqual(task.index["index"], index["index"])
self.assertEqual(
task.raw_task["question"], # type: ignore
f"Question {index['index'] + 1} in subset {index['taskset_id'] + 1}.",
)
self.assertEqual(
task.raw_task["answer"], # type: ignore
f"Answer {index['index'] + 1} in subset {index['taskset_id'] + 1}.",
)

@parameterized.expand(
[
(
{"selector_type": "sequential"},
[
{"index": 0, "taskset_id": 1},
{"index": 0, "taskset_id": 0},
{"index": 1, "taskset_id": 1},
{"index": 1, "taskset_id": 0},
{"index": 2, "taskset_id": 1},
{"index": 3, "taskset_id": 1},
{"index": 2, "taskset_id": 0},
{"index": 3, "taskset_id": 0},
{"index": 4, "taskset_id": 1},
{"index": 5, "taskset_id": 1},
{"index": 6, "taskset_id": 1},
{"index": 4, "taskset_id": 0},
{"index": 0, "taskset_id": 1},
{"index": 1, "taskset_id": 1},
{"index": 0, "taskset_id": 0},
{"index": 2, "taskset_id": 1},
{"index": 3, "taskset_id": 1},
{"index": 1, "taskset_id": 0},
{"index": 2, "taskset_id": 0},
{"index": 4, "taskset_id": 1},
{"index": 3, "taskset_id": 0},
{"index": 5, "taskset_id": 1},
{"index": 6, "taskset_id": 1},
{"index": 4, "taskset_id": 0},
],
),
(
{"selector_type": "shuffle", "seed": 42},
[
{"index": 3, "taskset_id": 1},
{"index": 4, "taskset_id": 0},
{"index": 2, "taskset_id": 1},
{"index": 2, "taskset_id": 0},
{"index": 6, "taskset_id": 1},
{"index": 4, "taskset_id": 1},
{"index": 3, "taskset_id": 0},
{"index": 1, "taskset_id": 0},
{"index": 1, "taskset_id": 1},
{"index": 5, "taskset_id": 1},
{"index": 0, "taskset_id": 1},
{"index": 0, "taskset_id": 0},
{"index": 2, "taskset_id": 1},
{"index": 6, "taskset_id": 1},
{"index": 4, "taskset_id": 0},
{"index": 5, "taskset_id": 1},
{"index": 1, "taskset_id": 1},
{"index": 1, "taskset_id": 0},
{"index": 2, "taskset_id": 0},
{"index": 4, "taskset_id": 1},
{"index": 0, "taskset_id": 0},
{"index": 0, "taskset_id": 1},
{"index": 3, "taskset_id": 1},
{"index": 3, "taskset_id": 0},
],
),
(
{"selector_type": "random", "seed": 42},
[
{"index": 0, "taskset_id": 1},
{"index": 0, "taskset_id": 0},
{"index": 3, "taskset_id": 1},
{"index": 2, "taskset_id": 0},
{"index": 4, "taskset_id": 1},
{"index": 0, "taskset_id": 1},
{"index": 2, "taskset_id": 0},
{"index": 0, "taskset_id": 0},
{"index": 6, "taskset_id": 1},
{"index": 3, "taskset_id": 1},
{"index": 0, "taskset_id": 1},
{"index": 2, "taskset_id": 0},
{"index": 0, "taskset_id": 1},
{"index": 2, "taskset_id": 1},
{"index": 0, "taskset_id": 0},
{"index": 2, "taskset_id": 1},
{"index": 6, "taskset_id": 1},
{"index": 0, "taskset_id": 0},
{"index": 0, "taskset_id": 0},
{"index": 5, "taskset_id": 1},
{"index": 3, "taskset_id": 0},
{"index": 2, "taskset_id": 1},
{"index": 6, "taskset_id": 1},
{"index": 1, "taskset_id": 0},
],
),
(
{"selector_type": "offline_easy2hard", "feature_keys": ["feature_offline"]},
[
{"index": 3, "taskset_id": 1},
{"index": 3, "taskset_id": 0},
{"index": 4, "taskset_id": 1},
{"index": 0, "taskset_id": 0},
{"index": 1, "taskset_id": 1},
{"index": 0, "taskset_id": 1},
{"index": 2, "taskset_id": 0},
{"index": 4, "taskset_id": 0},
{"index": 6, "taskset_id": 1},
{"index": 5, "taskset_id": 1},
{"index": 2, "taskset_id": 1},
{"index": 1, "taskset_id": 0},
{"index": 3, "taskset_id": 1},
{"index": 4, "taskset_id": 1},
{"index": 3, "taskset_id": 0},
{"index": 1, "taskset_id": 1},
{"index": 0, "taskset_id": 1},
{"index": 0, "taskset_id": 0},
{"index": 2, "taskset_id": 0},
{"index": 6, "taskset_id": 1},
{"index": 4, "taskset_id": 0},
{"index": 5, "taskset_id": 1},
{"index": 2, "taskset_id": 1},
{"index": 1, "taskset_id": 0},
],
),
(
{"selector_type": "diff_based", "feature_keys": ["feat_1", "feat_2"]},
[
{"index": 3, "taskset_id": 1},
{"index": 3, "taskset_id": 0},
{"index": 6, "taskset_id": 1},
{"index": 2, "taskset_id": 0},
{"index": 2, "taskset_id": 1},
{"index": 3, "taskset_id": 1},
{"index": 2, "taskset_id": 0},
{"index": 3, "taskset_id": 0},
{"index": 2, "taskset_id": 1},
{"index": 1, "taskset_id": 1},
{"index": 4, "taskset_id": 1},
{"index": 2, "taskset_id": 0},
{"index": 3, "taskset_id": 1},
{"index": 2, "taskset_id": 1},
{"index": 4, "taskset_id": 0},
{"index": 4, "taskset_id": 1},
{"index": 5, "taskset_id": 1},
{"index": 4, "taskset_id": 0},
{"index": 3, "taskset_id": 0},
{"index": 5, "taskset_id": 1},
{"index": 1, "taskset_id": 0},
{"index": 6, "taskset_id": 1},
{"index": 6, "taskset_id": 1},
{"index": 4, "taskset_id": 0},
],
),
]
)
async def test_task_scheduler(self, task_selector_kwargs, batch_tasks_orders) -> None:
config = get_template_config()
config.buffer.batch_size = 2
config.buffer.total_epochs = 2
config.buffer.explorer_input.taskset = None
config.buffer.explorer_input.tasksets = [
StorageConfig(
name="subset_1",
path=os.path.join(
os.path.dirname(__file__),
"..",
"template",
"data",
"task_scheduler",
"subset_1",
),
split="train",
enable_progress_bar=False,
format=FormatConfig(
prompt_key="question",
response_key="answer",
),
default_workflow_type="math_workflow",
default_reward_fn_type="math_reward",
task_selector=DataSelectorConfig(
**task_selector_kwargs,
),
),
StorageConfig(
name="subset_2",
path=os.path.join(
os.path.dirname(__file__),
"..",
"template",
"data",
"task_scheduler",
"subset_2",
),
split="train",
enable_progress_bar=False,
format=FormatConfig(
prompt_key="question",
response_key="answer",
),
default_workflow_type="math_workflow",
default_reward_fn_type="math_reward",
task_selector=DataSelectorConfig(
**task_selector_kwargs,
),
),
]
config.check_and_update()

task_scheduler = TasksetScheduler({}, config)
self.assertEqual(len(batch_tasks_orders) % config.buffer.batch_size, 0)
for i, start_id in enumerate(range(0, len(batch_tasks_orders), config.buffer.batch_size)):
batch_tasks_indices = batch_tasks_orders[start_id : start_id + config.buffer.batch_size]
batch_tasks = await task_scheduler.read_async()
# for task in batch_tasks: # used for debug
# print(f"{task.index},")
self._check_batch_tasks(batch_tasks, batch_tasks_indices)
if i % 3 == 2:
# test resume
state_dict = {
"latest_iteration": task_scheduler.step,
"taskset_states": task_scheduler.state_dict(),
}
task_scheduler = TasksetScheduler(state_dict, config)

with self.assertRaises(StopAsyncIteration):
batch_tasks = await task_scheduler.read_async()
2 changes: 1 addition & 1 deletion tests/cli/launcher_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def test_debug_mode(self, mock_load):
except Exception:
time.sleep(3)
output_file = os.path.join(self.config.checkpoint_job_dir, "debug.html")
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
self.config.buffer.explorer_input.tasksets = [get_unittest_dataset_config("gsm8k")]
mock_load.return_value = self.config
with mock.patch(
"argparse.ArgumentParser.parse_args",
Expand Down
4 changes: 2 additions & 2 deletions tests/common/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_load_default_config(self):
self.assertEqual(config.trainer.trainer_config.trainer.project_name, config.project)
self.assertEqual(config.trainer.trainer_config.trainer.experiment_name, config.name)
self.assertEqual(
config.buffer.explorer_input.taskset.repeat_times, config.algorithm.repeat_times
config.buffer.explorer_input.tasksets[0].repeat_times, config.algorithm.repeat_times
)
self.assertEqual(config.model.model_path, config.model.critic_model_path)
self.assertEqual(config.model.model_path, config.explorer.rollout_model.model_path)
Expand Down Expand Up @@ -110,7 +110,7 @@ def test_default_workflow(self):
"math_boxed_workflow",
)
self.assertEqual(
config.buffer.explorer_input.taskset.default_workflow_type,
config.buffer.explorer_input.tasksets[0].default_workflow_type,
"simple_workflow",
)

Expand Down
5 changes: 5 additions & 0 deletions tests/template/data/task_scheduler/subset_1/train.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{"question": "Question 1 in subset 1.", "answer": "Answer 1 in subset 1.", "feature_offline": 0.5, "feat_1": 0.4, "feat_2": 0.3}
{"question": "Question 2 in subset 1.", "answer": "Answer 2 in subset 1.", "feature_offline": 0.1, "feat_1": 0.1, "feat_2": 0.1}
{"question": "Question 3 in subset 1.", "answer": "Answer 3 in subset 1.", "feature_offline": 0.4, "feat_1": 0.5, "feat_2": 0.3}
{"question": "Question 4 in subset 1.", "answer": "Answer 4 in subset 1.", "feature_offline": 0.5, "feat_1": 0.3, "feat_2": 0.5}
{"question": "Question 5 in subset 1.", "answer": "Answer 5 in subset 1.", "feature_offline": 0.2, "feat_1": 0.1, "feat_2": 0.5}
7 changes: 7 additions & 0 deletions tests/template/data/task_scheduler/subset_2/train.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{"question": "Question 1 in subset 2.", "answer": "Answer 1 in subset 2.", "feature_offline": 0.2, "feat_1": 0.5, "feat_2": 0.2}
{"question": "Question 2 in subset 2.", "answer": "Answer 2 in subset 2.", "feature_offline": 0.3, "feat_1": 0.6, "feat_2": 0.2}
{"question": "Question 3 in subset 2.", "answer": "Answer 3 in subset 2.", "feature_offline": 0.1, "feat_1": 0.7, "feat_2": 0.4}
{"question": "Question 4 in subset 2.", "answer": "Answer 4 in subset 2.", "feature_offline": 0.5, "feat_1": 0.1, "feat_2": 0.4}
{"question": "Question 5 in subset 2.", "answer": "Answer 5 in subset 2.", "feature_offline": 0.3, "feat_1": 0.1, "feat_2": 0.7}
{"question": "Question 6 in subset 2.", "answer": "Answer 6 in subset 2.", "feature_offline": 0.1, "feat_1": 0.7, "feat_2": 0.4}
{"question": "Question 7 in subset 2.", "answer": "Answer 7 in subset 2.", "feature_offline": 0.1, "feat_1": 0.7, "feat_2": 0.6}
4 changes: 4 additions & 0 deletions tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
AlgorithmConfig,
BufferConfig,
Config,
DataSelectorConfig,
ExplorerInput,
StageConfig,
StorageConfig,
Expand Down Expand Up @@ -73,6 +74,9 @@ def test_trainer(self):
"""Test the both and bench mode."""
# test both mode
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
self.config.buffer.explorer_input.taskset.task_selector = DataSelectorConfig(
selector_type="shuffle", seed=42
)
self.config.buffer.explorer_input.eval_tasksets.append(
get_unittest_dataset_config("countdown", "test")
)
Expand Down
17 changes: 16 additions & 1 deletion trinity/buffer/buffer_reader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Reader of the buffer."""
from abc import ABC, abstractmethod
from typing import List, Optional
from typing import Dict, List, Optional


class BufferReader(ABC):
Expand All @@ -13,3 +13,18 @@ def read(self, batch_size: Optional[int] = None) -> List:
@abstractmethod
async def read_async(self, batch_size: Optional[int] = None) -> List:
"""Read from buffer asynchronously."""

def __len__(self) -> int:
"""Get the number of samples in buffer."""
raise NotImplementedError

@property
@abstractmethod
def index(self) -> int:
Comment thread
chenyushuo marked this conversation as resolved.
Outdated
"""Get the current index."""

def state_dict(self) -> Dict:
return {}

def load_state_dict(self, state_dict: Dict) -> None:
pass
Loading