Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions examples/openai_eval/bench.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
mode: bench
project: Trinity-RFT-openai-eval
name: openai-api-mvp
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
model:
model_path: qwen3-max
max_model_len: 4096
max_prompt_tokens: 2048
max_response_tokens: 1024
buffer:
batch_size: 16
total_epochs: 1
explorer_input:
eval_tasksets:
- name: gsm8k_eval
storage_type: file
path: ${oc.env:TRINITY_TASKSET_PATH,openai/gsm8k}
subset_name: main
split: test
format:
prompt_key: question
response_key: answer
rollout_args:
temperature: 1.0
logprobs: 0
default_eval_workflow_type: math_workflow
explorer:
rollout_model:
engine_type: openai_api
api_base_url_env: OPENAI_BASE_URL
api_key_env: DASHSCOPE_API_KEY
api_model_name: qwen3-max
api_max_concurrent_requests: 16
auxiliary_models: []
monitor:
monitor_type: tensorboard
101 changes: 101 additions & 0 deletions tests/common/openai_api_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import asyncio
import gc
import os
import unittest
from pathlib import Path

import ray

from trinity.common.config import InferenceModelConfig, load_config
from trinity.common.constants import MODEL_PATH_ENV_VAR
from trinity.common.models import create_explorer_models
from trinity.common.models.model import ModelWrapper
from trinity.common.models.openai_api_model import OpenaiAPIModel


async def prepare_engines(engines, auxiliary_engines):
prepare_refs = []
for engine in engines:
prepare_refs.append(engine.prepare.remote())
for models in auxiliary_engines:
for engine in models:
prepare_refs.append(engine.prepare.remote())
await asyncio.gather(*prepare_refs)


class TestOpenaiAPIModel(unittest.IsolatedAsyncioTestCase):
@classmethod
def setUpClass(cls):
ray.init(ignore_reinit_error=True, namespace="trinity_unittest")
gc.collect()

@classmethod
def tearDownClass(cls):
ray.shutdown(_exiting_interpreter=True)

async def asyncSetUp(self):
model_path = os.environ.get(MODEL_PATH_ENV_VAR)
if not model_path:
raise unittest.SkipTest(
f"Please set `export {MODEL_PATH_ENV_VAR}=<your_model_dir>` before running this test."
)

# Part 1: bootstrap a local OpenAI-compatible endpoint via vLLM.
config_path = Path(__file__).resolve().parents[1] / "template" / "config.yaml"
config = load_config(str(config_path))
config.mode = "explore"
config.ray_namespace = "trinity_unittest"
config.model.model_path = model_path
config.explorer.rollout_model.engine_type = "vllm"
config.explorer.rollout_model.engine_num = 1
config.explorer.rollout_model.tensor_parallel_size = 1
config.explorer.rollout_model.enable_openai_api = True
config.check_and_update()

self.engines, self.auxiliary_engines = create_explorer_models(config)
self.vllm_wrapper = ModelWrapper(self.engines[0], enable_history=False)
await prepare_engines(self.engines, self.auxiliary_engines)
await self.vllm_wrapper.prepare()

openai_client = self.vllm_wrapper.get_openai_client()
self.model_name = openai_client.models.list().data[0].id

self.base_url_env = "TRINITY_OPENAI_BASE_URL_TEST"
self.api_key_env = "TRINITY_OPENAI_API_KEY_TEST"
os.environ[self.base_url_env] = f"{self.vllm_wrapper.api_address}/v1"
os.environ[self.api_key_env] = "EMPTY"
self.model_path = model_path

async def test_openai_api_model_basic(self):
# Part 2: verify OpenaiAPIModel can call the endpoint correctly.
model = OpenaiAPIModel(
InferenceModelConfig(
engine_type="openai_api",
model_path=self.model_path,
api_base_url_env=self.base_url_env,
api_key_env=self.api_key_env,
api_model_name=self.model_name,
api_max_concurrent_requests=2,
max_prompt_tokens=8,
enable_prompt_truncation=True,
)
)

generate_exps = await model.generate("Say hello in one sentence.", n=1, max_tokens=16)
self.assertEqual(len(generate_exps), 1)
self.assertTrue(len(generate_exps[0].response_text) > 0)

messages = [
{"role": "system", "content": "You are concise."},
{"role": "user", "content": [{"type": "text", "text": "What is 1+1?"}]},
]
chat_exps = await model.chat(messages, n=1, max_tokens=16)
self.assertEqual(len(chat_exps), 1)
self.assertTrue(len(chat_exps[0].response_text) > 0)

long_prompt_exps = await model.generate("hello " * 1024, n=2)
self.assertEqual(len(long_prompt_exps), 2)
self.assertTrue(all(exp.truncate_status == "prompt_truncated" for exp in long_prompt_exps))

token_len = await model.get_message_token_len(messages)
self.assertGreater(token_len, 0)
8 changes: 8 additions & 0 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,14 @@ class InferenceModelConfig:

reasoning_parser: Optional[str] = None

# For external OpenAI-compatible API engine
api_base_url_env: str = "OPENAI_BASE_URL"
api_key_env: str = "OPENAI_API_KEY"
api_model_name: Optional[str] = None
api_timeout: int = 60
api_max_concurrent_requests: int = 8
api_support_logprobs: bool = False

# ! DO NOT SET
bundle_indices: str = ""
ray_namespace: Optional[str] = None
Expand Down
36 changes: 30 additions & 6 deletions trinity/common/config_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,14 +226,16 @@ def _set_gpu_allocation_info(self, config: Config) -> None:
the current mode and available resources.
"""
cluster = config.cluster

def _required_gpus(model_config) -> int:
if model_config.explorer.rollout_model.engine_type == "openai_api":
return 0
return model_config.tensor_parallel_size * model_config.engine_num

if config.mode != "train":
cluster.rollout_gpu_num = (
config.explorer.rollout_model.tensor_parallel_size
* config.explorer.rollout_model.engine_num
)
cluster.rollout_gpu_num = _required_gpus(config.explorer.rollout_model)
cluster.auxiliary_model_gpu_num = sum(
model.tensor_parallel_size * model.engine_num
for model in config.explorer.auxiliary_models
_required_gpus(model) for model in config.explorer.auxiliary_models
)
cluster.explorer_gpu_num = cluster.rollout_gpu_num + cluster.auxiliary_model_gpu_num
cluster.total_gpu_num = cluster.node_num * cluster.gpu_per_node
Expand Down Expand Up @@ -388,6 +390,9 @@ def validate(self, config: Config) -> None:
if not model.critic_model_path:
model.critic_model_path = model.model_path

if model.explorer.rollout_model.engine_type == "openai_api":
self._check_openai_api(config)

if model.tinker.enable:
self._check_tinker(config)

Expand All @@ -404,6 +409,25 @@ def validate(self, config: Config) -> None:
# check max_model_len, max_prompt_tokens, max_response_tokens
self._check_model_len(config)

def _check_openai_api(self, config: Config) -> None:
"""Validate OpenAI API-specific configuration settings.

- Sets trainer type to `tinker` for skipping trainer validation
- Validates that `api_base_url_env` and `api_key_env` are provided
- Validates that `api_model_name` is provided
"""
if config.trainer.trainer_type != "tinker":
config.trainer.trainer_type = "tinker"
self.logger.debug("Trainer type is set to `tinker` for skipping trainer validation.")

model = config.model
if not model.api_base_url_env or os.getenv(model.api_base_url_env) is None:
raise ValueError("`api_base_url_env` must be provided when engine_type=openai_api.")
if not model.api_key_env or os.getenv(model.api_key_env) is None:
raise ValueError("`model.api_key_env` must be provided when engine_type=openai_api.")
if not model.api_model_name:
raise ValueError("`model.api_model_name` must be provided when engine_type=openai_api.")

def _check_tinker(self, config: Config) -> None:
"""Validate Tinker-specific configuration settings.

Expand Down
36 changes: 36 additions & 0 deletions trinity/common/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,19 @@ def create_explorer_models(
from trinity.common.models.vllm_model import vLLMRolloutModel

engine_cls = vLLMRolloutModel
elif config.explorer.rollout_model.engine_type == "openai_api":
rollout_engines = create_api_inference_models(
config=config.explorer.rollout_model,
actor_name=f"{config.explorer.name}_rollout_model",
)
auxiliary_engines = []
for i, model_config in enumerate(config.explorer.auxiliary_models):
engines = create_api_inference_models(
config=model_config,
actor_name=f"{config.explorer.name}_auxiliary_model_{model_config.name or i}",
)
auxiliary_engines.append(engines)
return rollout_engines, auxiliary_engines
elif config.explorer.rollout_model.engine_type == "tinker":
from trinity.common.models.tinker_model import TinkerModel

Expand Down Expand Up @@ -183,6 +196,29 @@ def create_vllm_inference_models(
return models


def create_api_inference_models(
Comment thread
hiyuchang marked this conversation as resolved.
Outdated
config: InferenceModelConfig,
actor_name: str,
) -> List:
from trinity.common.models.openai_api_model import OpenaiAPIModel

models = []
for i in range(config.engine_num):
models.append(
ray.remote(OpenaiAPIModel)
.options(
name=f"{actor_name}_{i}",
num_cpus=0,
num_gpus=0,
namespace=config.ray_namespace,
)
.remote(
config=config,
)
)
return models


async def create_debug_explorer_model(config: Config) -> None:
"""Create explorer inference models for debugging."""
logger = get_logger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion trinity/common/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ async def prepare(self) -> None:
if self.api_address is None:
self.logger.info("API server is not enabled for inference model.")
return
if self._engine_type == "tinker":
if self._engine_type in {"tinker", "openai_api"}:
return
max_retries = 30
interval = 2 # seconds
Expand Down
Loading
Loading