From 90c91b0fdea068496e07c698ac8aab6eaac5244a Mon Sep 17 00:00:00 2001 From: muupan Date: Wed, 1 Oct 2025 16:39:47 +0900 Subject: [PATCH] Add trust_remote_code to GRPOConfig --- trl/trainer/grpo_config.py | 4 ++++ trl/trainer/grpo_trainer.py | 26 +++++++++++++++++++------- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 9c4f4b3a686..378fcb40dd9 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -286,6 +286,10 @@ class GRPOConfig(TrainingArguments): "it prevents the model from generating different logprobs for the same input." }, ) + trust_remote_code: bool = field( + default=False, + metadata={"help": "Whether to trust remote code when loading custom models e.g. from the Hugging Face Hub."}, + ) # Parameters that control the data preprocessing # The default value remove_unused_columns is overwritten from the parent class, because in GRPO we usually rely on diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 51db9a1f505..a10e1e21b30 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -34,6 +34,7 @@ from torch.utils.data import DataLoader, Sampler from transformers import ( AutoConfig, + AutoModelForCausalLM, AutoModelForSequenceClassification, AutoProcessor, AutoTokenizer, @@ -239,9 +240,13 @@ def __init__( f"a `torch.dtype` (e.g., 'float32'), but got {dtype}." ) # Disable caching if gradient checkpointing is enabled (not supported) - config = AutoConfig.from_pretrained(model_id) - architecture = getattr(transformers, config.architectures[0]) - model = architecture.from_pretrained(model_id, **model_init_kwargs) + config = AutoConfig.from_pretrained(model_id, trust_remote_code=args.trust_remote_code) + if architecture := getattr(transformers, config.architectures[0], None): + model = architecture.from_pretrained(model_id, **model_init_kwargs) + else: + model = AutoModelForCausalLM.from_pretrained( + model_id, trust_remote_code=args.trust_remote_code, **model_init_kwargs + ) else: model_id = model.config._name_or_path if args.model_init_kwargs is not None: @@ -263,7 +268,9 @@ def __init__( # Processing class if processing_class is None: - processing_class = AutoProcessor.from_pretrained(model.config._name_or_path) + processing_class = AutoProcessor.from_pretrained( + model.config._name_or_path, trust_remote_code=args.trust_remote_code + ) # Handle pad token for processors or tokenizers if isinstance(processing_class, ProcessorMixin): @@ -427,9 +434,13 @@ def __init__( self.ref_model = None else: # For deepspeed, fsdp or non-distributed models, create a reference model from scratch - config = AutoConfig.from_pretrained(model_id) - architecture = getattr(transformers, config.architectures[0]) - self.ref_model = architecture.from_pretrained(model_id, **model_init_kwargs) + config = AutoConfig.from_pretrained(model_id, trust_remote_code=args.trust_remote_code) + if architecture := getattr(transformers, config.architectures[0], None): + self.ref_model = architecture.from_pretrained(model_id, **model_init_kwargs) + else: + self.ref_model = AutoModelForCausalLM.from_pretrained( + model_id, trust_remote_code=args.trust_remote_code, **model_init_kwargs + ) # Disable dropout in the models if args.disable_dropout: @@ -537,6 +548,7 @@ def __init__( max_num_batched_tokens=4096, model_impl=self.args.vllm_model_impl, enable_sleep_mode=self.args.vllm_enable_sleep_mode, + trust_remote_code=self.args.trust_remote_code, ) if self.args.vllm_enable_sleep_mode: self.llm.sleep(level=1)