Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 45 additions & 13 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@ class FormatConfig:

@dataclass
class GenerationConfig:
temperature: float = 1.0
top_p: float = 1.0
top_k: int = -1
logprobs: int = 0 # vLLM return `logprobs + 1` elements
temperature: Optional[float] = None # 1.0
top_p: Optional[float] = None # 1.0
top_k: Optional[int] = None # -1
logprobs: Optional[int] = None # 0 # vLLM return `logprobs + 1` elements
max_tokens: Optional[int] = None # if None, use model.max_response_tokens
# repeat each task for `n` times
# ! DO NOT SET in `buffer.explorer_input.taskset.rollout_args`
Expand Down Expand Up @@ -412,6 +412,12 @@ class ModelConfig:

custom_chat_template: Optional[str] = None

# rollout args
temperature: float = 1.0
top_p: float = 1.0
top_k: int = -1
logprobs: int = 0

# the total number of tokens the model can handle
max_model_len: Optional[int] = None

Expand Down Expand Up @@ -447,6 +453,12 @@ class InferenceModelConfig:
dtype: str = "bfloat16"
seed: int = 42

# rollout args, ! DO NOT SET
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
logprobs: Optional[int] = None

# if not set, use `model.max_model_len`
max_model_len: Optional[int] = None
# if not set, use `model.max_prompt_tokens`
Expand Down Expand Up @@ -853,6 +865,10 @@ def _check_explorer_input(self) -> None:
set_if_none(taskset, "default_workflow_type", explorer_input.default_workflow_type)
set_if_none(taskset, "default_reward_fn_type", explorer_input.default_reward_fn_type)
set_if_none(taskset, "ray_namespace", self.ray_namespace)
set_if_none(taskset.rollout_args, "temperature", self.model.temperature)
set_if_none(taskset.rollout_args, "top_p", self.model.top_p)
set_if_none(taskset.rollout_args, "top_k", self.model.top_k)
set_if_none(taskset.rollout_args, "logprobs", self.model.logprobs)
set_if_none(taskset.rollout_args, "max_tokens", self.model.max_response_tokens)

for idx, dataset in enumerate(explorer_input.eval_tasksets):
Expand All @@ -868,6 +884,10 @@ def _check_explorer_input(self) -> None:
set_if_none(dataset, "default_workflow_type", explorer_input.default_workflow_type)
set_if_none(dataset, "default_reward_fn_type", explorer_input.default_reward_fn_type)
set_if_none(dataset, "ray_namespace", self.ray_namespace)
set_if_none(dataset.rollout_args, "temperature", self.model.temperature)
set_if_none(dataset.rollout_args, "top_p", self.model.top_p)
set_if_none(dataset.rollout_args, "top_k", self.model.top_k)
set_if_none(dataset.rollout_args, "logprobs", self.model.logprobs)
set_if_none(dataset.rollout_args, "max_tokens", self.model.max_response_tokens)

def _check_trainer_input(self) -> None:
Expand Down Expand Up @@ -1161,18 +1181,30 @@ def check_and_update(self) -> Config: # noqa: C901

# check explorer
if self.explorer is not None:
self.explorer.rollout_model.model_path = self.model.model_path
self.explorer.rollout_model.max_model_len = self.model.max_model_len
self.explorer.rollout_model.max_prompt_tokens = self.model.max_prompt_tokens
self.explorer.rollout_model.max_response_tokens = self.model.max_response_tokens
self.explorer.rollout_model.min_response_tokens = self.model.min_response_tokens
rollout_args = ["temperature", "top_p", "top_k", "logprobs"]
length_args = [
"max_model_len",
"max_prompt_tokens",
"max_response_tokens",
"min_response_tokens",
]
for args in ["model_path"] + rollout_args + length_args:
setattr(self.explorer.rollout_model, args, getattr(self.model, args))
# self.explorer.rollout_model.model_path = self.model.model_path
# self.explorer.rollout_model.temperature = self.model.temperature
# self.explorer.rollout_model.max_model_len = self.model.max_model_len
# self.explorer.rollout_model.max_prompt_tokens = self.model.max_prompt_tokens
# self.explorer.rollout_model.max_response_tokens = self.model.max_response_tokens
# self.explorer.rollout_model.min_response_tokens = self.model.min_response_tokens
for aux_model in self.explorer.auxiliary_models:
if not aux_model.model_path:
raise ValueError("auxiliary model's model_path is required.")
set_if_none(aux_model, "max_model_len", self.model.max_model_len)
set_if_none(aux_model, "max_prompt_tokens", self.model.max_prompt_tokens)
set_if_none(aux_model, "max_response_tokens", self.model.max_response_tokens)
set_if_none(aux_model, "min_response_tokens", self.model.min_response_tokens)
for args in rollout_args + length_args:
set_if_none(aux_model, args, getattr(self.model, args))
# set_if_none(aux_model, "max_model_len", self.model.max_model_len)
# set_if_none(aux_model, "max_prompt_tokens", self.model.max_prompt_tokens)
# set_if_none(aux_model, "max_response_tokens", self.model.max_response_tokens)
# set_if_none(aux_model, "min_response_tokens", self.model.min_response_tokens)

# for lora configs
if self.model.lora_configs is not None:
Expand Down
21 changes: 16 additions & 5 deletions trinity/common/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def get_model_path(self) -> Optional[str]:
"""Get the model path"""
return None

def get_default_rollout_args(self) -> dict:
"""Get the default rollout arguments."""
raise NotImplementedError


def _history_recorder(func):
"""Decorator to record history of the model calls."""
Expand All @@ -83,7 +87,7 @@ class ModelWrapper:

def __init__(
self,
model: Any,
model: InferenceModel,
engine_type: str = "vllm",
enable_lora: bool = False,
enable_history: bool = False,
Expand All @@ -99,6 +103,8 @@ def __init__(
self.history = []
self.status = RunningStatus.RUNNING
self.request_count = 0
self.default_rollout_args = model.get_default_rollout_args()
self.default_rollout_args["logprobs"] = True

async def prepare(self) -> None:
"""Prepare the model wrapper."""
Expand Down Expand Up @@ -271,10 +277,12 @@ def get_openai_client(self) -> openai.OpenAI:
)
if self.enable_history:
# add a decorator to the openai client to record history
ori_create = partial(self.openai_client.chat.completions.create, logprobs=True)
ori_create = self.openai_client.chat.completions.create

def record_chat_completions(*args, **kwargs):
response = ori_create(*args, **kwargs)
default_kwargs = self.default_rollout_args.copy()
default_kwargs.update(kwargs)
response = ori_create(*args, **default_kwargs)
self.history.extend(convert_api_output_to_experience(response))
return response

Expand All @@ -301,10 +309,13 @@ def get_openai_async_client(self) -> openai.AsyncOpenAI:
)
if self.enable_history:
# add a decorator to the openai client to record history
ori_create = partial(self.openai_async_client.chat.completions.create, logprobs=True)
ori_create = self.openai_async_client.chat.completions.create

async def record_chat_completions(*args, **kwargs):
response = await ori_create(*args, **kwargs)
default_kwargs = self.default_rollout_args.copy()
default_kwargs.update(kwargs)
# print(f"!!!!! {default_kwargs = }")
response = ori_create(*args, **default_kwargs)
self.history.extend(convert_api_output_to_experience(response))
return response

Expand Down
9 changes: 9 additions & 0 deletions trinity/common/models/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,15 @@ def get_model_version(self) -> int:
def get_model_path(self) -> str:
return self.config.model_path

def get_default_rollout_args(self) -> dict:
return {
"temperature": self.config.temperature,
"top_p": self.config.top_p,
"top_k": self.config.top_k,
"max_tokens": self.config.max_response_tokens,
# "n": self.config.repeat_times,
}

def get_lora_request(self, lora_path: Optional[str] = None) -> LoRARequest:
assert self.config.lora_modules is not None
lora_request = LoRARequest(**self.config.lora_modules[0])
Expand Down
3 changes: 3 additions & 0 deletions trinity/common/workflows/agentscope_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def __init__(
model.get_openai_async_client(),
generate_kwargs={
"temperature": self.task.rollout_args.temperature,
"top_p": self.task.rollout_args.top_p,
"top_k": self.task.rollout_args.top_k,
"max_tokens": self.task.rollout_args.max_tokens or 4096,
"logprobs": True,
"top_logprobs": self.task.rollout_args.logprobs,
},
)
Expand Down
Loading