Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,10 @@ class GRPOConfig(TrainingArguments):
"it prevents the model from generating different logprobs for the same input."
},
)
trust_remote_code: bool = field(
default=False,
metadata={"help": "Whether to trust remote code when loading custom models e.g. from the Hugging Face Hub."},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
metadata={"help": "Whether to trust remote code when loading custom models e.g. from the Hugging Face Hub."},
metadata={"help": "Whether to trust remote code when loading custom models from the Hugging Face Hub."},

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trust_remote_code matters when loading things from local files too. The docs of vLLM describe it similarly.

Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer.

https://docs.vllm.ai/en/latest/api/vllm/index.html#vllm.LLM

That said, the docs of transformers only mention the hub. Do you think e.g. should be deleted? I don't have a strong opinion so will follow your preference.

Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set to True for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine.

https://huggingface.co/docs/transformers/v4.56.2/en/model_doc/auto#transformers.AutoModel.from_pretrained.trust_remote_code

)

# Parameters that control the data preprocessing
# The default value remove_unused_columns is overwritten from the parent class, because in GRPO we usually rely on
Expand Down
26 changes: 19 additions & 7 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from torch.utils.data import DataLoader, Sampler
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoProcessor,
AutoTokenizer,
Expand Down Expand Up @@ -251,9 +252,13 @@ def __init__(
f"a `torch.dtype` (e.g., 'float32'), but got {dtype}."
)
# Disable caching if gradient checkpointing is enabled (not supported)
config = AutoConfig.from_pretrained(model_id)
architecture = getattr(transformers, config.architectures[0])
model = architecture.from_pretrained(model_id, **model_init_kwargs)
config = AutoConfig.from_pretrained(model_id, trust_remote_code=args.trust_remote_code)
if architecture := getattr(transformers, config.architectures[0], None):
model = architecture.from_pretrained(model_id, **model_init_kwargs)
else:
model = AutoModelForCausalLM.from_pretrained(
model_id, trust_remote_code=args.trust_remote_code, **model_init_kwargs
)
else:
model_id = model.config._name_or_path
if args.model_init_kwargs is not None:
Expand All @@ -275,7 +280,9 @@ def __init__(

# Processing class
if processing_class is None:
processing_class = AutoProcessor.from_pretrained(model.config._name_or_path)
processing_class = AutoProcessor.from_pretrained(
model.config._name_or_path, trust_remote_code=args.trust_remote_code
)

# Handle pad token for processors or tokenizers
if isinstance(processing_class, ProcessorMixin):
Expand Down Expand Up @@ -439,9 +446,13 @@ def __init__(
self.ref_model = None
else:
# For deepspeed, fsdp or non-distributed models, create a reference model from scratch
config = AutoConfig.from_pretrained(model_id)
architecture = getattr(transformers, config.architectures[0])
self.ref_model = architecture.from_pretrained(model_id, **model_init_kwargs)
config = AutoConfig.from_pretrained(model_id, trust_remote_code=args.trust_remote_code)
if architecture := getattr(transformers, config.architectures[0], None):
self.ref_model = architecture.from_pretrained(model_id, **model_init_kwargs)
else:
self.ref_model = AutoModelForCausalLM.from_pretrained(
model_id, trust_remote_code=args.trust_remote_code, **model_init_kwargs
)

# Disable dropout in the models
if args.disable_dropout:
Expand Down Expand Up @@ -549,6 +560,7 @@ def __init__(
max_num_batched_tokens=4096,
model_impl=self.args.vllm_model_impl,
enable_sleep_mode=self.args.vllm_enable_sleep_mode,
trust_remote_code=self.args.trust_remote_code,
# Important so temperature scaling/logit tweaking affects the TIS log probs
logprobs_mode="processed_logprobs",
)
Expand Down