Skip to content

Commit bb910e4

Browse files
authored
Add generation args to ModelConfig (#357)
1 parent 3343729 commit bb910e4

File tree

6 files changed

+47
-17
lines changed

6 files changed

+47
-17
lines changed

docs/sphinx_doc/source/tutorial/trinity_installation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ For installing Trinity-RFT, you have three options: from source (recommended), v
66
Before installing, ensure your system meets the following requirements:
77

88
- **Python**: Version 3.10 to 3.12 (inclusive)
9-
- **CUDA**: Version 12.4 to 12.8 (inclusive)
9+
- **CUDA**: Version >= 12.6
1010
- **GPUs**: At least 2 GPUs
1111

1212
---

docs/sphinx_doc/source_zh/tutorial/trinity_installation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
在安装前,请确保您的系统满足以下要求:
77

88
- **Python**:3.10 至 3.12(包含)
9-
- **CUDA**12.4 至 12.8(包含)
9+
- **CUDA**大于等于 12.6
1010
- **GPU**:至少 2 块 GPU
1111

1212
---

trinity/common/config.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,10 @@ class FormatConfig:
7878

7979
@dataclass
8080
class GenerationConfig:
81-
temperature: float = 1.0
82-
top_p: float = 1.0
83-
top_k: int = -1
84-
logprobs: int = 0 # vLLM return `logprobs + 1` elements
81+
temperature: Optional[float] = None # 1.0
82+
top_p: Optional[float] = None # 1.0
83+
top_k: Optional[int] = None # -1
84+
logprobs: Optional[int] = None # 0 # vLLM return `logprobs + 1` elements
8585
max_tokens: Optional[int] = None # if None, use model.max_response_tokens
8686
# repeat each task for `n` times
8787
# ! DO NOT SET in `buffer.explorer_input.taskset.rollout_args`
@@ -412,6 +412,12 @@ class ModelConfig:
412412

413413
custom_chat_template: Optional[str] = None
414414

415+
# rollout args
416+
temperature: float = 1.0
417+
top_p: float = 1.0
418+
top_k: int = -1
419+
logprobs: int = 0
420+
415421
# the total number of tokens the model can handle
416422
max_model_len: Optional[int] = None
417423

@@ -447,6 +453,12 @@ class InferenceModelConfig:
447453
dtype: str = "bfloat16"
448454
seed: int = 42
449455

456+
# rollout args, ! DO NOT SET
457+
temperature: Optional[float] = None
458+
top_p: Optional[float] = None
459+
top_k: Optional[int] = None
460+
logprobs: Optional[int] = None
461+
450462
# if not set, use `model.max_model_len`
451463
max_model_len: Optional[int] = None
452464
# if not set, use `model.max_prompt_tokens`
@@ -853,6 +865,10 @@ def _check_explorer_input(self) -> None:
853865
set_if_none(taskset, "default_workflow_type", explorer_input.default_workflow_type)
854866
set_if_none(taskset, "default_reward_fn_type", explorer_input.default_reward_fn_type)
855867
set_if_none(taskset, "ray_namespace", self.ray_namespace)
868+
set_if_none(taskset.rollout_args, "temperature", self.model.temperature)
869+
set_if_none(taskset.rollout_args, "top_p", self.model.top_p)
870+
set_if_none(taskset.rollout_args, "top_k", self.model.top_k)
871+
set_if_none(taskset.rollout_args, "logprobs", self.model.logprobs)
856872
set_if_none(taskset.rollout_args, "max_tokens", self.model.max_response_tokens)
857873

858874
for idx, dataset in enumerate(explorer_input.eval_tasksets):
@@ -868,6 +884,10 @@ def _check_explorer_input(self) -> None:
868884
set_if_none(dataset, "default_workflow_type", explorer_input.default_workflow_type)
869885
set_if_none(dataset, "default_reward_fn_type", explorer_input.default_reward_fn_type)
870886
set_if_none(dataset, "ray_namespace", self.ray_namespace)
887+
set_if_none(dataset.rollout_args, "temperature", self.model.temperature)
888+
set_if_none(dataset.rollout_args, "top_p", self.model.top_p)
889+
set_if_none(dataset.rollout_args, "top_k", self.model.top_k)
890+
set_if_none(dataset.rollout_args, "logprobs", self.model.logprobs)
871891
set_if_none(dataset.rollout_args, "max_tokens", self.model.max_response_tokens)
872892

873893
def _check_trainer_input(self) -> None:
@@ -1161,18 +1181,20 @@ def check_and_update(self) -> Config: # noqa: C901
11611181

11621182
# check explorer
11631183
if self.explorer is not None:
1164-
self.explorer.rollout_model.model_path = self.model.model_path
1165-
self.explorer.rollout_model.max_model_len = self.model.max_model_len
1166-
self.explorer.rollout_model.max_prompt_tokens = self.model.max_prompt_tokens
1167-
self.explorer.rollout_model.max_response_tokens = self.model.max_response_tokens
1168-
self.explorer.rollout_model.min_response_tokens = self.model.min_response_tokens
1184+
rollout_args = ["temperature", "top_p", "top_k", "logprobs"]
1185+
length_args = [
1186+
"max_model_len",
1187+
"max_prompt_tokens",
1188+
"max_response_tokens",
1189+
"min_response_tokens",
1190+
]
1191+
for args in ["model_path"] + rollout_args + length_args:
1192+
setattr(self.explorer.rollout_model, args, getattr(self.model, args))
11691193
for aux_model in self.explorer.auxiliary_models:
11701194
if not aux_model.model_path:
11711195
raise ValueError("auxiliary model's model_path is required.")
1172-
set_if_none(aux_model, "max_model_len", self.model.max_model_len)
1173-
set_if_none(aux_model, "max_prompt_tokens", self.model.max_prompt_tokens)
1174-
set_if_none(aux_model, "max_response_tokens", self.model.max_response_tokens)
1175-
set_if_none(aux_model, "min_response_tokens", self.model.min_response_tokens)
1196+
for args in rollout_args + length_args:
1197+
set_if_none(aux_model, args, getattr(self.model, args))
11761198

11771199
# for lora configs
11781200
if self.model.lora_configs is not None:

trinity/common/models/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import socket
55
from abc import ABC, abstractmethod
66
from functools import partial
7-
from typing import Any, List, Optional, Sequence, Tuple, Union
7+
from typing import List, Optional, Sequence, Tuple, Union
88

99
import httpx
1010
import numpy as np
@@ -83,7 +83,7 @@ class ModelWrapper:
8383

8484
def __init__(
8585
self,
86-
model: Any,
86+
model: InferenceModel,
8787
engine_type: str = "vllm",
8888
enable_lora: bool = False,
8989
enable_history: bool = False,

trinity/common/models/vllm_model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,12 @@ def __init__(
8383
gpu_memory_utilization=config.gpu_memory_utilization,
8484
enable_chunked_prefill=config.enable_chunked_prefill,
8585
# max_num_batched_tokens=256, # you can further set this parameter to reduce the vllm peak memory usage
86+
override_generation_config={ # TODO: find a way to unittest this
87+
"temperature": config.temperature,
88+
"top_p": config.top_p,
89+
"top_k": config.top_k,
90+
"max_new_tokens": config.max_response_tokens,
91+
},
8692
disable_log_stats=True,
8793
enable_lora=config.enable_lora,
8894
**config.lora_kwargs,

trinity/common/workflows/agentscope_workflow.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ def __init__(
4747
model.get_openai_async_client(),
4848
generate_kwargs={
4949
"temperature": self.task.rollout_args.temperature,
50+
"top_p": self.task.rollout_args.top_p,
5051
"max_tokens": self.task.rollout_args.max_tokens or 4096,
52+
"logprobs": True,
5153
"top_logprobs": self.task.rollout_args.logprobs,
5254
},
5355
)

0 commit comments

Comments
 (0)