Skip to content

Setting trust_remote_code=True for vLLM in GRPOTrainer with vllm_mode=="colocate" #4129

@muupan

Description

@muupan

Feature request

Currently, there is no way to pass trust_remote_code=True to vLLM instance in GRPOTrainer with vllm_mode=="colocate", which makes it unusable for a model with custom code.

When vllm_mode=="colocate", a vLLM instance is created inside GRPOTrainer, but there is no way to pass trust_remote_code to vllm.LLM.

self.llm = LLM(
model=model.name_or_path,
tensor_parallel_size=args.vllm_tensor_parallel_size,
gpu_memory_utilization=self.vllm_gpu_memory_utilization,
max_num_seqs=self.args.per_device_train_batch_size
* self.vllm_tensor_parallel_size
* self.args.steps_per_generation,
max_model_len=max_model_len,
distributed_executor_backend="external_launcher",
# Feed identical seed for tp groups to ensure sampling results are the same across workers
seed=self.accelerator.process_index // self.vllm_tensor_parallel_size,
# Latest vLLM v1 memory profiler is misled by the high default value (i.e., 32768) - thinking there's not enough memory
max_num_batched_tokens=4096,
model_impl=self.args.vllm_model_impl,
enable_sleep_mode=self.args.vllm_enable_sleep_mode,
)

Since trl vllm-serve already has the trust_remote_code option, vllm_mode=="server" is not affected by the issue.

trust_remote_code: bool = field(
default=False,
metadata={
"help": "Whether to trust remote code when loading models. Set to True to allow executing code from model "
"repositories. This is required for some custom models but introduces security risks."
},
)

Motivation

It would be nice if training a custom model is supported by the library.

Your contribution

I can send a PR, but there are several ways to resolve it.

  • Add trust_remote_code (False by default) to GRPOConfig and pass it to not only vllm.LLM but also other classes that accept trust_remote_code: AutoConfig and AutoTokenizer
  • Add vllm_trust_remote_code (False by default) to GRPOConfig and pass it to vllm.LLM only.
  • Add a vllm-specific but more general argument like vllm_init_kwargs to GRPOConfig.

I think the first way is best because setting trust_remote_code to different values for vLLM and AutoConfig does not make sense. Does this sound reasonable?

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions