Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/sphinx_doc/source/tutorial/trinity_installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ For installing Trinity-RFT, you have three options: from source (recommended), v
Before installing, ensure your system meets the following requirements:

- **Python**: Version 3.10 to 3.12 (inclusive)
- **CUDA**: Version 12.4 to 12.8 (inclusive)
- **CUDA**: Version >= 12.6
- **GPUs**: At least 2 GPUs

---
Expand Down
2 changes: 1 addition & 1 deletion docs/sphinx_doc/source_zh/tutorial/trinity_installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
在安装前,请确保您的系统满足以下要求:

- **Python**:3.10 至 3.12(包含)
- **CUDA**:12.4 至 12.8(包含)
- **CUDA**:大于等于 12.6
- **GPU**:至少 2 块 GPU

---
Expand Down
48 changes: 35 additions & 13 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@ class FormatConfig:

@dataclass
class GenerationConfig:
temperature: float = 1.0
top_p: float = 1.0
top_k: int = -1
logprobs: int = 0 # vLLM return `logprobs + 1` elements
temperature: Optional[float] = None # 1.0
top_p: Optional[float] = None # 1.0
top_k: Optional[int] = None # -1
logprobs: Optional[int] = None # 0 # vLLM return `logprobs + 1` elements
max_tokens: Optional[int] = None # if None, use model.max_response_tokens
# repeat each task for `n` times
# ! DO NOT SET in `buffer.explorer_input.taskset.rollout_args`
Expand Down Expand Up @@ -412,6 +412,12 @@ class ModelConfig:

custom_chat_template: Optional[str] = None

# rollout args
temperature: float = 1.0
top_p: float = 1.0
top_k: int = -1
logprobs: int = 0

# the total number of tokens the model can handle
max_model_len: Optional[int] = None

Expand Down Expand Up @@ -447,6 +453,12 @@ class InferenceModelConfig:
dtype: str = "bfloat16"
seed: int = 42

# rollout args, ! DO NOT SET
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
logprobs: Optional[int] = None

# if not set, use `model.max_model_len`
max_model_len: Optional[int] = None
# if not set, use `model.max_prompt_tokens`
Expand Down Expand Up @@ -853,6 +865,10 @@ def _check_explorer_input(self) -> None:
set_if_none(taskset, "default_workflow_type", explorer_input.default_workflow_type)
set_if_none(taskset, "default_reward_fn_type", explorer_input.default_reward_fn_type)
set_if_none(taskset, "ray_namespace", self.ray_namespace)
set_if_none(taskset.rollout_args, "temperature", self.model.temperature)
set_if_none(taskset.rollout_args, "top_p", self.model.top_p)
set_if_none(taskset.rollout_args, "top_k", self.model.top_k)
set_if_none(taskset.rollout_args, "logprobs", self.model.logprobs)
set_if_none(taskset.rollout_args, "max_tokens", self.model.max_response_tokens)

for idx, dataset in enumerate(explorer_input.eval_tasksets):
Expand All @@ -868,6 +884,10 @@ def _check_explorer_input(self) -> None:
set_if_none(dataset, "default_workflow_type", explorer_input.default_workflow_type)
set_if_none(dataset, "default_reward_fn_type", explorer_input.default_reward_fn_type)
set_if_none(dataset, "ray_namespace", self.ray_namespace)
set_if_none(dataset.rollout_args, "temperature", self.model.temperature)
set_if_none(dataset.rollout_args, "top_p", self.model.top_p)
set_if_none(dataset.rollout_args, "top_k", self.model.top_k)
set_if_none(dataset.rollout_args, "logprobs", self.model.logprobs)
set_if_none(dataset.rollout_args, "max_tokens", self.model.max_response_tokens)

def _check_trainer_input(self) -> None:
Expand Down Expand Up @@ -1161,18 +1181,20 @@ def check_and_update(self) -> Config: # noqa: C901

# check explorer
if self.explorer is not None:
self.explorer.rollout_model.model_path = self.model.model_path
self.explorer.rollout_model.max_model_len = self.model.max_model_len
self.explorer.rollout_model.max_prompt_tokens = self.model.max_prompt_tokens
self.explorer.rollout_model.max_response_tokens = self.model.max_response_tokens
self.explorer.rollout_model.min_response_tokens = self.model.min_response_tokens
rollout_args = ["temperature", "top_p", "top_k", "logprobs"]
length_args = [
"max_model_len",
"max_prompt_tokens",
"max_response_tokens",
"min_response_tokens",
]
for args in ["model_path"] + rollout_args + length_args:
setattr(self.explorer.rollout_model, args, getattr(self.model, args))
for aux_model in self.explorer.auxiliary_models:
if not aux_model.model_path:
raise ValueError("auxiliary model's model_path is required.")
set_if_none(aux_model, "max_model_len", self.model.max_model_len)
set_if_none(aux_model, "max_prompt_tokens", self.model.max_prompt_tokens)
set_if_none(aux_model, "max_response_tokens", self.model.max_response_tokens)
set_if_none(aux_model, "min_response_tokens", self.model.min_response_tokens)
for args in rollout_args + length_args:
set_if_none(aux_model, args, getattr(self.model, args))

# for lora configs
if self.model.lora_configs is not None:
Expand Down
4 changes: 2 additions & 2 deletions trinity/common/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import socket
from abc import ABC, abstractmethod
from functools import partial
from typing import Any, List, Optional, Sequence, Tuple, Union
from typing import List, Optional, Sequence, Tuple, Union

import httpx
import numpy as np
Expand Down Expand Up @@ -83,7 +83,7 @@ class ModelWrapper:

def __init__(
self,
model: Any,
model: InferenceModel,
engine_type: str = "vllm",
enable_lora: bool = False,
enable_history: bool = False,
Expand Down
6 changes: 6 additions & 0 deletions trinity/common/models/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ def __init__(
gpu_memory_utilization=config.gpu_memory_utilization,
enable_chunked_prefill=config.enable_chunked_prefill,
# max_num_batched_tokens=256, # you can further set this parameter to reduce the vllm peak memory usage
override_generation_config={ # TODO: find a way to unittest this
"temperature": config.temperature,
"top_p": config.top_p,
"top_k": config.top_k,
"max_new_tokens": config.max_response_tokens,
},
disable_log_stats=True,
enable_lora=config.enable_lora,
**config.lora_kwargs,
Expand Down
2 changes: 2 additions & 0 deletions trinity/common/workflows/agentscope_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def __init__(
model.get_openai_async_client(),
generate_kwargs={
"temperature": self.task.rollout_args.temperature,
"top_p": self.task.rollout_args.top_p,
"max_tokens": self.task.rollout_args.max_tokens or 4096,
"logprobs": True,
"top_logprobs": self.task.rollout_args.logprobs,
},
)
Expand Down