Skip to content
80 changes: 80 additions & 0 deletions perf/scripts/explorer/perf_workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import time
from typing import Any, List, Optional, cast

from trinity.common.experience import Experience
from trinity.common.models.model import ModelWrapper
from trinity.common.workflows.workflow import Task, Workflow


class PerfWorkflow(Workflow):
"""A workflow for performance testing of Explorer with OpenAI API calls."""

is_async: bool = True
can_reset: bool = True

def __init__(
self,
*,
task: Task,
model: ModelWrapper,
auxiliary_models: Optional[List[ModelWrapper]] = None,
):
super().__init__(
task=task,
model=model,
auxiliary_models=auxiliary_models,
)
self.client = self.model.get_openai_async_client()
self.model_path = getattr(self.client, "model_path")
self.reset(task)

def reset(self, task: Task) -> None:
raw_task = task.raw_task or {}
self.messages = raw_task.get("messages") or []
if not self.messages:
raise ValueError("PerfWorkflow requires task.raw_task['messages'].")
self.tools = raw_task.get("tools")

async def run_async(self) -> List[Experience]:
request_latencies = []
usage_prompt_tokens = 0.0
usage_completion_tokens = 0.0
for i in range(len(self.messages)):
if self.messages[i].get("role") == "assistant":
# send a fake request to trigger the workflow and measure performance, but ignore the response content
request_kwargs = {
"model": self.model_path,
"messages": self.messages[:i],
}
if self.tools is not None:
request_kwargs["tools"] = self.tools

request_start = time.perf_counter()
responses = await self.client.chat.completions.create(**request_kwargs)
request_latency = time.perf_counter() - request_start
request_latencies.append(request_latency)

usage = cast(Any, getattr(responses, "usage", None))
prompt_tokens = getattr(usage, "prompt_tokens", None)
completion_tokens = getattr(usage, "completion_tokens", None)
if isinstance(prompt_tokens, (int, float)):
usage_prompt_tokens += float(prompt_tokens)
if isinstance(completion_tokens, (int, float)):
usage_completion_tokens += float(completion_tokens)

self.logger.info("Received response: %s", responses.choices[0].message)
exps = self.model.extract_experience_from_history()
Comment thread
pan-x-c marked this conversation as resolved.
total_request_latency = sum(request_latencies)
exps[0].metrics = {
"prompt_length": usage_prompt_tokens,
"response_length": usage_completion_tokens,
"api_call_prompt_tokens_per_second": (
usage_prompt_tokens / total_request_latency if total_request_latency > 0 else 0.0
),
"api_call_response_tokens_per_second": (
usage_completion_tokens / total_request_latency
if total_request_latency > 0
else 0.0
),
}
return exps
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ dependencies = [
"matplotlib",
"psutil",
"nvidia-ml-py",
"transformers>=5.6.2",
"transformers>=5.8.0",
"datasets>=4.0.0",
"typer>=0.20.1",
]
Expand All @@ -56,6 +56,9 @@ trinity = "trinity.cli.launcher:main"
vllm = [
"vllm>=0.19.1,<=0.20.1",
]
sglang = [
"sglang==0.5.11",
]
data = [
"py-data-juicer>=1.4.3"
]
Expand Down
2 changes: 1 addition & 1 deletion scripts/docker/Dockerfile.uv
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ RUN chmod 1777 /tmp && apt update && apt install -y \
curl git wget vim tmux net-tools cmake \
python3 python3-pip python3-dev python3-packaging python3-venv \
libomp-dev libnuma1 infiniband-diags libibverbs-dev librdmacm-dev rdma-core perftest \
libnuma-dev \
libnuma-dev protobuf-compiler \
&& rm -rf /var/lib/apt/lists/* \
&& ln -sf /usr/bin/python3 /usr/bin/python \
&& ln -sf /usr/bin/pip3 /usr/bin/pip
Expand Down
58 changes: 58 additions & 0 deletions tests/common/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,75 @@
import math
import os
import shutil
import socket
import unittest
from unittest.mock import patch

import torch

from tests.tools import get_template_config, get_unittest_dataset_config
from trinity.common.config import InferenceModelConfig, load_config
from trinity.common.constants import SyncMethod
from trinity.common.models.model import InferenceModel

CHECKPOINT_ROOT_DIR = os.path.join(os.path.dirname(__file__), "temp_checkpoint_dir")


class DummyInferenceModel(InferenceModel):
async def generate(self, prompt: str, **kwargs):
raise NotImplementedError

async def chat(self, messages, **kwargs):
raise NotImplementedError

async def logprobs(self, token_ids, **kwargs):
raise NotImplementedError

async def convert_messages_to_experience(self, messages, tools=None, temperature=None):
raise NotImplementedError

async def sync_model(
self, model_version: int, sync_method: SyncMethod, timeout: float = 1200
) -> int:
return model_version

def get_model_version(self) -> int:
return 0


class TestConfig(unittest.TestCase):
def test_inference_model_base_port_uses_engine_id(self):
model = DummyInferenceModel(InferenceModelConfig(base_port=9000, engine_id=3))

_, port = model.get_available_address()

self.assertEqual(port, 9003)

def test_inference_model_base_port_falls_back_when_unavailable(self):
requested_port = 9004
model = DummyInferenceModel(InferenceModelConfig(base_port=9000, engine_id=4))

with socket.socket() as occupied_socket:
occupied_socket.bind(("", requested_port))

with patch.object(model.logger, "warning") as mock_warning:
_, port = model.get_available_address()

self.assertNotEqual(port, requested_port)
self.assertGreater(port, 0)
mock_warning.assert_called_once_with(
"Configured port %s is unavailable for engine %s; falling back to an ephemeral port.",
requested_port,
4,
)

def test_inference_model_without_base_port_uses_ephemeral_port(self):
model = DummyInferenceModel(InferenceModelConfig())

_, port = model.get_available_address()

self.assertGreater(port, 0)

def test_load_default_config(self):
config = get_template_config()
config.buffer.batch_size = 8
Expand Down
4 changes: 2 additions & 2 deletions tests/explorer/scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def __init__(self):

super().__init__(InferenceModelConfig(model_path="dummy_model"))

def sync_model(self, model_version, update_weight_args_list):
def sync_model(self, model_version, sync_method, timeout):
return True

async def prepare(self):
Expand Down Expand Up @@ -329,7 +329,7 @@ async def generate(self, prompt: str, lora_request=None, **kwargs) -> Sequence[E

@ray.remote
class DummyAuxiliaryModel(InferenceModel):
def sync_model(self, model_version, update_weight_args_list):
def sync_model(self, model_version, sync_method, timeout):
return True

def get_model_version(self):
Expand Down
13 changes: 10 additions & 3 deletions tests/manager/synchronizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,14 @@ async def new_finish_explore_step(self: Explorer, step: int, model_version: int)
await asyncio.sleep(explore_step_time_list[step - 1])
dummy_exps = [
Experience(
tokens=torch.tensor([0, 0, 0]),
tokens=torch.tensor([0, 1, 2]),
info={"model_version": model_version},
)
for _ in range(self.config.buffer.train_batch_size)
]
await self.experience_pipeline.process.remote(Experience.serialize_many(dummy_exps))
await self.rollout_coordinator.process_experiences.remote(
[Experience.serialize_many(dummy_exps)]
)
self.monitor.log(metric, step=step)

Explorer.explore_step = new_explore_step
Expand Down Expand Up @@ -347,6 +349,7 @@ class TestPullLatestWeights(unittest.IsolatedAsyncioTestCase):

def setUp(self):
self.explorer = object.__new__(Explorer)
self.explorer.config = Config()
self.explorer.logger = MagicMock()
self.explorer.models = [MagicMock(), MagicMock()]
self.explorer.synchronizer = MagicMock()
Expand Down Expand Up @@ -378,7 +381,11 @@ async def test_pull_latest_weights(self, model_version, new_version, expect_sync

for m in self.explorer.models:
if expect_sync:
m.sync_model.remote.assert_called_once_with(new_version)
m.sync_model.remote.assert_called_once_with(
new_version,
self.explorer.config.synchronizer.sync_method,
timeout=self.explorer.config.synchronizer.sync_timeout,
)
else:
m.sync_model.remote.assert_not_called()

Expand Down
2 changes: 1 addition & 1 deletion tests/template/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ explorer:
gpu_memory_utilization: 0.8
trainer:
trainer_type: verl
trainer_strategy: fsdp
trainer_strategy: fsdp2
save_interval: 100
save_hf_checkpoint: never
grad_clip: 1.0
Expand Down
12 changes: 5 additions & 7 deletions tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def setUp(self):
@parameterized_class(
("strategy",),
[
("fsdp",),
("fsdp2",),
("megatron",),
],
)
Expand Down Expand Up @@ -576,7 +576,7 @@ def run_serve(config: Config, stop_event=None) -> None:

@parameterized_class(
("use_priority_queue", "strategy"),
[(False, "fsdp"), (True, "fsdp"), (True, "megatron")],
[(False, "fsdp"), (True, "fsdp2"), (True, "megatron")],
)
class TestFullyAsyncMode(unittest.TestCase):
def setUp(self):
Expand All @@ -603,16 +603,14 @@ def test_fully_async_mode(self):
config.synchronizer.sync_method = SyncMethod.CHECKPOINT
config.synchronizer.sync_style = SyncStyle.EXPLORER_DRIVEN
config.synchronizer.sync_interval = 8
config.trainer.trainer_strategy = self.strategy
config.monitor.monitor_type = "tensorboard"
trainer_config = deepcopy(config)
trainer_config.mode = "train"
trainer_config.buffer.train_batch_size = 4
if self.strategy == "megatron":
trainer_config.trainer.trainer_strategy = "megatron"
trainer_config.check_and_update()
if self.strategy == "megatron":
_trainer_config = trainer_config.trainer.trainer_config
_trainer_config.critic.strategy = "megatron"
_trainer_config = trainer_config.trainer.trainer_config
_trainer_config.critic.strategy = self.strategy

explorer1_config = deepcopy(config)
explorer1_config.trainer = deepcopy(trainer_config.trainer)
Expand Down
16 changes: 15 additions & 1 deletion trinity/buffer/reader/queue_reader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Reader of the Queue buffer."""

import traceback
from typing import Dict, List, Optional

import ray
Expand Down Expand Up @@ -33,11 +34,24 @@ def read(self, batch_size: Optional[int] = None, **kwargs) -> List[Experience]:
)
except StopAsyncIteration:
raise StopIteration()
except Exception as e:
if "StopAsyncIteration" in traceback.format_exc():
raise StopIteration() from e
else:
raise e
Comment thread
pan-x-c marked this conversation as resolved.
Outdated
return exps

async def read_async(self, batch_size: Optional[int] = None, **kwargs) -> List[Experience]:
batch_size = self.read_batch_size if batch_size is None else batch_size
exp_bytes = await self.queue.get_batch.remote(batch_size, timeout=self.timeout, **kwargs)
try:
exp_bytes = await self.queue.get_batch.remote(
batch_size, timeout=self.timeout, **kwargs
)
except Exception as e:
if "StopAsyncIteration" in traceback.format_exc():
raise StopAsyncIteration() from e
else:
raise e
Comment thread
pan-x-c marked this conversation as resolved.
Outdated
exps = Experience.deserialize_many(exp_bytes)
if len(exps) != batch_size:
raise TimeoutError(
Expand Down
4 changes: 4 additions & 0 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,8 @@ class InferenceModelConfig:
repetition_penalty: Optional[float] = None
# used for testing very long response generation, do not set it unless you know what you are doing
ignore_eos: bool = False
# for multi-modal models
enable_multimodal: bool = False

# override chat template in model
chat_template: Optional[str] = None
Expand All @@ -559,6 +561,7 @@ class InferenceModelConfig:
# For OpenAI API
enable_openai_api: bool = False
enable_log_requests: bool = False # whether to enable request logging in vLLM API server
base_port: Optional[int] = None

# For tool calls in OpenAI API
enable_auto_tool_choice: bool = False
Expand All @@ -572,6 +575,7 @@ class InferenceModelConfig:

# ! DO NOT SET
bundle_indices: str = ""
engine_id: int = 0
ray_namespace: Optional[str] = None
cuda_visible_devices: Optional[str] = None

Expand Down
Loading
Loading