diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index f6f3c6e346..d6cea987f1 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -286,6 +286,10 @@ class GRPOConfig(TrainingArguments): "it prevents the model from generating different logprobs for the same input." }, ) + trust_remote_code: bool = field( + default=False, + metadata={"help": "Whether to trust remote code when loading custom models e.g. from the Hugging Face Hub."}, + ) # Parameters that control the data preprocessing # The default value remove_unused_columns is overwritten from the parent class, because in GRPO we usually rely on diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 352a0144ef..dfee0fd545 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -33,6 +33,7 @@ from torch.utils.data import DataLoader, Sampler from transformers import ( AutoConfig, + AutoModelForCausalLM, AutoModelForSequenceClassification, AutoProcessor, AutoTokenizer, @@ -249,9 +250,13 @@ def __init__( f"a `torch.dtype` (e.g., 'float32'), but got {dtype}." ) # Disable caching if gradient checkpointing is enabled (not supported) - config = AutoConfig.from_pretrained(model_id) - architecture = getattr(transformers, config.architectures[0]) - model = architecture.from_pretrained(model_id, **model_init_kwargs) + config = AutoConfig.from_pretrained(model_id, trust_remote_code=args.trust_remote_code) + if architecture := getattr(transformers, config.architectures[0], None): + model = architecture.from_pretrained(model_id, **model_init_kwargs) + else: + model = AutoModelForCausalLM.from_pretrained( + model_id, trust_remote_code=args.trust_remote_code, **model_init_kwargs + ) else: model_id = model.config._name_or_path if args.model_init_kwargs is not None: @@ -273,7 +278,9 @@ def __init__( # Processing class if processing_class is None: - processing_class = AutoProcessor.from_pretrained(model.config._name_or_path, truncation_side="left") + processing_class = AutoProcessor.from_pretrained( + model.config._name_or_path, trust_remote_code=args.trust_remote_code, truncation_side="left" + ) # Handle pad token for processors or tokenizers if isinstance(processing_class, ProcessorMixin): @@ -433,9 +440,13 @@ def __init__( self.ref_model = None else: # For deepspeed, fsdp or non-distributed models, create a reference model from scratch - config = AutoConfig.from_pretrained(model_id) - architecture = getattr(transformers, config.architectures[0]) - self.ref_model = architecture.from_pretrained(model_id, **model_init_kwargs) + config = AutoConfig.from_pretrained(model_id, trust_remote_code=args.trust_remote_code) + if architecture := getattr(transformers, config.architectures[0], None): + self.ref_model = architecture.from_pretrained(model_id, **model_init_kwargs) + else: + self.ref_model = AutoModelForCausalLM.from_pretrained( + model_id, trust_remote_code=args.trust_remote_code, **model_init_kwargs + ) # Disable dropout in the models if args.disable_dropout: @@ -543,6 +554,7 @@ def __init__( max_num_batched_tokens=4096, model_impl=self.args.vllm_model_impl, enable_sleep_mode=self.args.vllm_enable_sleep_mode, + trust_remote_code=self.args.trust_remote_code, # Important so temperature scaling/logit tweaking affects the TIS log probs logprobs_mode="processed_logprobs", )