Skip to content

Commit 8376f61

Browse files
committed
Add trust_remote_code to GRPOConfig
1 parent 5c52f46 commit 8376f61

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

trl/trainer/grpo_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,10 @@ class GRPOConfig(TrainingArguments):
286286
"it prevents the model from generating different logprobs for the same input."
287287
},
288288
)
289+
trust_remote_code: bool = field(
290+
default=False,
291+
metadata={"help": "Whether to trust remote code when loading custom models e.g. from the Hugging Face Hub."},
292+
)
289293

290294
# Parameters that control the data preprocessing
291295
# The default value remove_unused_columns is overwritten from the parent class, because in GRPO we usually rely on

trl/trainer/grpo_trainer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def __init__(
239239
f"a `torch.dtype` (e.g., 'float32'), but got {dtype}."
240240
)
241241
# Disable caching if gradient checkpointing is enabled (not supported)
242-
config = AutoConfig.from_pretrained(model_id)
242+
config = AutoConfig.from_pretrained(model_id, trust_remote_code=self.args.trust_remote_code)
243243
architecture = getattr(transformers, config.architectures[0])
244244
model = architecture.from_pretrained(model_id, **model_init_kwargs)
245245
else:
@@ -263,7 +263,9 @@ def __init__(
263263

264264
# Processing class
265265
if processing_class is None:
266-
processing_class = AutoProcessor.from_pretrained(model.config._name_or_path)
266+
processing_class = AutoProcessor.from_pretrained(
267+
model.config._name_or_path, trust_remote_code=self.args.trust_remote_code
268+
)
267269

268270
# Handle pad token for processors or tokenizers
269271
if isinstance(processing_class, ProcessorMixin):
@@ -427,7 +429,7 @@ def __init__(
427429
self.ref_model = None
428430
else:
429431
# For deepspeed, fsdp or non-distributed models, create a reference model from scratch
430-
config = AutoConfig.from_pretrained(model_id)
432+
config = AutoConfig.from_pretrained(model_id, trust_remote_code=self.args.trust_remote_code)
431433
architecture = getattr(transformers, config.architectures[0])
432434
self.ref_model = architecture.from_pretrained(model_id, **model_init_kwargs)
433435

@@ -537,6 +539,7 @@ def __init__(
537539
max_num_batched_tokens=4096,
538540
model_impl=self.args.vllm_model_impl,
539541
enable_sleep_mode=self.args.vllm_enable_sleep_mode,
542+
trust_remote_code=self.args.trust_remote_code,
540543
)
541544
if self.args.vllm_enable_sleep_mode:
542545
self.llm.sleep(level=1)

0 commit comments

Comments
 (0)