Skip to content
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ data = [
"py-data-juicer>=1.4.3"
]
agent = [
"agentscope[tuner]>=1.0.18"
"agentscope[tuner]>=1.0.19"
]
openjudge = [
"py-openjudge>=0.2.2"
Expand Down
94 changes: 92 additions & 2 deletions tests/explorer/scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,37 @@ def run(self) -> List[Experience]:
return exps


@WORKFLOWS.register_module("dummy_partial_snapshot_workflow")
class DummyPartialSnapshotWorkflow(Workflow):
can_reset: bool = True

def __init__(self, *, task, model, auxiliary_models):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
self.reset(task)

def reset(self, task: Task):
self.task = task

def run(self) -> List[Experience]:
actions = self.task.workflow_args.get("actions_by_repeat_times", {})
action = actions.get(self.task.rollout_args.n, "success")

if action.startswith("sleep_"):
time.sleep(float(action.split("_", 1)[1]))
elif action == "exception":
raise ValueError("planned partial failure")

return [
Experience(
eid=EID(step=0),
tokens=torch.zeros(5),
prompt_length=2,
prompt_text=action,
metrics={"run_metrics": float(self.task.rollout_args.n * 10)},
)
]


@WORKFLOWS.register_module("dummy_async_workflow")
class DummyAsyncWorkflow(Workflow):
can_repeat: bool = True
Expand Down Expand Up @@ -787,8 +818,9 @@ async def test_metric_calculation_with_repeatable_workflow(self, max_repeat_time
statuses, exps = await scheduler.get_results(batch_id=0)
self.assertEqual(len(statuses), 2)
self.assertEqual(len(exps), 1 * 4 * 1 + 1 * 8 * 4)
self.assertAlmostEqual(statuses[0].metrics[0]["run_metrics"], 1.5) # (0+1+2+3)/4
self.assertAlmostEqual(statuses[1].metrics[0]["run_metrics"], 3.5) # (0+1+2+3+4+5+6+7)/8
expected_run_metrics = set({1.5, 3.5}) # (0+1+2+3)/4 and (0+1+2+3+4+5+6+7)/8
actual_run_metrics = set(status.metrics[0]["run_metrics"] for status in statuses)
self.assertSetEqual(expected_run_metrics, actual_run_metrics)

@parameterized.expand(
[
Expand Down Expand Up @@ -836,6 +868,64 @@ async def test_over_rollout_min_wait(self):
self.assertEqual(len(statuses), 3)
self.assertEqual(len(exps), 3 * 1)

async def test_over_rollout_return_partial_tasks(self):
self.config.explorer.over_rollout.ratio = 0.5
self.config.explorer.over_rollout.wait_after_min = 0.5
self.config.explorer.over_rollout.return_partial_tasks = True
self.config.explorer.max_repeat_times_per_runner = 2
self.config.synchronizer.sync_style = SyncStyle.EXPLORER_DRIVEN
self.config.buffer.batch_size = 2
self.config.check_and_update()
scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()])
await scheduler.start()

tasks = [
Task(
workflow=DummyWorkflow, # type: ignore[type-abstract]
workflow_args={"step_num": 1},
repeat_times=1,
raw_task={},
),
Task(
workflow=DummyPartialSnapshotWorkflow, # type: ignore[type-abstract]
workflow_args={
"actions_by_repeat_times": {
2: "success",
1: "sleep_5",
}
},
repeat_times=3,
raw_task={},
),
]
scheduler.schedule(tasks, batch_id=0)

start_time = time.time()
statuses, exps = await scheduler.get_results(
batch_id=0,
min_num=1,
timeout=3,
return_partial_tasks=True,
)
end_time = time.time()

self.assertLess(end_time - start_time, 5)
self.assertEqual(len(statuses), 2)
self.assertEqual(len(exps), 3)

full_status = next(status for status in statuses if status.total_runs == 1)
partial_status = next(status for status in statuses if status.total_runs == 3)
self.assertTrue(full_status.ok)
self.assertEqual(partial_status.completed_runs, 2)
self.assertFalse(partial_status.ok)
self.assertIn("2/3 runs completed successfully", partial_status.message)

statuses, exps = await scheduler.get_results(batch_id=0, timeout=1)
self.assertEqual(len(statuses), 0)
self.assertEqual(len(exps), 0)

await scheduler.stop()

async def test_dynamic_timeout(self):
self.config.explorer.dynamic_timeout.enable = True
self.config.explorer.dynamic_timeout.ratio = 3.0
Expand Down
192 changes: 155 additions & 37 deletions tests/explorer/workflow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import asyncio
import os
import shutil
import threading
import time
import unittest
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, Optional
from unittest import mock
from unittest.mock import MagicMock
from unittest.mock import AsyncMock, MagicMock

import openai
import ray
Expand Down Expand Up @@ -531,46 +532,12 @@ async def monitor_routine():


class TestAgentScopeWorkflowAdapter(unittest.IsolatedAsyncioTestCase):
async def test_adapter_v0(self):
try:
from agentscope.model import TrinityChatModel
except ImportError:
self.skipTest("agentscope >= 1.0.9 is not installed")

async def as_workflow_func(task, model) -> float:
self.assertIsInstance(task, dict)
self.assertIsInstance(model, TrinityChatModel)
return task["reward"]

model = MagicMock()
openai_client = MagicMock()
openai_client.model_path = "Qwen/Qwen3-8B"
model.get_openai_async_client.return_value = openai_client
model.extract_experience_from_history.return_value = [
Experience(tokens=Tensor([0, 1, 2]), prompt_length=1, logprobs=Tensor([0.1, 0.2])),
Experience(tokens=Tensor([3, 4, 5]), prompt_length=2, logprobs=Tensor([0.3])),
]

as_adapter_cls = WORKFLOWS.get("agentscope_workflow_adapter")
as_adapter = as_adapter_cls(
task=Task(
raw_task={"reward": 0.1},
workflow_args={"workflow_func": as_workflow_func},
),
model=model,
)
result = await as_adapter.run_async()
self.assertEqual(len(result), 2)
self.assertEqual(result[0].reward, 0.1)
self.assertEqual(result[0].prompt_length, 1)
self.assertEqual(result[1].reward, 0.1)
self.assertEqual(result[1].prompt_length, 2)

async def test_adapter_v1(self):
try:
from agentscope.model import ChatModelBase
from agentscope.tuner import JudgeOutput, WorkflowOutput
except ImportError:
except ImportError as e:
print(str(e))
self.skipTest("agentscope >= 1.0.12 is not installed")
Comment thread
pan-x-c marked this conversation as resolved.
Outdated

async def as_workflow_func(task, model) -> WorkflowOutput:
Expand Down Expand Up @@ -660,6 +627,45 @@ async def run_async(self):
return exps


class PartialFailureWorkflow(Workflow):
can_reset: bool = True

_call_lock = threading.Lock()
_call_count = 0

def __init__(self, model: ModelWrapper, task: Task, auxiliary_models=None):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
self.fail_call_ids = set(task.raw_task.get("fail_call_ids", []))

def reset(self, task: Task):
self.fail_call_ids = set(task.raw_task.get("fail_call_ids", []))

@classmethod
def reset_call_count(cls):
with cls._call_lock:
cls._call_count = 0

@classmethod
def next_call_id(cls) -> int:
with cls._call_lock:
call_id = cls._call_count
cls._call_count += 1
return call_id

def run(self):
call_id = self.next_call_id()
if call_id in self.fail_call_ids:
raise RuntimeError(f"Intentional failure for run call {call_id}")

exp = Experience(
tokens=Tensor([0, 1, 2]),
prompt_length=1,
metrics={"run_metrics": float(call_id)},
)
exp.response_text = str(call_id)
return [exp]


class TestWorkflowRunner(unittest.IsolatedAsyncioTestCase):
async def test_workflow_runner(self):
config = get_template_config()
Expand Down Expand Up @@ -705,6 +711,118 @@ async def test_workflow_runner(self):
self.assertIsInstance(exps, list)
self.assertEqual(len(exps), 2)

@parameterized.expand(
[
("sequential", 2),
("asynchronous", 2),
("multi-threading", 2),
]
)
async def test_workflow_runner_partial_success_non_repeatable(
self, concurrent_mode: str, expected_success_runs: int
):
config = get_template_config()
config.explorer.concurrent_mode = concurrent_mode

with mock.patch(
"trinity.explorer.workflow_runner.ModelWrapper",
DummyModelWrapper,
):
runner = WorkflowRunner(
config,
model=MagicMock(),
auxiliary_models=[],
runner_id=0,
)
await runner.prepare()

PartialFailureWorkflow.reset_call_count()
task = Task(
workflow=PartialFailureWorkflow,
repeat_times=3,
raw_task={"fail_call_ids": [1]},
)

status, exps = await runner.run_task(
task, batch_id="test", repeat_times=3, run_id_base=0
)

self.assertFalse(status.ok)
self.assertEqual(status.completed_runs, expected_success_runs)
self.assertEqual(status.total_runs, 3)
self.assertEqual(len(exps), expected_success_runs)
self.assertIn(f"{expected_success_runs}/3 runs completed successfully", status.message) # type: ignore[arg-type]

@parameterized.expand(
[
("sequential",),
("asynchronous",),
("multi-threading",),
]
)
async def test_workflow_runner_fail_fast_without_partial_collection(self, concurrent_mode: str):
config = get_template_config()
config.explorer.concurrent_mode = concurrent_mode

with mock.patch(
"trinity.explorer.workflow_runner.ModelWrapper",
DummyModelWrapper,
):
runner = WorkflowRunner(
config,
model=MagicMock(),
auxiliary_models=[],
runner_id=0,
)
await runner.prepare()

task = Task(
workflow=PartialFailureWorkflow,
repeat_times=3,
raw_task={"fail_call_ids": []},
)

async def mock_execute_single_run(
workflow: Workflow,
task: Task,
run_index: int,
run_id_base: int,
):
if run_index == 0:
await asyncio.sleep(0.01)
exp = Experience(
tokens=Tensor([0, 1, 2]),
prompt_length=1,
metrics={"run_metrics": 0.0},
)
return True, [exp], {"run_metrics": 0.0}, None
if run_index == 1:
await asyncio.sleep(0.02)
return False, [], None, "planned failure"
await asyncio.sleep(0.5)
exp = Experience(
tokens=Tensor([0, 1, 2]),
prompt_length=1,
metrics={"run_metrics": 2.0},
)
return True, [exp], {"run_metrics": 2.0}, None

runner._execute_single_run = AsyncMock(side_effect=mock_execute_single_run)

status, exps = await runner.run_task(
task,
batch_id="test",
repeat_times=3,
run_id_base=0,
collect_partial_runs=False,
)

self.assertFalse(status.ok)
self.assertEqual(status.completed_runs, 1)
self.assertEqual(status.total_runs, 3)
self.assertEqual(len(exps), 1)
self.assertIn("1/3 runs completed successfully", status.message) # type: ignore[arg-type]

async def test_workflow_runner_get_state(self):
config = get_template_config()

Expand Down
1 change: 1 addition & 0 deletions tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,6 +1367,7 @@ def test_trainer(self):
self.config.buffer.eval_interval = 4 # only eval on start
self.config.name = f"explore-over-rollout-{datetime.now().strftime('%Y%m%d%H%M%S')}"
self.config.explorer.over_rollout.ratio = 0.5 # set over rollout rate to 50%, which means only wait for 2 (4 * 50%) tasks in each steps
self.config.explorer.over_rollout.return_partial_tasks = True
self.config.explorer.over_rollout.wait_after_min = 0
self.config.explorer.dynamic_timeout.enable = True
self.config.explorer.dynamic_timeout.ratio = 2
Expand Down
3 changes: 3 additions & 0 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ class OverRolloutConfig:

ratio: float = 0.0 # explorer will only wait for (1 - over_rollout.ratio) * batch_size of tasks at each step
wait_after_min: float = 30.0 # wait 30 s after reaching minimum task threshold
return_partial_tasks: bool = (
False # return tasks with partial successful runs during over-rollout cleanup
)
# more settings will be added in the future
# e.g., postpone tasks into the next step if not finished in time

Expand Down
7 changes: 5 additions & 2 deletions trinity/common/config_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,12 +655,15 @@ def validate(self, config: Config) -> None:
for args in model_args:
set_if_none(aux_model, args, getattr(config.model, args))

if config.explorer.over_rollout.ratio > 0.0:
if (
config.explorer.over_rollout.ratio > 0.0
or config.explorer.over_rollout.return_partial_tasks
):
if not (0.0 <= config.explorer.over_rollout.ratio < 1.0):
raise ValueError("over_rollout_ratio should be in [0.0, 1.0)")
if config.synchronizer.sync_style == SyncStyle.FIXED:
raise ValueError(
"over_rollout_ratio is not compatible with fixed sync_style, please set "
"over_rollout is not compatible with fixed sync_style, please set "
"`synchronizer.sync_style` to `explorer_driven` or `trainer_driven`."
)

Expand Down
Loading
Loading