Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions examples/bench/external_model.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
mode: bench
project: Trinity-RFT
name: external-model-qwen3-max
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
model:
model_path: qwen3-max
max_model_len: 4096
max_prompt_tokens: 2048
max_response_tokens: 1024
external_model:
enable: true
model_name: qwen3-max
base_url_env: API_BASE_URL
api_key_env: OPENAI_API_KEY
max_concurrent_requests: 16
buffer:
batch_size: 16
total_epochs: 1
explorer_input:
eval_tasksets:
- name: gsm8k_eval
storage_type: file
path: ${oc.env:TRINITY_TASKSET_PATH,openai/gsm8k}
subset_name: main
split: test
repeat_times: 16
format:
prompt_key: question
response_key: answer
rollout_args:
temperature: 1.0
logprobs: 0
default_eval_workflow_type: math_workflow
explorer:
auxiliary_models: []
Comment thread
hiyuchang marked this conversation as resolved.
Outdated
monitor:
monitor_type: tensorboard
1 change: 1 addition & 0 deletions tests/common/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def test_all_examples_are_valid(self):
or filename.startswith("verl_")
or filename.startswith("dj_")
or filename.startswith("tinker")
or filename.startswith("external")
):
print(f"Checking config: {filename}")
config_path = os.path.join(example_dir, example_name, filename)
Expand Down
97 changes: 97 additions & 0 deletions tests/common/external_model_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import asyncio
import gc
import os
import unittest

import ray

from tests.tools import get_model_path, get_template_config
from trinity.common.config import ExternalModelConfig, InferenceModelConfig
from trinity.common.models import create_explorer_models
from trinity.common.models.model import ModelWrapper
from trinity.common.models.external_model import ExternalModel


async def prepare_engines(engines, auxiliary_engines):
prepare_refs = []
for engine in engines:
prepare_refs.append(engine.prepare.remote())
for models in auxiliary_engines:
for engine in models:
prepare_refs.append(engine.prepare.remote())
await asyncio.gather(*prepare_refs)


class TestExternalModel(unittest.IsolatedAsyncioTestCase):
@classmethod
def setUpClass(cls):
ray.init(ignore_reinit_error=True, namespace="trinity_unittest")
gc.collect()

@classmethod
def tearDownClass(cls):
ray.shutdown(_exiting_interpreter=True)

async def asyncSetUp(self):
model_path = get_model_path()
# Part 1: bootstrap a local OpenAI-compatible endpoint via vLLM.
config = get_template_config()
config.mode = "explore"
config.model.model_path = model_path
config.explorer.rollout_model.engine_type = "vllm"
config.explorer.rollout_model.engine_num = 1
config.explorer.rollout_model.tensor_parallel_size = 1
config.explorer.rollout_model.enable_openai_api = True
config.check_and_update()

self.engines, self.auxiliary_engines = create_explorer_models(config)
self.vllm_wrapper = ModelWrapper(self.engines[0], enable_history=False)
await prepare_engines(self.engines, self.auxiliary_engines)
await self.vllm_wrapper.prepare()

openai_client = self.vllm_wrapper.get_openai_client()
self.model_name = openai_client.models.list().data[0].id

self.base_url_env = "TRINITY_OPENAI_BASE_URL_TEST"
self.api_key_env = "TRINITY_OPENAI_API_KEY_TEST"
os.environ[self.base_url_env] = f"{self.vllm_wrapper.api_address}/v1"
os.environ[self.api_key_env] = "EMPTY"
self.model_path = model_path
print(
f"Model is prepared at {self.vllm_wrapper.api_address}/v1, model_name: {self.model_name}"
)

async def test_external_model(self):
# Part 2: verify ExternalModel can call the endpoint correctly.
model = ExternalModel(
InferenceModelConfig(
model_path=self.model_path,
external_model_config=ExternalModelConfig(
base_url_env=self.base_url_env,
api_key_env=self.api_key_env,
model_name=self.model_name,
),
)
)

generate_exps = await model.generate("Say hello in one sentence.", n=1, max_tokens=16)
self.assertEqual(len(generate_exps), 1)
self.assertTrue(len(generate_exps[0].response_text) > 0)
self.assertEqual(generate_exps[0].reward, 0.0)
self.assertIn("usage/prompt_tokens", generate_exps[0].metrics)
self.assertIn("usage/completion_tokens", generate_exps[0].metrics)
self.assertIn("usage/total_tokens", generate_exps[0].metrics)
self.assertGreater(generate_exps[0].metrics["usage/total_tokens"], 0.0)

messages = [
{"role": "system", "content": "You are an assistant. Answer the question briefly."},
{"role": "user", "content": [{"type": "text", "text": "What is 1+1?"}]},
]
chat_exps = await model.chat(messages, n=4, max_tokens=32)
self.assertEqual(len(chat_exps), 4)
self.assertTrue(len(chat_exps[0].response_text) > 0)
self.assertEqual(chat_exps[0].reward, 0.0)
self.assertIn("usage/prompt_tokens", chat_exps[0].metrics)
self.assertIn("usage/completion_tokens", chat_exps[0].metrics)
self.assertIn("usage/total_tokens", chat_exps[0].metrics)
self.assertGreater(chat_exps[0].metrics["usage/total_tokens"], 0.0)
17 changes: 17 additions & 0 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,17 @@ class TinkerConfig:
base_url: Optional[str] = None


@dataclass
class ExternalModelConfig:
enable: bool = False
base_url_env: str = "OPENAI_BASE_URL"
api_key_env: str = "OPENAI_API_KEY"
model_name: Optional[str] = None
timeout: int = 60
max_concurrent_requests: int = 8
support_logprobs: bool = False


@dataclass
class ModelConfig:
# source model path
Expand Down Expand Up @@ -491,6 +502,9 @@ class ModelConfig:
# tinker config
tinker: TinkerConfig = field(default_factory=TinkerConfig)

# external API-based engine
external_model: ExternalModelConfig = field(default_factory=ExternalModelConfig)


@dataclass
class InferenceModelConfig:
Expand Down Expand Up @@ -551,6 +565,9 @@ class InferenceModelConfig:

reasoning_parser: Optional[str] = None

# For external API-based engine
external_model_config: ExternalModelConfig = field(default_factory=ExternalModelConfig)

# ! DO NOT SET
bundle_indices: str = ""
ray_namespace: Optional[str] = None
Expand Down
39 changes: 37 additions & 2 deletions trinity/common/config_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def validate(self, config: Config) -> None:
if config.ray_namespace is None or len(config.ray_namespace) == 0:
config.ray_namespace = f"{config.project}/{config.name}"

if config.model.tinker.enable:
if config.model.tinker.enable or config.model.external_model.enable:
return

# check cluster infomation
Expand Down Expand Up @@ -226,6 +226,10 @@ def _set_gpu_allocation_info(self, config: Config) -> None:
the current mode and available resources.
"""
cluster = config.cluster

if config.explorer.rollout_model.engine_type == "external":
return

if config.mode != "train":
cluster.rollout_gpu_num = (
config.explorer.rollout_model.tensor_parallel_size
Expand Down Expand Up @@ -388,6 +392,9 @@ def validate(self, config: Config) -> None:
if not model.critic_model_path:
model.critic_model_path = model.model_path

if model.external_model.enable:
self._check_external_model(config)

if model.tinker.enable:
self._check_tinker(config)

Expand All @@ -404,6 +411,31 @@ def validate(self, config: Config) -> None:
# check max_model_len, max_prompt_tokens, max_response_tokens
self._check_model_len(config)

def _check_external_model(self, config: Config) -> None:
"""Validate API-specific configuration settings.

- Validates that `base_url_env` `api_key_env` `model_name` are provided
"""
if config.mode != "bench":
raise ValueError("External model is only supported in `bench` mode.")

if config.explorer.rollout_model.engine_type != "external":
config.explorer.rollout_model.engine_type = "external"

external_model = config.model.external_model
rollout_model_config = config.explorer.rollout_model.external_model_config

for key, value in vars(external_model).items():
setattr(rollout_model_config, key, value)

# Validate required env keys.
if not rollout_model_config.base_url_env:
raise ValueError("`model.external_model.base_url_env` is required for external engine.")
if not rollout_model_config.api_key_env:
raise ValueError("`model.external_model.api_key_env` is required for external engine.")
if not rollout_model_config.model_name:
raise ValueError("`model.external_model.model_name` is required for external engine.")

def _check_tinker(self, config: Config) -> None:
"""Validate Tinker-specific configuration settings.

Expand Down Expand Up @@ -1135,6 +1167,9 @@ def validate(self, config: Config) -> None:
):
return

if config.model.external_model.enable:
return

if config.trainer.trainer_type == "verl":
if config.trainer.ulysses_sequence_parallel_size < 1:
self.logger.warning(
Expand Down Expand Up @@ -1227,7 +1262,7 @@ def validate(self, config: Config) -> None:
if config.ignore_validator_suggestions:
return

if config.model.tinker.enable:
if config.model.tinker.enable or config.model.external_model.enable:
return

if config.mode in {"train", "both"}:
Expand Down
36 changes: 36 additions & 0 deletions trinity/common/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,19 @@ def create_explorer_models(
from trinity.common.models.vllm_model import vLLMRolloutModel

engine_cls = vLLMRolloutModel
elif config.explorer.rollout_model.engine_type == "external":
rollout_engines = create_api_inference_models(
config=config.explorer.rollout_model,
actor_name=f"{config.explorer.name}_rollout_model",
)
auxiliary_engines = []
for i, model_config in enumerate(config.explorer.auxiliary_models):
engines = create_api_inference_models(
config=model_config,
actor_name=f"{config.explorer.name}_auxiliary_model_{model_config.name or i}",
)
auxiliary_engines.append(engines)
return rollout_engines, auxiliary_engines
elif config.explorer.rollout_model.engine_type == "tinker":
from trinity.common.models.tinker_model import TinkerModel

Expand Down Expand Up @@ -183,6 +196,29 @@ def create_vllm_inference_models(
return models


def create_api_inference_models(
Comment thread
hiyuchang marked this conversation as resolved.
Outdated
config: InferenceModelConfig,
actor_name: str,
) -> List:
from trinity.common.models.external_model import ExternalModel

models = []
for i in range(config.engine_num):
models.append(
ray.remote(ExternalModel)
.options(
name=f"{actor_name}_{i}",
num_cpus=0,
num_gpus=0,
namespace=config.ray_namespace,
)
.remote(
config=config,
)
)
return models


async def create_debug_explorer_model(config: Config) -> None:
"""Create explorer inference models for debugging."""
logger = get_logger(__name__)
Expand Down
Loading
Loading