Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmark/config/guru_math-template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ explorer:
eval_interval: 10
runner_per_model: 8
rollout_model:
engine_type: vllm_async
engine_type: vllm
engine_num: 3
tensor_parallel_size: 1
enable_prefix_caching: false
Expand Down
15 changes: 14 additions & 1 deletion docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,11 @@ model:
train_mlp: true
train_attn: true
train_unembed: true
external_model:
enable: false
base_url_env: OPENAI_BASE_URL
api_key_env: OPENAI_API_KEY
model_name: null
```

- `model_path`: Path to the model being trained. If `tinker` is enabled, this is the path to the local tokenizer.
Expand Down Expand Up @@ -208,6 +213,11 @@ model:
- `train_mlp`: Whether to train the MLP layer. Default is `true`.
- `train_attn`: Whether to train the attention layer. Default is `true`.
- `train_unembed`: Whether to train the unembedding layer. Default is `true`.
- `external_model`: Optional external API model configuration.
- `enable`: Whether to enable external API model. Default is `false`.
- `model_name`: The name of the external API model. If not specified, defaults to `null`.
- `base_url_env`: The environment variable name for the base URL of the external API model. If not specified, defaults to `OPENAI_BASE_URL`.
- `api_key_env`: The environment variable name for the API key of the external API model. If not specified, defaults to `OPENAI_API_KEY`.

```{tip}
If you are using the openai API provided by Explorer, only `max_model_len` will take effect, and the value of `max_response_tokens`, `max_prompt_tokens`, and `min_response_tokens` will be ignored. When `max_tokens` is not independently specified, each API call will generate up to `max_model_len - prompt_length` tokens. Therefore, please ensure that the prompt length is less than `max_model_len` when using the API.
Expand Down Expand Up @@ -427,7 +437,10 @@ explorer:
- `max_timeout`: Maximum time (in seconds) for a workflow to complete.
- `max_retry_times`: Maximum number of retries for a workflow.
- `env_vars`: Environment variables to be set for every workflow runners.
- `rollout_model.engine_type`: Type of inference engine. For now, only `vllm_async` and `vllm` is supported, they have the same meaning and both use the asynchronous engine. In subsequent versions, only `vllm` may be retained for simplicity.
- `rollout_model.engine_type`: Type of inference engine. Supported options:
- `vllm`: Use vLLM asynchronous engine.
- `tinker`: Use Tinker engine.
- `external`: Use external API-based model engine.
- `rollout_model.engine_num`: Number of inference engines.
- `rollout_model.tensor_parallel_size`: Degree of tensor parallelism.
- `rollout_model.enable_history`: Whether to enable model call history recording. If set to `True`, the model wrapper automatically records the return experiences of model calls. Please periodically extract the history via `extract_experience_from_history` to avoid out-of-memory issues. Default is `False`.
Expand Down
15 changes: 14 additions & 1 deletion docs/sphinx_doc/source_zh/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,11 @@ model:
train_mlp: true
train_attn: true
train_unembed: true
external_model:
enable: false
base_url_env: OPENAI_BASE_URL
api_key_env: OPENAI_API_KEY
model_name: null
```

- `model_path`: 被训练模型的路径。如果启用了`tinker`,则该路径为本地 tokenizer 的路径。
Expand Down Expand Up @@ -208,6 +213,11 @@ model:
- `train_mlp`:是否训练 MLP 层。默认为 `true`。
- `train_attn`:是否训练注意力层。默认为 `true`。
- `train_unembed`:是否训练反嵌入(unembedding)层。默认为 `true`。
- `external_model`:可选的外部 API 模型配置。
- `enable`:是否启用外部模型。默认为 `false`。
- `model_name`:外部模型的模型名称。若未指定,则默认为 `null`。
- `base_url_env`:外部模型的 base url 环境变量。若未指定,则默认为 `OPENAI_BASE_URL`。
- `api_key_env`:外部模型的 api key 环境变量。若未指定,则默认为 `OPENAI_API_KEY`。

```{tip}
如果使用的是 Explorer 提供的 openai API,则只有 `max_model_len` 会生效,而 `max_response_tokens`、`max_prompt_tokens` 和 `min_response_tokens` 的值将被忽略,在没有独立指定 `max_tokens` 时,每次 API 调用将生成最多 `max_model_len - prompt_length` 个 token,因此在使用时请确保 prompt 长度小于 `max_model_len`。
Expand Down Expand Up @@ -424,7 +434,10 @@ explorer:
- `max_timeout`: 等待 Workflow 完成的最大时间(秒)。
- `max_retry_times`: Workflow 失败或超时情况下的最大重试次数。
- `env_vars`: 为每个 WorkflowRunner 设置的环境变量。
- `rollout_model.engine_type`: 推理引擎类型。支持 `vllm_async` 和 `vllm`,二者的含义相同,都使用了异步引擎。后续版本会只保留 `vllm`。
- `rollout_model.engine_type`: 推理引擎类型。支持选项:
- `vllm`: 使用 vLLM 异步引擎。
- `tinker`: 使用 Tinker 引擎。
- `external`: 使用外部 API 引擎。
- `rollout_model.engine_num`: 推理引擎实例的数量。
- `rollout_model.tensor_parallel_size`: 每个实例的张量并行度。
- `rollout_model.enable_history`: 是否启用模型调用历史记录功能。若设为 `True`,模型会自动记录调用返回的 experience。请定期通过 `extract_experience_from_history` 提取历史,以避免内存溢出。默认为 `False`。
Expand Down
34 changes: 34 additions & 0 deletions examples/bench/external_model.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
mode: bench
project: Trinity-RFT
name: external-model-qwen3-max
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
model:
model_path: qwen3-max
max_model_len: 4096
max_prompt_tokens: 2048
max_response_tokens: 1024
external_model:
enable: true
model_name: qwen3-max
base_url_env: API_BASE_URL
api_key_env: OPENAI_API_KEY
buffer:
batch_size: 16
total_epochs: 1
explorer_input:
eval_tasksets:
- name: gsm8k_eval
storage_type: file
path: ${oc.env:TRINITY_TASKSET_PATH,openai/gsm8k}
subset_name: main
split: test
repeat_times: 1
format:
prompt_key: question
response_key: answer
rollout_args:
temperature: 1.0
logprobs: 0
default_eval_workflow_type: math_workflow
monitor:
monitor_type: tensorboard
2 changes: 1 addition & 1 deletion examples/entropy/clipb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ explorer:
eval_on_startup: true
runner_per_model: 8
rollout_model:
engine_type: vllm_async
engine_type: vllm
engine_num: 4
tensor_parallel_size: 1
seed: 42
Expand Down
2 changes: 1 addition & 1 deletion examples/entropy/clipv.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ explorer:
eval_on_startup: true
runner_per_model: 8
rollout_model:
engine_type: vllm_async
engine_type: vllm
engine_num: 4
tensor_parallel_size: 1
seed: 42
Expand Down
2 changes: 1 addition & 1 deletion examples/learn_to_ask/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ explorer:
max_timeout: 900
max_retry_times: 2
rollout_model:
engine_type: vllm_async
engine_type: vllm
engine_num: 4
tensor_parallel_size: 1
use_v1: true
Expand Down
1 change: 1 addition & 0 deletions tests/common/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def test_all_examples_are_valid(self):
or filename.startswith("verl_")
or filename.startswith("dj_")
or filename.startswith("tinker")
or filename.startswith("external")
):
print(f"Checking config: {filename}")
config_path = os.path.join(example_dir, example_name, filename)
Expand Down
143 changes: 143 additions & 0 deletions tests/common/external_model_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import asyncio
import gc
import os
import unittest

import ray

from tests.tools import get_model_path, get_template_config
from trinity.common.config import ExternalModelConfig, InferenceModelConfig
from trinity.common.models import create_explorer_models
from trinity.common.models.external_model import ExternalModel
from trinity.common.models.model import ModelWrapper


async def prepare_engines(engines, auxiliary_engines):
prepare_refs = []
for engine in engines:
prepare_refs.append(engine.prepare.remote())
for models in auxiliary_engines:
for engine in models:
prepare_refs.append(engine.prepare.remote())
await asyncio.gather(*prepare_refs)


class TestExternalModel(unittest.IsolatedAsyncioTestCase):
@classmethod
def setUpClass(cls):
ray.init(ignore_reinit_error=True, namespace="trinity_unittest")
gc.collect()

@classmethod
def tearDownClass(cls):
ray.shutdown(_exiting_interpreter=True)

async def asyncSetUp(self):
model_path = get_model_path()
# Part 1: bootstrap a local OpenAI-compatible endpoint via vLLM.
config = get_template_config()
config.mode = "explore"
config.model.model_path = model_path
config.explorer.rollout_model.engine_type = "vllm"
config.explorer.rollout_model.engine_num = 1
config.explorer.rollout_model.tensor_parallel_size = 1
config.explorer.rollout_model.enable_openai_api = True
config.check_and_update()

self.engines, self.auxiliary_engines = create_explorer_models(config)
self.vllm_wrapper = ModelWrapper(self.engines[0], enable_history=False)
await prepare_engines(self.engines, self.auxiliary_engines)
await self.vllm_wrapper.prepare()

openai_client = self.vllm_wrapper.get_openai_client()
self.model_name = openai_client.models.list().data[0].id

self.base_url_env = "TRINITY_OPENAI_BASE_URL_TEST"
self.api_key_env = "TRINITY_OPENAI_API_KEY_TEST"
os.environ[self.base_url_env] = f"{self.vllm_wrapper.api_address}/v1"
os.environ[self.api_key_env] = "EMPTY"
self.model_path = model_path
print(
f"Model is prepared at {self.vllm_wrapper.api_address}/v1, model_name: {self.model_name}"
)

async def test_external_model(self):
# Part 2: verify ExternalModel can call the endpoint correctly.
model = ExternalModel(
InferenceModelConfig(
model_path=self.model_path,
external_model_config=ExternalModelConfig(
base_url_env=self.base_url_env,
api_key_env=self.api_key_env,
model_name=self.model_name,
),
)
)

generate_exps = await model.generate("Say hello in one sentence.", n=1, max_tokens=16)
self.assertEqual(len(generate_exps), 1)
self.assertTrue(len(generate_exps[0].response_text) > 0)
self.assertEqual(generate_exps[0].reward, 0.0)
self.assertIn("usage/prompt_tokens", generate_exps[0].metrics)
self.assertIn("usage/completion_tokens", generate_exps[0].metrics)
self.assertIn("usage/total_tokens", generate_exps[0].metrics)
self.assertGreater(generate_exps[0].metrics["usage/total_tokens"], 0.0)

messages = [
{"role": "system", "content": "You are an assistant. Answer the question briefly."},
{"role": "user", "content": [{"type": "text", "text": "What is 1+1?"}]},
]
chat_exps = await model.chat(messages, n=4, max_tokens=32)
self.assertEqual(len(chat_exps), 4)
self.assertTrue(len(chat_exps[0].response_text) > 0)
self.assertEqual(chat_exps[0].reward, 0.0)
self.assertIn("usage/prompt_tokens", chat_exps[0].metrics)
self.assertIn("usage/completion_tokens", chat_exps[0].metrics)
self.assertIn("usage/total_tokens", chat_exps[0].metrics)
self.assertGreater(chat_exps[0].metrics["usage/total_tokens"], 0.0)


@ray.remote
class MockExternalInferenceModelActor:
def __init__(self, config, api_server_url, api_key):
self._config = config
self._api_server_url = api_server_url
self._api_key = api_key

def get_model_config(self):
return self._config

def get_api_key(self):
return self._api_key

def get_api_server_url(self):
return self._api_server_url

def get_model_path(self):
return self._config.model_path


class TestExternalModelLoad(unittest.IsolatedAsyncioTestCase):
@classmethod
def setUpClass(cls):
ray.init(ignore_reinit_error=True, namespace="trinity_unittest")

@classmethod
def tearDownClass(cls):
ray.shutdown(_exiting_interpreter=True)

async def test_external_model_load(self):
mock_base_url = "https://mock.external.endpoint/v1/"
mock_api_key = "dummy-api-key"
config = InferenceModelConfig(
model_path="mock-model-name",
engine_type="external",
external_model_config=ExternalModelConfig(enable=True),
)
model_actor = MockExternalInferenceModelActor.remote(config, mock_base_url, mock_api_key)
wrapper = ModelWrapper(model_actor, enable_history=False)
await wrapper.prepare()

client = wrapper.get_openai_client()
self.assertEqual(client.api_base_url, f"{mock_base_url}/v1")
self.assertEqual(client.api_key, mock_api_key)
15 changes: 15 additions & 0 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,15 @@ class TinkerConfig:
base_url: Optional[str] = None


@dataclass
class ExternalModelConfig:
enable: bool = False
base_url_env: str = "OPENAI_BASE_URL"
api_key_env: str = "OPENAI_API_KEY"
model_name: Optional[str] = None
# support_prompt_logprobs: bool = False # TODO


@dataclass
class ModelConfig:
# source model path
Expand Down Expand Up @@ -491,6 +500,9 @@ class ModelConfig:
# tinker config
tinker: TinkerConfig = field(default_factory=TinkerConfig)

# external API-based engine
external_model: ExternalModelConfig = field(default_factory=ExternalModelConfig)


@dataclass
class InferenceModelConfig:
Expand Down Expand Up @@ -551,6 +563,9 @@ class InferenceModelConfig:

reasoning_parser: Optional[str] = None

# For external API-based engine
external_model_config: ExternalModelConfig = field(default_factory=ExternalModelConfig)

# ! DO NOT SET
bundle_indices: str = ""
ray_namespace: Optional[str] = None
Expand Down
Loading
Loading