@@ -78,10 +78,10 @@ class FormatConfig:
7878
7979@dataclass
8080class GenerationConfig :
81- temperature : float = 1.0
82- top_p : float = 1.0
83- top_k : int = - 1
84- logprobs : int = 0 # vLLM return `logprobs + 1` elements
81+ temperature : Optional [ float ] = None # 1.0
82+ top_p : Optional [ float ] = None # 1.0
83+ top_k : Optional [ int ] = None # -1
84+ logprobs : Optional [ int ] = None # 0 # vLLM return `logprobs + 1` elements
8585 max_tokens : Optional [int ] = None # if None, use model.max_response_tokens
8686 # repeat each task for `n` times
8787 # ! DO NOT SET in `buffer.explorer_input.taskset.rollout_args`
@@ -412,6 +412,12 @@ class ModelConfig:
412412
413413 custom_chat_template : Optional [str ] = None
414414
415+ # rollout args
416+ temperature : float = 1.0
417+ top_p : float = 1.0
418+ top_k : int = - 1
419+ logprobs : int = 0
420+
415421 # the total number of tokens the model can handle
416422 max_model_len : Optional [int ] = None
417423
@@ -447,6 +453,12 @@ class InferenceModelConfig:
447453 dtype : str = "bfloat16"
448454 seed : int = 42
449455
456+ # rollout args, ! DO NOT SET
457+ temperature : Optional [float ] = None
458+ top_p : Optional [float ] = None
459+ top_k : Optional [int ] = None
460+ logprobs : Optional [int ] = None
461+
450462 # if not set, use `model.max_model_len`
451463 max_model_len : Optional [int ] = None
452464 # if not set, use `model.max_prompt_tokens`
@@ -853,6 +865,10 @@ def _check_explorer_input(self) -> None:
853865 set_if_none (taskset , "default_workflow_type" , explorer_input .default_workflow_type )
854866 set_if_none (taskset , "default_reward_fn_type" , explorer_input .default_reward_fn_type )
855867 set_if_none (taskset , "ray_namespace" , self .ray_namespace )
868+ set_if_none (taskset .rollout_args , "temperature" , self .model .temperature )
869+ set_if_none (taskset .rollout_args , "top_p" , self .model .top_p )
870+ set_if_none (taskset .rollout_args , "top_k" , self .model .top_k )
871+ set_if_none (taskset .rollout_args , "logprobs" , self .model .logprobs )
856872 set_if_none (taskset .rollout_args , "max_tokens" , self .model .max_response_tokens )
857873
858874 for idx , dataset in enumerate (explorer_input .eval_tasksets ):
@@ -868,6 +884,10 @@ def _check_explorer_input(self) -> None:
868884 set_if_none (dataset , "default_workflow_type" , explorer_input .default_workflow_type )
869885 set_if_none (dataset , "default_reward_fn_type" , explorer_input .default_reward_fn_type )
870886 set_if_none (dataset , "ray_namespace" , self .ray_namespace )
887+ set_if_none (dataset .rollout_args , "temperature" , self .model .temperature )
888+ set_if_none (dataset .rollout_args , "top_p" , self .model .top_p )
889+ set_if_none (dataset .rollout_args , "top_k" , self .model .top_k )
890+ set_if_none (dataset .rollout_args , "logprobs" , self .model .logprobs )
871891 set_if_none (dataset .rollout_args , "max_tokens" , self .model .max_response_tokens )
872892
873893 def _check_trainer_input (self ) -> None :
@@ -1161,18 +1181,20 @@ def check_and_update(self) -> Config: # noqa: C901
11611181
11621182 # check explorer
11631183 if self .explorer is not None :
1164- self .explorer .rollout_model .model_path = self .model .model_path
1165- self .explorer .rollout_model .max_model_len = self .model .max_model_len
1166- self .explorer .rollout_model .max_prompt_tokens = self .model .max_prompt_tokens
1167- self .explorer .rollout_model .max_response_tokens = self .model .max_response_tokens
1168- self .explorer .rollout_model .min_response_tokens = self .model .min_response_tokens
1184+ rollout_args = ["temperature" , "top_p" , "top_k" , "logprobs" ]
1185+ length_args = [
1186+ "max_model_len" ,
1187+ "max_prompt_tokens" ,
1188+ "max_response_tokens" ,
1189+ "min_response_tokens" ,
1190+ ]
1191+ for args in ["model_path" ] + rollout_args + length_args :
1192+ setattr (self .explorer .rollout_model , args , getattr (self .model , args ))
11691193 for aux_model in self .explorer .auxiliary_models :
11701194 if not aux_model .model_path :
11711195 raise ValueError ("auxiliary model's model_path is required." )
1172- set_if_none (aux_model , "max_model_len" , self .model .max_model_len )
1173- set_if_none (aux_model , "max_prompt_tokens" , self .model .max_prompt_tokens )
1174- set_if_none (aux_model , "max_response_tokens" , self .model .max_response_tokens )
1175- set_if_none (aux_model , "min_response_tokens" , self .model .min_response_tokens )
1196+ for args in rollout_args + length_args :
1197+ set_if_none (aux_model , args , getattr (self .model , args ))
11761198
11771199 # for lora configs
11781200 if self .model .lora_configs is not None :
0 commit comments