diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index df8cb279907..1cd0467be93 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -412,15 +412,16 @@ explorer: engine_type: vllm engine_num: 1 tensor_parallel_size: 1 - enable_history: False + enable_history: false auxiliary_models: - model_path: Qwen/Qwen2.5-7B-Instruct tensor_parallel_size: 1 eval_interval: 100 - eval_on_startup: True + eval_on_startup: true over_rollout: ratio: 0.0 wait_after_min: 30.0 + return_partial_tasks: false dynamic_timeout: enable: false ratio: 3.0 @@ -443,13 +444,14 @@ explorer: - `external`: Use external API-based model engine. - `rollout_model.engine_num`: Number of inference engines. - `rollout_model.tensor_parallel_size`: Degree of tensor parallelism. -- `rollout_model.enable_history`: Whether to enable model call history recording. If set to `True`, the model wrapper automatically records the return experiences of model calls. Please periodically extract the history via `extract_experience_from_history` to avoid out-of-memory issues. Default is `False`. +- `rollout_model.enable_history`: Whether to enable model call history recording. If set to `true`, the model wrapper automatically records the return experiences of model calls. Please periodically extract the history via `extract_experience_from_history` to avoid out-of-memory issues. Default is `false`. - `auxiliary_models`: Additional models used for custom workflows. - `eval_interval`: Interval (in steps) for evaluating the model. - `eval_on_startup`: Whether to evaluate the model on startup. More precisely, at step 0 with the original model, so it will not be triggered when restarting. - `over_rollout`: [Experimental] Configurations for over-rollout mechanism, which allows the explorer to proceed with fewer tasks than the full batch size. It effectively increases throughput in scenarios where some tasks take significantly longer to complete than others. Only applicable when dynamic synchronization (`synchronizer.sync_style` is not `fixed`) is used. - `ratio`: Explorer will only wait for `(1 - ratio) * batch_size` of tasks at each step. Default is `0.0`, meaning waiting for all tasks. - `wait_after_min`: After reaching the minimum task threshold, wait for this many seconds before proceeding. Default is `30.0` seconds. + - `return_partial_tasks`: Whether to return the results of tasks that have only completed partially (e.g., only some runs in GRPO). Default is `false`, meaning only return results of tasks that have completed all runs. - `dynamic_timeout`: [Experimental] Configurations for dynamic timeout mechanism, which adjusts the timeout for each task based on the average time taken for successful tasks. - `enable`: Whether to enable dynamic timeout. Default is `false`. - `ratio`: The timeout for each task is dynamically set to `average_time_per_success_task * ratio`. Default is `3.0`. diff --git a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md index eea9cb74226..5192f1b2786 100644 --- a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md @@ -409,15 +409,16 @@ explorer: engine_type: vllm engine_num: 1 tensor_parallel_size: 1 - enable_history: False + enable_history: false auxiliary_models: - model_path: Qwen/Qwen2.5-7B-Instruct tensor_parallel_size: 1 eval_interval: 100 - eval_on_startup: True + eval_on_startup: true over_rollout: ratio: 0.0 wait_after_min: 30.0 + return_partial_tasks: false dynamic_timeout: enable: false ratio: 3.0 @@ -447,6 +448,7 @@ explorer: - `over_rollout`: [实验性] 超量 rollout 机制的配置,允许 explorer 在每个步骤中使用少于完整批次大小的任务继续进行。这在某些任务显著耗时较长的场景中能有效地提高吞吐量。仅当使用动态同步(`synchronizer.sync_style` 不是 `fixed`)时适用。 - `ratio`: explorer 在每个步骤中仅等待 `(1 - ratio) * batch_size` 的任务。默认为 `0.0`,表示等待所有任务。 - `wait_after_min`: 达到最小任务阈值后,等待此秒数后再继续。 + - `return_partial_tasks`: 是否返回仅部分完成的任务结果(例如,在 GRPO 中仅完成部分 run 的任务)。默认为 `false`,表示仅返回已完成组内所有 run 的任务结果。 - `dynamic_timeout`: [实验性] 动态超时机制的配置,根据成功任务的平均耗时调整每个任务的超时时间。 - `enable`: 是否启用动态超时。默认为 `false`。 - `ratio`: 每个任务的超时时间动态设置为 `average_time_per_success_task * ratio`。默认为 `3.0`。 diff --git a/pyproject.toml b/pyproject.toml index c53f7b03cd8..b5463a4104b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,7 @@ data = [ "py-data-juicer>=1.4.3" ] agent = [ - "agentscope[tuner]>=1.0.18" + "agentscope[tuner]>=1.0.19" ] openjudge = [ "py-openjudge>=0.2.2" diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index 2e45804de4c..430eb7ec7d6 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -3,6 +3,7 @@ import unittest from collections import defaultdict from typing import Dict, List, Optional, Sequence +from unittest.mock import AsyncMock, patch import ray import torch @@ -99,6 +100,90 @@ def run(self) -> List[Experience]: return exps +@WORKFLOWS.register_module("dummy_partial_snapshot_workflow") +class DummyPartialSnapshotWorkflow(Workflow): + can_reset: bool = True + + def __init__(self, *, task, model, auxiliary_models): + super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) + self.reset(task) + + def reset(self, task: Task): + self.task = task + + def run(self) -> List[Experience]: + actions = self.task.workflow_args.get("actions_by_repeat_times", {}) + action = actions.get(self.task.rollout_args.n, "success") + + if action.startswith("sleep_"): + time.sleep(float(action.split("_", 1)[1])) + elif action == "exception": + raise ValueError("planned partial failure") + + return [ + Experience( + eid=EID(step=0), + tokens=torch.zeros(5), + prompt_length=2, + prompt_text=action, + metrics={"run_metrics": float(self.task.rollout_args.n * 10)}, + ) + ] + + +@WORKFLOWS.register_module("dummy_async_partial_snapshot_workflow") +class DummyAsyncPartialSnapshotWorkflow(Workflow): + can_reset: bool = True + is_async: bool = True + _run_index_by_key = defaultdict(int) + + def __init__(self, *, task, model, auxiliary_models): + super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) + self.reset(task) + + def reset(self, task: Task): + self.task = task + + async def run_async(self) -> List[Experience]: + actions = self.task.workflow_args.get("actions_by_repeat_times", {}) + action_sequences = self.task.workflow_args.get("action_sequence_by_repeat_times", {}) + metric_sequences = self.task.workflow_args.get("metric_sequence_by_repeat_times", {}) + + key = (self.task.batch_id, self.task.task_id, self.task.rollout_args.n) + run_index = DummyAsyncPartialSnapshotWorkflow._run_index_by_key[key] + DummyAsyncPartialSnapshotWorkflow._run_index_by_key[key] += 1 + + sequence_actions = action_sequences.get(self.task.rollout_args.n) + if sequence_actions and run_index < len(sequence_actions): + action = sequence_actions[run_index] + else: + action = actions.get(self.task.rollout_args.n, "success") + + if action.startswith("sleep_"): + await asyncio.sleep(float(action.split("_", 1)[1])) + elif action == "exception": + raise ValueError("planned partial failure") + + sequence = metric_sequences.get(self.task.rollout_args.n) + if sequence and run_index < len(sequence): + run_metric = float(sequence[run_index]) + else: + run_metric = float(self.task.rollout_args.n * 10) + + return [ + Experience( + eid=EID(step=0), + tokens=torch.zeros(5), + prompt_length=2, + prompt_text=action, + metrics={"run_metrics": run_metric}, + ) + ] + + def run(self): + raise RuntimeError("This method should not be called") + + @WORKFLOWS.register_module("dummy_async_workflow") class DummyAsyncWorkflow(Workflow): can_repeat: bool = True @@ -327,6 +412,7 @@ def generate_tasks( class SchedulerTest(unittest.IsolatedAsyncioTestCase): def setUp(self): ray.init(ignore_reinit_error=True) + DummyAsyncPartialSnapshotWorkflow._run_index_by_key.clear() self.config = get_template_config() self.config.explorer.max_retry_times = 1 self.config.explorer.max_timeout = 5 @@ -787,8 +873,9 @@ async def test_metric_calculation_with_repeatable_workflow(self, max_repeat_time statuses, exps = await scheduler.get_results(batch_id=0) self.assertEqual(len(statuses), 2) self.assertEqual(len(exps), 1 * 4 * 1 + 1 * 8 * 4) - self.assertAlmostEqual(statuses[0].metrics[0]["run_metrics"], 1.5) # (0+1+2+3)/4 - self.assertAlmostEqual(statuses[1].metrics[0]["run_metrics"], 3.5) # (0+1+2+3+4+5+6+7)/8 + expected_run_metrics = set({1.5, 3.5}) # (0+1+2+3)/4 and (0+1+2+3+4+5+6+7)/8 + actual_run_metrics = set(status.metrics[0]["run_metrics"] for status in statuses) + self.assertSetEqual(expected_run_metrics, actual_run_metrics) @parameterized.expand( [ @@ -836,6 +923,270 @@ async def test_over_rollout_min_wait(self): self.assertEqual(len(statuses), 3) self.assertEqual(len(exps), 3 * 1) + async def test_over_rollout_return_partial_tasks(self): + self.config.explorer.over_rollout.ratio = 0.5 + self.config.explorer.over_rollout.wait_after_min = 0.5 + self.config.explorer.over_rollout.return_partial_tasks = True + self.config.explorer.max_repeat_times_per_runner = 2 + self.config.synchronizer.sync_style = SyncStyle.EXPLORER_DRIVEN + self.config.buffer.batch_size = 2 + self.config.check_and_update() + scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) + await scheduler.start() + + tasks = [ + Task( + workflow=DummyWorkflow, # type: ignore[type-abstract] + workflow_args={"step_num": 1}, + repeat_times=1, + raw_task={}, + ), + Task( + # split into 2 sub-tasks with repeat_times=2 and 1. + # Inside the repeat_times=2 sub-task, one internal run succeeds and + # the other raises, so runner returns a partial subtask result. + # The repeat_times=1 sub-task remains in sleep_5 and is later cleaned + # up by over-rollout before it can contribute any metric. + workflow=DummyAsyncPartialSnapshotWorkflow, # type: ignore[type-abstract] + workflow_args={ + "action_sequence_by_repeat_times": { + 2: ["success", "exception"], + }, + "actions_by_repeat_times": { + 1: "sleep_5", + }, + "metric_sequence_by_repeat_times": { + 2: [10.0, 30.0], + }, + }, + repeat_times=3, + raw_task={}, + ), + ] + scheduler.schedule(tasks, batch_id=0) + + start_time = time.time() + statuses, exps = await scheduler.get_results( + batch_id=0, + min_num=1, + timeout=2, + return_partial_tasks=True, + ) + end_time = time.time() + + self.assertLess(end_time - start_time, 5) + self.assertEqual(len(statuses), 2) + self.assertEqual(len(exps), 2) + + full_status = next(status for status in statuses if status.total_runs == 1) + partial_status = next(status for status in statuses if status.total_runs == 3) + self.assertTrue(full_status.ok) + self.assertEqual(partial_status.completed_runs, 1) + self.assertFalse(partial_status.ok) + metric_bearing_exps = [exp for exp in exps if exp.metrics] + full_task_run_metrics = [ + exp.metrics["run_metrics"] + for exp in metric_bearing_exps + if exp.info and exp.info.get("repeat_times") == 1 + ] + partial_task_run_metrics = [ + exp.metrics["run_metrics"] + for exp in metric_bearing_exps + if not (exp.info and exp.info.get("repeat_times") == 1) + ] + + # The normal task should keep its own task-level metric, proving scheduler + # does not mix metrics between different tasks in the same batch. + self.assertEqual(full_task_run_metrics, [0.0]) + self.assertEqual(full_status.metrics[0]["run_metrics"], 0.0) + + # End-to-end check: the partially returned task contributes only the single + # successful run from the partially failed repeat_times=2 sub-task. The + # failed internal run and the cancelled repeat_times=1 sub-task must not be + # counted when scheduler computes the task-level metric. This also guards + # against retrying the partially successful sub-task, which would add extra + # successful runs and change both the emitted experiences and task metric. + # We identify partial-task experiences by excluding the known full-task + # marker (`info["repeat_times"] == 1`) rather than assuming empty info. + self.assertEqual(sorted(partial_task_run_metrics), [10.0]) + self.assertEqual( + partial_status.metrics[0]["run_metrics"], + sum(partial_task_run_metrics) / len(partial_task_run_metrics), + ) + self.assertEqual(partial_status.metrics[0]["run_metrics"], 10.0) + self.assertIn("1/3 runs completed successfully", partial_status.message) + + statuses, exps = await scheduler.get_results(batch_id=0, timeout=1) + self.assertEqual(len(statuses), 0) + self.assertEqual(len(exps), 0) + + await scheduler.stop() + + async def test_over_rollout_async_cancelled_runner_accepts_next_batch(self): + self.config.explorer.over_rollout.ratio = 0.5 + self.config.explorer.over_rollout.wait_after_min = 0.5 + self.config.explorer.over_rollout.return_partial_tasks = True + self.config.explorer.max_repeat_times_per_runner = 2 + self.config.explorer.runner_per_model = 1 + self.config.synchronizer.sync_style = SyncStyle.EXPLORER_DRIVEN + self.config.buffer.batch_size = 2 + self.config.check_and_update() + scheduler = Scheduler(self.config, [DummyModel.remote()]) + await scheduler.start() + + tasks = [ + Task( + workflow=DummyWorkflow, # type: ignore[type-abstract] + workflow_args={"step_num": 1}, + repeat_times=1, + raw_task={}, + ), + Task( + workflow=DummyAsyncPartialSnapshotWorkflow, # type: ignore[type-abstract] + workflow_args={ + "actions_by_repeat_times": { + 2: "success", + 1: "sleep_5", + } + }, + repeat_times=3, + raw_task={}, + ), + ] + scheduler.schedule(tasks, batch_id=0) + + statuses, exps = await scheduler.get_results( + batch_id=0, + min_num=1, + timeout=3, + return_partial_tasks=True, + ) + + self.assertEqual(len(statuses), 2) + self.assertEqual(len(exps), 3) + + follow_up_tasks = generate_tasks(2) + scheduler.schedule(follow_up_tasks, batch_id=1) + start_time = time.time() + next_statuses, next_exps = await scheduler.get_results(batch_id=1, min_num=2, timeout=2) + elapsed = time.time() - start_time + + self.assertEqual(len(next_statuses), 2) + self.assertEqual(len(next_exps), 2) + self.assertLess(elapsed, 1.5) + + await scheduler.stop() + + async def test_over_rollout_sync_cancel_does_not_imply_immediate_runner_reuse(self): + self.config.explorer.over_rollout.ratio = 0.5 + self.config.explorer.over_rollout.wait_after_min = 0.5 + self.config.explorer.over_rollout.return_partial_tasks = True + self.config.explorer.max_repeat_times_per_runner = 2 + self.config.explorer.runner_per_model = 1 + self.config.synchronizer.sync_style = SyncStyle.EXPLORER_DRIVEN + self.config.buffer.batch_size = 2 + self.config.check_and_update() + scheduler = Scheduler(self.config, [DummyModel.remote()]) + await scheduler.start() + + tasks = [ + Task( + workflow=DummyWorkflow, # type: ignore[type-abstract] + workflow_args={"step_num": 1}, + repeat_times=1, + raw_task={}, + ), + Task( + workflow=DummyPartialSnapshotWorkflow, # type: ignore[type-abstract] + workflow_args={ + "actions_by_repeat_times": { + 2: "success", + 1: "sleep_2", + } + }, + repeat_times=3, + raw_task={}, + ), + ] + scheduler.schedule(tasks, batch_id=0) + + statuses, exps = await scheduler.get_results( + batch_id=0, + min_num=1, + timeout=3, + return_partial_tasks=True, + ) + + self.assertEqual(len(statuses), 2) + self.assertEqual(len(exps), 3) + + follow_up_tasks = generate_tasks(2) + scheduler.schedule(follow_up_tasks, batch_id=1) + start_time = time.time() + next_statuses, next_exps = await scheduler.get_results( + batch_id=1, + min_num=2, + timeout=0.5, + clear_timeout_tasks=False, + ) + elapsed = time.time() - start_time + + self.assertEqual(len(next_statuses), 0) + self.assertEqual(len(next_exps), 0) + self.assertGreaterEqual(elapsed, 0.5) + self.assertTrue(scheduler.has_step(1)) + + next_statuses, next_exps = await scheduler.get_results(batch_id=1, min_num=2, timeout=4) + self.assertEqual(len(next_statuses), 2) + self.assertEqual(len(next_exps), 2) + + await scheduler.stop() + + async def test_timeout_cleanup_still_restarts_runner(self): + self.config.explorer.over_rollout.wait_after_min = 100 + self.config.explorer.max_repeat_times_per_runner = None + self.config.synchronizer.sync_style = SyncStyle.EXPLORER_DRIVEN + self.config.check_and_update() + scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) + await scheduler.start() + + tasks = generate_tasks(0, timeout_num=2, repeat_times=1, timeout_seconds=10) + scheduler.schedule(tasks, batch_id=0) + + with patch.object(scheduler, "_restart_runner", new=AsyncMock()) as restart_runner_mock: + await scheduler.get_results(batch_id=0, timeout=1) + + self.assertGreaterEqual(restart_runner_mock.await_count, 1) + + await scheduler.stop() + + async def test_unexpected_task_exception_restarts_runner(self): + self.config.explorer.runner_per_model = 1 + self.config.check_and_update() + scheduler = Scheduler(self.config, [DummyModel.remote()]) + await scheduler.start() + + scheduler.runners[0].run_with_retry = AsyncMock(side_effect=RuntimeError("boom")) + restart_event = asyncio.Event() + + async def fake_restart_runner(runner_id): + scheduler.busy_runners.pop(runner_id, None) + scheduler.idle_runners.add(runner_id) + restart_event.set() + + with patch.object( + scheduler, "_restart_runner", side_effect=fake_restart_runner + ) as restart_runner_mock: + scheduler.schedule(generate_tasks(1), batch_id=0) + await asyncio.wait_for(restart_event.wait(), timeout=2) + + self.assertEqual(restart_runner_mock.await_count, 1) + self.assertEqual(len(scheduler.busy_runners), 0) + self.assertEqual(len(scheduler.idle_runners), scheduler.runner_num) + self.assertNotIn(0, scheduler.running_tasks) + + await scheduler.stop() + async def test_dynamic_timeout(self): self.config.explorer.dynamic_timeout.enable = True self.config.explorer.dynamic_timeout.ratio = 3.0 diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py index 1c6e881a5ea..c52f9dbf61f 100644 --- a/tests/explorer/workflow_test.py +++ b/tests/explorer/workflow_test.py @@ -3,13 +3,14 @@ import asyncio import os import shutil +import threading import time import unittest from collections import defaultdict from dataclasses import dataclass, field from typing import Dict, Optional from unittest import mock -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock import openai import ray @@ -531,41 +532,6 @@ async def monitor_routine(): class TestAgentScopeWorkflowAdapter(unittest.IsolatedAsyncioTestCase): - async def test_adapter_v0(self): - try: - from agentscope.model import TrinityChatModel - except ImportError: - self.skipTest("agentscope >= 1.0.9 is not installed") - - async def as_workflow_func(task, model) -> float: - self.assertIsInstance(task, dict) - self.assertIsInstance(model, TrinityChatModel) - return task["reward"] - - model = MagicMock() - openai_client = MagicMock() - openai_client.model_path = "Qwen/Qwen3-8B" - model.get_openai_async_client.return_value = openai_client - model.extract_experience_from_history.return_value = [ - Experience(tokens=Tensor([0, 1, 2]), prompt_length=1, logprobs=Tensor([0.1, 0.2])), - Experience(tokens=Tensor([3, 4, 5]), prompt_length=2, logprobs=Tensor([0.3])), - ] - - as_adapter_cls = WORKFLOWS.get("agentscope_workflow_adapter") - as_adapter = as_adapter_cls( - task=Task( - raw_task={"reward": 0.1}, - workflow_args={"workflow_func": as_workflow_func}, - ), - model=model, - ) - result = await as_adapter.run_async() - self.assertEqual(len(result), 2) - self.assertEqual(result[0].reward, 0.1) - self.assertEqual(result[0].prompt_length, 1) - self.assertEqual(result[1].reward, 0.1) - self.assertEqual(result[1].prompt_length, 2) - async def test_adapter_v1(self): try: from agentscope.model import ChatModelBase @@ -660,6 +626,45 @@ async def run_async(self): return exps +class PartialFailureWorkflow(Workflow): + can_reset: bool = True + + _call_lock = threading.Lock() + _call_count = 0 + + def __init__(self, model: ModelWrapper, task: Task, auxiliary_models=None): + super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) + self.fail_call_ids = set(task.raw_task.get("fail_call_ids", [])) + + def reset(self, task: Task): + self.fail_call_ids = set(task.raw_task.get("fail_call_ids", [])) + + @classmethod + def reset_call_count(cls): + with cls._call_lock: + cls._call_count = 0 + + @classmethod + def next_call_id(cls) -> int: + with cls._call_lock: + call_id = cls._call_count + cls._call_count += 1 + return call_id + + def run(self): + call_id = self.next_call_id() + if call_id in self.fail_call_ids: + raise RuntimeError(f"Intentional failure for run call {call_id}") + + exp = Experience( + tokens=Tensor([0, 1, 2]), + prompt_length=1, + metrics={"run_metrics": float(call_id)}, + ) + exp.response_text = str(call_id) + return [exp] + + class TestWorkflowRunner(unittest.IsolatedAsyncioTestCase): async def test_workflow_runner(self): config = get_template_config() @@ -705,6 +710,133 @@ async def test_workflow_runner(self): self.assertIsInstance(exps, list) self.assertEqual(len(exps), 2) + @parameterized.expand( + [ + ("sequential", 2), + ("asynchronous", 2), + ("multi-threading", 2), + ] + ) + async def test_workflow_runner_partial_success_non_repeatable( + self, concurrent_mode: str, expected_success_runs: int + ): + config = get_template_config() + config.explorer.concurrent_mode = concurrent_mode + + with mock.patch( + "trinity.explorer.workflow_runner.ModelWrapper", + DummyModelWrapper, + ): + runner = WorkflowRunner( + config, + model=MagicMock(), + auxiliary_models=[], + runner_id=0, + ) + await runner.prepare() + + PartialFailureWorkflow.reset_call_count() + task = Task( + workflow=PartialFailureWorkflow, + repeat_times=3, + raw_task={"fail_call_ids": [1]}, + ) + + status, exps = await runner.run_task( + task, batch_id="test", repeat_times=3, run_id_base=0 + ) + + self.assertFalse(status.ok) + self.assertEqual(status.completed_runs, expected_success_runs) + self.assertEqual(status.total_runs, 3) + self.assertEqual(len(exps), expected_success_runs) + + # One internal run fails with call_id=1, so runner-level metrics should + # retain only the successful runs from this single subtask: call_id=0 and 2. + self.assertEqual(len(status.metrics), expected_success_runs) + self.assertEqual( + sorted(metric["run_metrics"] for metric in status.metrics), + [0.0, 2.0], + ) + + # Experiences returned from the runner should match the same successful + # run set, proving failed runs do not leak into partial-return outputs. + self.assertEqual( + sorted(exp.metrics["run_metrics"] for exp in exps if exp.metrics), + [0.0, 2.0], + ) + self.assertIn(f"{expected_success_runs}/3 runs completed successfully", status.message) # type: ignore[arg-type] + + @parameterized.expand( + [ + ("sequential",), + ("asynchronous",), + ("multi-threading",), + ] + ) + async def test_workflow_runner_fail_fast_without_partial_collection(self, concurrent_mode: str): + config = get_template_config() + config.explorer.concurrent_mode = concurrent_mode + + with mock.patch( + "trinity.explorer.workflow_runner.ModelWrapper", + DummyModelWrapper, + ): + runner = WorkflowRunner( + config, + model=MagicMock(), + auxiliary_models=[], + runner_id=0, + ) + await runner.prepare() + + task = Task( + workflow=PartialFailureWorkflow, + repeat_times=3, + raw_task={"fail_call_ids": []}, + ) + + async def mock_execute_single_run( + workflow: Workflow, + task: Task, + run_index: int, + run_id_base: int, + ): + if run_index == 0: + await asyncio.sleep(0.01) + exp = Experience( + tokens=Tensor([0, 1, 2]), + prompt_length=1, + metrics={"run_metrics": 0.0}, + ) + return True, [exp], {"run_metrics": 0.0}, None + if run_index == 1: + await asyncio.sleep(0.02) + return False, [], None, "planned failure" + await asyncio.sleep(0.5) + exp = Experience( + tokens=Tensor([0, 1, 2]), + prompt_length=1, + metrics={"run_metrics": 2.0}, + ) + return True, [exp], {"run_metrics": 2.0}, None + + runner._execute_single_run = AsyncMock(side_effect=mock_execute_single_run) + + status, exps = await runner.run_task( + task, + batch_id="test", + repeat_times=3, + run_id_base=0, + collect_partial_runs=False, + ) + + self.assertFalse(status.ok) + self.assertEqual(status.completed_runs, 1) + self.assertEqual(status.total_runs, 3) + self.assertEqual(len(exps), 1) + self.assertIn("1/3 runs completed successfully", status.message) # type: ignore[arg-type] + async def test_workflow_runner_get_state(self): config = get_template_config() diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 17c333f8695..66b938c6e3a 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -1367,6 +1367,7 @@ def test_trainer(self): self.config.buffer.eval_interval = 4 # only eval on start self.config.name = f"explore-over-rollout-{datetime.now().strftime('%Y%m%d%H%M%S')}" self.config.explorer.over_rollout.ratio = 0.5 # set over rollout rate to 50%, which means only wait for 2 (4 * 50%) tasks in each steps + self.config.explorer.over_rollout.return_partial_tasks = True self.config.explorer.over_rollout.wait_after_min = 0 self.config.explorer.dynamic_timeout.enable = True self.config.explorer.dynamic_timeout.ratio = 2 diff --git a/trinity/common/config.py b/trinity/common/config.py index 749a4f637f2..c3ac8ab1bfd 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -145,6 +145,9 @@ class OverRolloutConfig: ratio: float = 0.0 # explorer will only wait for (1 - over_rollout.ratio) * batch_size of tasks at each step wait_after_min: float = 30.0 # wait 30 s after reaching minimum task threshold + return_partial_tasks: bool = ( + False # return tasks with partial successful runs during over-rollout cleanup + ) # more settings will be added in the future # e.g., postpone tasks into the next step if not finished in time diff --git a/trinity/common/config_validator.py b/trinity/common/config_validator.py index b96f4169386..57069ad313f 100644 --- a/trinity/common/config_validator.py +++ b/trinity/common/config_validator.py @@ -655,12 +655,15 @@ def validate(self, config: Config) -> None: for args in model_args: set_if_none(aux_model, args, getattr(config.model, args)) - if config.explorer.over_rollout.ratio > 0.0: + if ( + config.explorer.over_rollout.ratio > 0.0 + or config.explorer.over_rollout.return_partial_tasks + ): if not (0.0 <= config.explorer.over_rollout.ratio < 1.0): raise ValueError("over_rollout_ratio should be in [0.0, 1.0)") if config.synchronizer.sync_style == SyncStyle.FIXED: raise ValueError( - "over_rollout_ratio is not compatible with fixed sync_style, please set " + "over_rollout is not compatible with fixed sync_style, please set " "`synchronizer.sync_style` to `explorer_driven` or `trainer_driven`." ) diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index b78ba7defeb..db82c83067a 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -391,7 +391,9 @@ async def _finish_explore_step(self, step: int, model_version: int) -> None: metric = {"rollout/model_version": model_version} with Timer(metric, "time/wait_explore_step"): statuses, exps = await self.scheduler.get_results( - batch_id=step, min_num=self.min_wait_num + batch_id=step, + min_num=self.min_wait_num, + return_partial_tasks=self.config.explorer.over_rollout.return_partial_tasks, ) if self.experience_pipeline is not None: pipeline_metrics = await self.experience_pipeline.process.remote( @@ -415,7 +417,10 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva if eval_step != step: return self.pending_eval_tasks.popleft() - statuses, _ = await self.scheduler.get_results(batch_id=f"{step}/{eval_task_name}") + statuses, _ = await self.scheduler.get_results( + batch_id=f"{step}/{eval_task_name}", + return_partial_tasks=self.config.explorer.over_rollout.return_partial_tasks, + ) metric[f"{prefix}/{eval_task_name}/finished_task_count"] = len(statuses) metric.update( gather_eval_metrics( diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index e6d7bb46e4d..ea3206ef00b 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -29,7 +29,13 @@ class TaskWrapper: batch_id: Union[int, str] sub_task_num: int = 1 # number of sub tasks splitted from this task # if max_repeat_times_per_runner is set, one task may be splitted into multiple sub tasks - results: List[Tuple[Status, List[Experience]]] = field(default_factory=list) + finished_sub_task_num: int = 0 + completed_runs: int = 0 + total_runs: int = 0 # total planned runs for the whole task + metrics: List[Dict[str, float]] = field(default_factory=list) + experiences: List[Experience] = field(default_factory=list) + first_error: Optional[str] = None + emitted: bool = False # Adapted from verl/trainer/ppo/metric_utils.py @@ -167,7 +173,12 @@ async def update_state(self) -> None: self.state["running_time"] = time.time() - self.state.get("begin_time", time.time()) async def run_with_retry( - self, task: TaskWrapper, repeat_times: int, run_id_base: int, timeout: float + self, + task: TaskWrapper, + repeat_times: int, + run_id_base: int, + timeout: float, + collect_partial_runs: bool, ) -> Tuple[Status, List, int, float]: """ Args: @@ -185,8 +196,9 @@ async def run_with_retry( last_exception_msg = None await self.runner.__ray_ready__.remote() start_time = time.time() - status = Status(ok=False, metrics=list()) + status = Status(completed_runs=0, total_runs=repeat_times, metrics=list()) exps = [] + run_task_ref = None task2run = replace( task.task, rollout_args=replace( @@ -197,29 +209,55 @@ async def run_with_retry( try: for attempt in range(self.retry_times + 1): try: + run_task_ref = self.runner.run_task.remote( + task=task2run, + batch_id=str(task.batch_id), + repeat_times=repeat_times, + run_id_base=run_id_base, + collect_partial_runs=collect_partial_runs, + ) status, exps = await asyncio.wait_for( - self.runner.run_task.remote( - task=task2run, - batch_id=str(task.batch_id), - repeat_times=repeat_times, - run_id_base=run_id_base, - ), + run_task_ref, timeout=timeout, ) + run_task_ref = None if status.ok: break + if collect_partial_runs and status.completed_runs > 0: + self.logger.warning( + "Task returned partial success; skipping retry to avoid " + "re-running successful runs. %s", + status.message, + ) + break else: self.logger.error(status.message) except asyncio.TimeoutError: + run_task_ref = None last_exception_msg = f"Timeout when running task of batch {task.batch_id} at runner {self.runner_id} at attempt {attempt + 1}: {task.task}" self.logger.error(last_exception_msg) - status = Status(ok=False, metrics=list(), message=last_exception_msg) + status = Status( + completed_runs=0, + total_runs=repeat_times, + metrics=list(), + message=last_exception_msg, + ) + except asyncio.CancelledError: + if run_task_ref is not None: + ray.cancel(run_task_ref, force=False) + raise except Exception: + run_task_ref = None last_exception_msg = traceback.format_exc() self.logger.warning( f"Task execution attempt {attempt + 1} failed:\n{last_exception_msg}" ) - status = Status(ok=False, metrics=list(), message=last_exception_msg) + status = Status( + completed_runs=0, + total_runs=repeat_times, + metrics=list(), + message=last_exception_msg, + ) finally: end_time = time.time() status.metrics.append({"time/task_execution": end_time - start_time}) @@ -288,11 +326,14 @@ def __init__( int ) # batch_id -> tasks scheduled under this batch_id self.running_task_map: Dict[asyncio.Future, TaskWrapper] = dict() # future -> task + self.running_task_runner_map: Dict[asyncio.Future, int] = dict() # future -> runner_id + self.cancelled_task_restart_map: Dict[asyncio.Future, bool] = dict() self.completed_tasks: Dict[ Union[int, str], deque[Tuple[Status, List[Experience]]] ] = defaultdict( deque ) # batch_id -> results + self.background_tasks: set[asyncio.Task] = set() self.scheduler_task: Optional[asyncio.Task] = None self.running = False @@ -330,6 +371,11 @@ async def _restart_runner(self, runner_id: int): self.idle_runners.add(runner_id) self.logger.info(f"Runner {runner_id} restarted.") + def _schedule_runner_restart(self, runner_id: int) -> None: + restart_task = asyncio.create_task(self._restart_runner(runner_id)) + self.background_tasks.add(restart_task) + restart_task.add_done_callback(self.background_tasks.discard) + async def _scheduler_loop(self) -> None: self.logger.info("Scheduler loop started.") while self.running: @@ -377,9 +423,11 @@ async def _schedule_pending_tasks(self) -> None: repeat_times=repeat_times, run_id_base=run_id_base, timeout=self.dynamic_timeout(), + collect_partial_runs=self.config.explorer.over_rollout.return_partial_tasks, ) ) self.running_task_map[future] = task + self.running_task_runner_map[future] = runner_id future.add_done_callback(self.task_done_callback) self.running_tasks[batch_id].add(future) @@ -388,35 +436,28 @@ async def _schedule_pending_tasks(self) -> None: def task_done_callback(self, async_task: asyncio.Task): task = self.running_task_map.pop(async_task) + runner_id = self.running_task_runner_map.pop(async_task) if async_task.cancelled(): - return + should_restart = self.cancelled_task_restart_map.pop(async_task, True) + if not should_restart: + self.busy_runners.pop(runner_id, None) + self.idle_runners.add(runner_id) elif async_task.exception(): self.logger.error(f"Task {task.task.task_id} failed: {async_task.exception()}") - return + self.cancelled_task_restart_map.pop(async_task, None) + self._schedule_runner_restart(runner_id) else: + self.cancelled_task_restart_map.pop(async_task, None) status, exps, runner_id, run_time = async_task.result() if not task.task.is_eval: # only count running time for non-eval tasks self.total_running_time += run_time self.total_completed_tasks += 1 - task.results.append((status, exps)) - self.busy_runners.pop(runner_id) + self._accumulate_task_result(task, status, exps) + self.busy_runners.pop(runner_id, None) self.idle_runners.add(runner_id) # If all sub runs in a task are completed - if len(task.results) == task.sub_task_num: - task_experiences = [] - task_metrics = [] - all_success = True - for s, exp in task.results: - task_metrics.extend(s.metrics) - task_experiences.extend(exp) - if not s.ok: - all_success = False - # calculate task level metrics - task_status = Status( - ok=all_success, - metrics=[calculate_task_level_metrics(task_metrics, task.task.is_eval)], - ) - self.completed_tasks[task.batch_id].appendleft((task_status, task_experiences)) + if task.finished_sub_task_num == task.sub_task_num: + self._emit_task_result(task) self.logger.debug(f"Task completed (batch_id {task.batch_id}).") if task.batch_id in self.running_tasks: @@ -424,16 +465,77 @@ def task_done_callback(self, async_task: asyncio.Task): if not self.running_tasks[task.batch_id]: del self.running_tasks[task.batch_id] - def _clear_timeout_tasks(self, batch_id: Union[int, str]) -> None: + def _accumulate_task_result( + self, task: TaskWrapper, status: Status, experiences: List[Experience] + ) -> None: + task.finished_sub_task_num += 1 + task.completed_runs += status.completed_runs + task.metrics.extend(status.metrics) + task.experiences.extend(experiences) + if not status.ok and task.first_error is None: + task.first_error = status.message + + def _build_task_result(self, task: TaskWrapper) -> Tuple[Status, List[Experience]]: + if task.completed_runs < task.total_runs: + message = f"{task.completed_runs}/{task.total_runs} runs completed successfully." + if task.first_error: + message = f"{message} First error: {task.first_error}" + else: + message = f"{message} Remaining runs were cancelled during scheduler cleanup." + else: + message = None + status = Status( + completed_runs=task.completed_runs, + total_runs=task.total_runs, + metrics=[calculate_task_level_metrics(task.metrics, task.task.is_eval)], + message=message, + ) + return status, list(task.experiences) + + def _emit_task_result(self, task: TaskWrapper) -> None: + if task.emitted: + return + self.completed_tasks[task.batch_id].appendleft(self._build_task_result(task)) + task.emitted = True + + def _collect_incomplete_tasks(self, batch_id: Union[int, str]) -> List[TaskWrapper]: + tasks = {} + for task, _, _ in self.pending_tasks.get(batch_id, deque()): + tasks[id(task)] = task + for future in self.running_tasks.get(batch_id, set()): + task = self.running_task_map.get(future) + if task is not None: + tasks[id(task)] = task + return list(tasks.values()) + + def _emit_partial_tasks_for_batch(self, batch_id: Union[int, str]) -> None: + for task in self._collect_incomplete_tasks(batch_id): + if task.emitted or task.completed_runs <= 0: + continue + self._emit_task_result(task) + self.logger.debug( + f"Task partially completed and emitted (batch_id {task.batch_id}, task_id {task.task.task_id})." + ) + + def _clear_timeout_tasks(self, batch_id: Union[int, str]) -> List[asyncio.Future]: + cancelled_futures = [] if batch_id in self.pending_tasks: self.logger.info(f"Clear timeout pending tasks at batch_id {batch_id}.") del self.pending_tasks[batch_id] if batch_id in self.running_tasks: self.logger.info(f"Clear timeout running tasks at batch_id {batch_id}.") for future in self.running_tasks[batch_id]: + cancelled_futures.append(future) future.cancel() del self.running_tasks[batch_id] self.task_num_map.pop(batch_id, None) + return cancelled_futures + + def _mark_running_tasks_for_cleanup( + self, batch_id: Union[int, str], restart_runners: bool + ) -> None: + for future in self.running_tasks.get(batch_id, set()): + self.cancelled_task_restart_map[future] = restart_runners async def start(self) -> None: if self.running: @@ -459,6 +561,9 @@ async def stop(self) -> None: self.logger.info(f"Waiting for {len(all_running_futures)} running tasks to complete...") await asyncio.gather(*all_running_futures, return_exceptions=True) + if self.background_tasks: + await asyncio.gather(*list(self.background_tasks), return_exceptions=True) + if self.scheduler_task: self.scheduler_task.cancel() try: @@ -493,6 +598,7 @@ def _split_and_submit_tasks(self, tasks: List[Task], batch_id: Union[int, str]) task_wrapper = TaskWrapper( task=replace(task, batch_id=batch_id, task_id=i), batch_id=batch_id, + total_runs=task.repeat_times, ) if self.max_repeat_times is None: task_wrapper.sub_task_num = 1 @@ -518,9 +624,21 @@ def dynamic_timeout(self, timeout: Optional[float] = None) -> float: avg_time_per_task * self.config.explorer.dynamic_timeout.ratio, ) - async def _cleanup_batch_and_restart_runners(self, batch_id: Union[int, str]) -> None: - """Clear timeout tasks for a batch and restart associated runners.""" - self._clear_timeout_tasks(batch_id=batch_id) + async def _cleanup_batch( + self, + batch_id: Union[int, str], + return_partial_tasks: bool = False, + restart_runners: bool = True, + ) -> None: + """Clear unfinished tasks for a batch and optionally restart associated runners.""" + if return_partial_tasks: + self._emit_partial_tasks_for_batch(batch_id) + self._mark_running_tasks_for_cleanup(batch_id, restart_runners=restart_runners) + cancelled_futures = self._clear_timeout_tasks(batch_id=batch_id) + if cancelled_futures: + await asyncio.gather(*cancelled_futures, return_exceptions=True) + if not restart_runners: + return runners_to_restart = [ runner_id for runner_id, task in list(self.busy_runners.items()) @@ -535,6 +653,7 @@ async def get_results( min_num: Optional[int] = None, timeout: Optional[float] = None, clear_timeout_tasks: bool = True, + return_partial_tasks: bool = False, ) -> Tuple[List[Status], List[Experience]]: """Get the result of tasks at the specific batch_id. @@ -543,6 +662,7 @@ async def get_results( min_num (`int`): The minimum number of tasks to wait for. If `None`, wait for all tasks at `batch_id`. timeout (`float`): The timeout for waiting for tasks to finish. If `None`, wait for default timeout. clear_timeout_tasks (`bool`): Whether to clear timeout tasks. + return_partial_tasks (`bool`): Whether to emit tasks with partial successful runs when cleaning up unfinished tasks. """ timeout = timeout or self.default_timeout start_time = time.time() @@ -568,7 +688,11 @@ async def get_results( >= self.config.explorer.over_rollout.wait_after_min ): if clear_timeout_tasks: - await self._cleanup_batch_and_restart_runners(batch_id) + await self._cleanup_batch( + batch_id, + return_partial_tasks=return_partial_tasks, + restart_runners=False, + ) break await asyncio.sleep(0.1) @@ -577,7 +701,11 @@ async def get_results( f"Timed out waiting for tasks at batch {batch_id} to complete after {timeout} seconds" ) if clear_timeout_tasks: - await self._cleanup_batch_and_restart_runners(batch_id) + await self._cleanup_batch( + batch_id, + return_partial_tasks=return_partial_tasks, + restart_runners=True, + ) statuses = [] experiences = [] diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index fc14bdd9f78..5bd7fc667c3 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -22,11 +22,24 @@ class Status: """Status of the task running result.""" - ok: bool + completed_runs: int + total_runs: int metrics: List[Dict[str, float]] # A list of metric dictionaries, where each dictionary is from a single run. message: Optional[str] = None + @property + def ok(self) -> bool: + return self.completed_runs == self.total_runs + + +@dataclass(frozen=True) +class RunnerExecutionResult: + """Execution result for one runner task.""" + + status: Status + experiences: List[Experience] + def calculate_run_level_metrics(experiences: List[Experience]) -> Dict[str, float]: """Calculate metrics from experiences. @@ -136,9 +149,163 @@ async def _run_workflow(self, workflow_instance: Workflow) -> List[Experience]: exps = workflow_instance.run() return exps + def _create_isolated_workflow_instance(self, task: Task) -> Workflow: + return task.to_workflow( + self.model_wrapper.clone_with_isolated_history() + if self.config.explorer.rollout_model.enable_history + else self.model_wrapper, + self.auxiliary_model_wrappers, + ) + + def _build_execution_result( + self, + total_runs: int, + completed_runs: int, + metrics: List[Dict[str, float]], + experiences: List[Experience], + first_error: Optional[str] = None, + ) -> RunnerExecutionResult: + if first_error is None: + message = None + elif completed_runs > 0: + message = ( + f"{completed_runs}/{total_runs} runs completed successfully. " + f"First error: {first_error}" + ) + else: + message = first_error + + return RunnerExecutionResult( + status=Status( + completed_runs=completed_runs, + total_runs=total_runs, + metrics=list(metrics), + message=message, + ), + experiences=experiences, + ) + + def _aggregate_run_results( + self, + total_runs: int, + results: List[Tuple[bool, List[Experience], Optional[Dict[str, float]], Optional[str]]], + ) -> RunnerExecutionResult: + exps = [] + run_metrics = [] + first_error = None + + for ok, new_exps, run_metric, error in results: + if ok: + exps.extend(new_exps) + if run_metric is not None: + run_metrics.append(run_metric) + continue + if first_error is None: + first_error = error + + return self._build_execution_result( + total_runs=total_runs, + completed_runs=len(run_metrics), + metrics=run_metrics, + experiences=exps, + first_error=first_error, + ) + + async def _run_parallel_runs( + self, + task: Task, + repeat_times: int, + run_id_base: int, + collect_partial_runs: bool = True, + use_threads: bool = False, + ) -> RunnerExecutionResult: + async def run_single( + i: int, + ) -> Tuple[bool, List[Experience], Optional[Dict[str, float]], Optional[str]]: + workflow = self._create_isolated_workflow_instance(task) + return await self._execute_single_run(workflow, task, i, run_id_base) + + if collect_partial_runs: + if use_threads: + results = await asyncio.gather( + *( + asyncio.to_thread(lambda idx=i: asyncio.run(run_single(idx))) # type: ignore[misc] + for i in range(repeat_times) + ) + ) + else: + results = await asyncio.gather(*(run_single(i) for i in range(repeat_times))) + return self._aggregate_run_results(repeat_times, results) + + future_to_run_index = {} + for i in range(repeat_times): + if use_threads: + future = asyncio.create_task( + asyncio.to_thread(lambda idx=i: asyncio.run(run_single(idx))) # type: ignore[misc] + ) + else: + future = asyncio.create_task(run_single(i)) + future_to_run_index[future] = i + + results = [] + while future_to_run_index: + done, pending = await asyncio.wait( + future_to_run_index.keys(), + return_when=asyncio.FIRST_COMPLETED, + ) + should_stop = False + for future in done: + future_to_run_index.pop(future) + result = future.result() + results.append(result) + ok, _, _, _ = result + if not ok: + should_stop = True + if should_stop: + for future in pending: + future.cancel() + if pending: + await asyncio.gather(*pending, return_exceptions=True) + break + + return self._aggregate_run_results(repeat_times, results) + + async def _execute_single_run( + self, + workflow: Workflow, + task: Task, + run_index: int, + run_id_base: int, + ) -> Tuple[bool, List[Experience], Optional[Dict[str, float]], Optional[str]]: + st = time.time() + await self.model_wrapper.clean_workflow_state() + self.runner_state["workflow_id"] = f"{task.batch_id}/{task.task_id}/{run_index}" + self.runner_state["terminate_time"] = None + self.runner_state["begin_time"] = st + try: + new_exps = await self._run_workflow(workflow) + et = time.time() + self.runner_state["terminate_time"] = et + run_metric = calculate_run_level_metrics(new_exps) + run_metric["time/run_execution"] = et - st + for exp in new_exps: + exp.eid.run = run_id_base + run_index + return True, new_exps, run_metric, None + except Exception as exc: + self.runner_state["terminate_time"] = time.time() + error_trace_back = traceback.format_exc() + self.logger.error( + "WorkflowRunner single run error: " f"{exc}\nTraceback:\n{error_trace_back}" + ) + return False, [], None, error_trace_back.rstrip() + async def _run_task( - self, task: Task, repeat_times: int, run_id_base: int - ) -> Tuple[List[Experience], List[Dict]]: + self, + task: Task, + repeat_times: int, + run_id_base: int, + collect_partial_runs: bool = True, + ) -> RunnerExecutionResult: """Init workflow from the task and run it.""" if task.workflow.can_repeat: workflow_instance = self._create_workflow_instance(task) @@ -155,114 +322,69 @@ async def _run_task( run_metrics = [exp.metrics for exp in exps if exp.metrics] for metric in run_metrics: metric["time/run_execution"] = et - st + return self._build_execution_result( + total_runs=repeat_times, + completed_runs=repeat_times, + metrics=run_metrics, + experiences=exps, + ) else: - exps, run_metrics = await self.concurrent_run_fn(task, repeat_times, run_id_base) - return exps, run_metrics + return await self.concurrent_run_fn( + task, + repeat_times, + run_id_base, + collect_partial_runs=collect_partial_runs, + ) async def _sequential_run( self, task: Task, repeat_times: int, run_id_base: int, - ) -> Tuple[List[Experience], List[Dict]]: - exps = [] - run_metrics = [] + collect_partial_runs: bool = True, + ) -> RunnerExecutionResult: + results = [] for i in range(repeat_times): - st = time.time() workflow = self._create_workflow_instance(task) - await self.model_wrapper.clean_workflow_state() - self.runner_state["workflow_id"] = f"{task.batch_id}/{task.task_id}/{i}" - self.runner_state["terminate_time"] = None - self.runner_state["begin_time"] = st - new_exps = await self._run_workflow(workflow) - et = time.time() - self.runner_state["terminate_time"] = et - run_metric = calculate_run_level_metrics(new_exps) - run_metric["time/run_execution"] = et - st - run_metrics.append(run_metric) - for exp in new_exps: - exp.eid.run = run_id_base + i - exps.extend(new_exps) - return exps, run_metrics + result = await self._execute_single_run(workflow, task, i, run_id_base) + results.append(result) + if collect_partial_runs: + continue + ok, _, _, _ = result + if ok: + continue + break + return self._aggregate_run_results(repeat_times, results) async def _asynchronous_run( self, task: Task, repeat_times: int, run_id_base: int, - ) -> Tuple[List[Experience], List[Dict]]: - async def run_single(i: int) -> Tuple[List[Experience], Dict]: - st = time.time() - workflow = task.to_workflow( - self.model_wrapper.clone_with_isolated_history() - if self.config.explorer.rollout_model.enable_history - else self.model_wrapper, - self.auxiliary_model_wrappers, - ) - await self.model_wrapper.clean_workflow_state() - self.runner_state["workflow_id"] = f"{task.batch_id}/{task.task_id}/{i}" - self.runner_state["terminate_time"] = None - self.runner_state["begin_time"] = st - new_exps = await self._run_workflow(workflow) - et = time.time() - self.runner_state["terminate_time"] = et - run_metric = calculate_run_level_metrics(new_exps) - run_metric["time/run_execution"] = et - st - for exp in new_exps: - exp.eid.run = run_id_base + i - return new_exps, run_metric - - tasks = [run_single(i) for i in range(repeat_times)] - results = await asyncio.gather(*tasks) - exps = [] - run_metrics = [] - for new_exps, run_metric in results: - exps.extend(new_exps) - run_metrics.append(run_metric) - return exps, run_metrics + collect_partial_runs: bool = True, + ) -> RunnerExecutionResult: + return await self._run_parallel_runs( + task, + repeat_times, + run_id_base, + collect_partial_runs=collect_partial_runs, + ) async def _multi_threading_run( self, task: Task, repeat_times: int, run_id_base: int, - ) -> Tuple[List[Experience], List[Dict]]: - async def run_single(i: int) -> Tuple[List[Experience], Dict]: - st = time.time() - await self.model_wrapper.clean_workflow_state() - workflow = task.to_workflow( - self.model_wrapper.clone_with_isolated_history() - if self.config.explorer.rollout_model.enable_history - else self.model_wrapper, - self.auxiliary_model_wrappers, - ) - self.runner_state["workflow_id"] = f"{task.batch_id}/{task.task_id}/{i}" - self.runner_state["terminate_time"] = None - self.runner_state["begin_time"] = st - new_exps = await self._run_workflow(workflow) - et = time.time() - self.runner_state["terminate_time"] = et - run_metric = calculate_run_level_metrics(new_exps) - run_metric["time/run_execution"] = et - st - for exp in new_exps: - exp.eid.run = run_id_base + i - return new_exps, run_metric - - # Use asyncio.to_thread to run async tasks in threads - results = await asyncio.gather( - *( - asyncio.to_thread(lambda idx=i: asyncio.run(run_single(idx))) # type: ignore[misc] - for i in range(repeat_times) - ) + collect_partial_runs: bool = True, + ) -> RunnerExecutionResult: + return await self._run_parallel_runs( + task, + repeat_times, + run_id_base, + collect_partial_runs=collect_partial_runs, + use_threads=True, ) - exps = [] - run_metrics = [] - for new_exps, run_metric in results: - exps.extend(new_exps) - run_metrics.append(run_metric) - return exps, run_metrics - async def get_runner_state(self) -> Dict: """Get the runner state.""" runner_state = self.runner_state.copy() @@ -275,6 +397,7 @@ async def run_task( batch_id: str, repeat_times: int = 1, run_id_base: int = 0, + collect_partial_runs: bool = True, ) -> Tuple[Status, List[Experience]]: """Run the task and return the states.""" # TODO: avoid sending the experiences back to the scheduler to reduce the communication overhead @@ -285,8 +408,15 @@ async def run_task( self.logger.info( f"Starting task: step={batch_id}, model_version={model_version}, repeat_times={repeat_times}, run_id_base={run_id_base}" ) - exps, metrics = await self._run_task(task, repeat_times, run_id_base) - assert exps is not None and len(exps) > 0, "An empty experience is generated" + execution_result = await self._run_task( + task, + repeat_times, + run_id_base, + collect_partial_runs=collect_partial_runs, + ) + exps = execution_result.experiences + if execution_result.status.completed_runs > 0: + assert exps is not None and len(exps) > 0, "An empty experience is generated" # set eid for each experience for exp in exps: exp.eid.batch = task.batch_id @@ -302,17 +432,24 @@ async def run_task( if not hasattr(exp, "metrics") or exp.metrics is None: exp.metrics = {} + status = execution_result.status + if task.is_eval: # If the task is an evaluation task, we do not record the experiences to the buffer - return Status(True, metrics=metrics), [] + return status, [] else: - return Status(True, metrics=metrics), exps + return status, exps except Exception as e: error_trace_back = traceback.format_exc() self.logger.error(f"WorkflowRunner run task error: {e}\nTraceback:\n{error_trace_back}") return ( - Status(False, metrics=[{"time/run_execution": time.time() - st}], message=str(e)), + Status( + completed_runs=0, + total_runs=repeat_times, + metrics=[{"time/run_execution": time.time() - st}], + message=error_trace_back.rstrip(), + ), [], )