Skip to content

Commit 90c91b0

Browse files
committed
Add trust_remote_code to GRPOConfig
1 parent 5c52f46 commit 90c91b0

File tree

2 files changed

+23
-7
lines changed

2 files changed

+23
-7
lines changed

trl/trainer/grpo_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,10 @@ class GRPOConfig(TrainingArguments):
286286
"it prevents the model from generating different logprobs for the same input."
287287
},
288288
)
289+
trust_remote_code: bool = field(
290+
default=False,
291+
metadata={"help": "Whether to trust remote code when loading custom models e.g. from the Hugging Face Hub."},
292+
)
289293

290294
# Parameters that control the data preprocessing
291295
# The default value remove_unused_columns is overwritten from the parent class, because in GRPO we usually rely on

trl/trainer/grpo_trainer.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from torch.utils.data import DataLoader, Sampler
3535
from transformers import (
3636
AutoConfig,
37+
AutoModelForCausalLM,
3738
AutoModelForSequenceClassification,
3839
AutoProcessor,
3940
AutoTokenizer,
@@ -239,9 +240,13 @@ def __init__(
239240
f"a `torch.dtype` (e.g., 'float32'), but got {dtype}."
240241
)
241242
# Disable caching if gradient checkpointing is enabled (not supported)
242-
config = AutoConfig.from_pretrained(model_id)
243-
architecture = getattr(transformers, config.architectures[0])
244-
model = architecture.from_pretrained(model_id, **model_init_kwargs)
243+
config = AutoConfig.from_pretrained(model_id, trust_remote_code=args.trust_remote_code)
244+
if architecture := getattr(transformers, config.architectures[0], None):
245+
model = architecture.from_pretrained(model_id, **model_init_kwargs)
246+
else:
247+
model = AutoModelForCausalLM.from_pretrained(
248+
model_id, trust_remote_code=args.trust_remote_code, **model_init_kwargs
249+
)
245250
else:
246251
model_id = model.config._name_or_path
247252
if args.model_init_kwargs is not None:
@@ -263,7 +268,9 @@ def __init__(
263268

264269
# Processing class
265270
if processing_class is None:
266-
processing_class = AutoProcessor.from_pretrained(model.config._name_or_path)
271+
processing_class = AutoProcessor.from_pretrained(
272+
model.config._name_or_path, trust_remote_code=args.trust_remote_code
273+
)
267274

268275
# Handle pad token for processors or tokenizers
269276
if isinstance(processing_class, ProcessorMixin):
@@ -427,9 +434,13 @@ def __init__(
427434
self.ref_model = None
428435
else:
429436
# For deepspeed, fsdp or non-distributed models, create a reference model from scratch
430-
config = AutoConfig.from_pretrained(model_id)
431-
architecture = getattr(transformers, config.architectures[0])
432-
self.ref_model = architecture.from_pretrained(model_id, **model_init_kwargs)
437+
config = AutoConfig.from_pretrained(model_id, trust_remote_code=args.trust_remote_code)
438+
if architecture := getattr(transformers, config.architectures[0], None):
439+
self.ref_model = architecture.from_pretrained(model_id, **model_init_kwargs)
440+
else:
441+
self.ref_model = AutoModelForCausalLM.from_pretrained(
442+
model_id, trust_remote_code=args.trust_remote_code, **model_init_kwargs
443+
)
433444

434445
# Disable dropout in the models
435446
if args.disable_dropout:
@@ -537,6 +548,7 @@ def __init__(
537548
max_num_batched_tokens=4096,
538549
model_impl=self.args.vllm_model_impl,
539550
enable_sleep_mode=self.args.vllm_enable_sleep_mode,
551+
trust_remote_code=self.args.trust_remote_code,
540552
)
541553
if self.args.vllm_enable_sleep_mode:
542554
self.llm.sleep(level=1)

0 commit comments

Comments
 (0)