diff --git a/docs/sphinx_doc/source/tutorial/trinity_installation.md b/docs/sphinx_doc/source/tutorial/trinity_installation.md index a361d67bd4..30ea15759b 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_installation.md +++ b/docs/sphinx_doc/source/tutorial/trinity_installation.md @@ -6,7 +6,7 @@ For installing Trinity-RFT, you have three options: from source (recommended), v Before installing, ensure your system meets the following requirements: - **Python**: Version 3.10 to 3.12 (inclusive) -- **CUDA**: Version 12.4 to 12.8 (inclusive) +- **CUDA**: Version >= 12.6 - **GPUs**: At least 2 GPUs --- diff --git a/docs/sphinx_doc/source_zh/tutorial/trinity_installation.md b/docs/sphinx_doc/source_zh/tutorial/trinity_installation.md index 228be1bd3d..5c0e4cdc18 100644 --- a/docs/sphinx_doc/source_zh/tutorial/trinity_installation.md +++ b/docs/sphinx_doc/source_zh/tutorial/trinity_installation.md @@ -6,7 +6,7 @@ 在安装前,请确保您的系统满足以下要求: - **Python**:3.10 至 3.12(包含) -- **CUDA**:12.4 至 12.8(包含) +- **CUDA**:大于等于 12.6 - **GPU**:至少 2 块 GPU --- diff --git a/trinity/common/config.py b/trinity/common/config.py index 456ff136b5..2efec5b24a 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -78,10 +78,10 @@ class FormatConfig: @dataclass class GenerationConfig: - temperature: float = 1.0 - top_p: float = 1.0 - top_k: int = -1 - logprobs: int = 0 # vLLM return `logprobs + 1` elements + temperature: Optional[float] = None # 1.0 + top_p: Optional[float] = None # 1.0 + top_k: Optional[int] = None # -1 + logprobs: Optional[int] = None # 0 # vLLM return `logprobs + 1` elements max_tokens: Optional[int] = None # if None, use model.max_response_tokens # repeat each task for `n` times # ! DO NOT SET in `buffer.explorer_input.taskset.rollout_args` @@ -412,6 +412,12 @@ class ModelConfig: custom_chat_template: Optional[str] = None + # rollout args + temperature: float = 1.0 + top_p: float = 1.0 + top_k: int = -1 + logprobs: int = 0 + # the total number of tokens the model can handle max_model_len: Optional[int] = None @@ -447,6 +453,12 @@ class InferenceModelConfig: dtype: str = "bfloat16" seed: int = 42 + # rollout args, ! DO NOT SET + temperature: Optional[float] = None + top_p: Optional[float] = None + top_k: Optional[int] = None + logprobs: Optional[int] = None + # if not set, use `model.max_model_len` max_model_len: Optional[int] = None # if not set, use `model.max_prompt_tokens` @@ -853,6 +865,10 @@ def _check_explorer_input(self) -> None: set_if_none(taskset, "default_workflow_type", explorer_input.default_workflow_type) set_if_none(taskset, "default_reward_fn_type", explorer_input.default_reward_fn_type) set_if_none(taskset, "ray_namespace", self.ray_namespace) + set_if_none(taskset.rollout_args, "temperature", self.model.temperature) + set_if_none(taskset.rollout_args, "top_p", self.model.top_p) + set_if_none(taskset.rollout_args, "top_k", self.model.top_k) + set_if_none(taskset.rollout_args, "logprobs", self.model.logprobs) set_if_none(taskset.rollout_args, "max_tokens", self.model.max_response_tokens) for idx, dataset in enumerate(explorer_input.eval_tasksets): @@ -868,6 +884,10 @@ def _check_explorer_input(self) -> None: set_if_none(dataset, "default_workflow_type", explorer_input.default_workflow_type) set_if_none(dataset, "default_reward_fn_type", explorer_input.default_reward_fn_type) set_if_none(dataset, "ray_namespace", self.ray_namespace) + set_if_none(dataset.rollout_args, "temperature", self.model.temperature) + set_if_none(dataset.rollout_args, "top_p", self.model.top_p) + set_if_none(dataset.rollout_args, "top_k", self.model.top_k) + set_if_none(dataset.rollout_args, "logprobs", self.model.logprobs) set_if_none(dataset.rollout_args, "max_tokens", self.model.max_response_tokens) def _check_trainer_input(self) -> None: @@ -1161,18 +1181,20 @@ def check_and_update(self) -> Config: # noqa: C901 # check explorer if self.explorer is not None: - self.explorer.rollout_model.model_path = self.model.model_path - self.explorer.rollout_model.max_model_len = self.model.max_model_len - self.explorer.rollout_model.max_prompt_tokens = self.model.max_prompt_tokens - self.explorer.rollout_model.max_response_tokens = self.model.max_response_tokens - self.explorer.rollout_model.min_response_tokens = self.model.min_response_tokens + rollout_args = ["temperature", "top_p", "top_k", "logprobs"] + length_args = [ + "max_model_len", + "max_prompt_tokens", + "max_response_tokens", + "min_response_tokens", + ] + for args in ["model_path"] + rollout_args + length_args: + setattr(self.explorer.rollout_model, args, getattr(self.model, args)) for aux_model in self.explorer.auxiliary_models: if not aux_model.model_path: raise ValueError("auxiliary model's model_path is required.") - set_if_none(aux_model, "max_model_len", self.model.max_model_len) - set_if_none(aux_model, "max_prompt_tokens", self.model.max_prompt_tokens) - set_if_none(aux_model, "max_response_tokens", self.model.max_response_tokens) - set_if_none(aux_model, "min_response_tokens", self.model.min_response_tokens) + for args in rollout_args + length_args: + set_if_none(aux_model, args, getattr(self.model, args)) # for lora configs if self.model.lora_configs is not None: diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index c22f6cbc99..f09756511d 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -4,7 +4,7 @@ import socket from abc import ABC, abstractmethod from functools import partial -from typing import Any, List, Optional, Sequence, Tuple, Union +from typing import List, Optional, Sequence, Tuple, Union import httpx import numpy as np @@ -83,7 +83,7 @@ class ModelWrapper: def __init__( self, - model: Any, + model: InferenceModel, engine_type: str = "vllm", enable_lora: bool = False, enable_history: bool = False, diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 5b22dddb85..982b2e7d0b 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -83,6 +83,12 @@ def __init__( gpu_memory_utilization=config.gpu_memory_utilization, enable_chunked_prefill=config.enable_chunked_prefill, # max_num_batched_tokens=256, # you can further set this parameter to reduce the vllm peak memory usage + override_generation_config={ # TODO: find a way to unittest this + "temperature": config.temperature, + "top_p": config.top_p, + "top_k": config.top_k, + "max_new_tokens": config.max_response_tokens, + }, disable_log_stats=True, enable_lora=config.enable_lora, **config.lora_kwargs, diff --git a/trinity/common/workflows/agentscope_workflow.py b/trinity/common/workflows/agentscope_workflow.py index 77a5d5ed7f..03f3689d72 100644 --- a/trinity/common/workflows/agentscope_workflow.py +++ b/trinity/common/workflows/agentscope_workflow.py @@ -47,7 +47,9 @@ def __init__( model.get_openai_async_client(), generate_kwargs={ "temperature": self.task.rollout_args.temperature, + "top_p": self.task.rollout_args.top_p, "max_tokens": self.task.rollout_args.max_tokens or 4096, + "logprobs": True, "top_logprobs": self.task.rollout_args.logprobs, }, )