Skip to content

Conversation

muupan
Copy link
Contributor

@muupan muupan commented Oct 1, 2025

What does this PR do?

This PR adds trust_remote_code to GRPOConfig and makes GRPOTrainer use it when making the model and its related objects to support custom models.

Fixes #4129

Before submitting

  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

)
trust_remote_code: bool = field(
default=False,
metadata={"help": "Whether to trust remote code when loading custom models e.g. from the Hugging Face Hub."},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
metadata={"help": "Whether to trust remote code when loading custom models e.g. from the Hugging Face Hub."},
metadata={"help": "Whether to trust remote code when loading custom models from the Hugging Face Hub."},

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trust_remote_code matters when loading things from local files too. The docs of vLLM describe it similarly.

Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer.

https://docs.vllm.ai/en/latest/api/vllm/index.html#vllm.LLM

That said, the docs of transformers only mention the hub. Do you think e.g. should be deleted? I don't have a strong opinion so will follow your preference.

Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set to True for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine.

https://huggingface.co/docs/transformers/v4.56.2/en/model_doc/auto#transformers.AutoModel.from_pretrained.trust_remote_code

# Disable caching if gradient checkpointing is enabled (not supported)
config = AutoConfig.from_pretrained(model_id)
config = AutoConfig.from_pretrained(model_id, trust_remote_code=self.args.trust_remote_code)
architecture = getattr(transformers, config.architectures[0])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When using remote code, the idea is that that model is not included in transformers, right? So maybe you need something like this instead:

if hasattr(transformers, config.architectures[0]):
    architecture = getattr(transformers, config.architectures[0])
    model = architecture.from_pretrained(model_id, **model_init_kwargs)
else:
    model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, you are right. I haven't tested the code path of AutoConfig as I was passing an already loaded model to GRPOTrainer. I will try rewriting it like your code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed it. Now it works when model passed to GRPOTrainer is str.

@muupan muupan force-pushed the feature/grpo-config-trust-remote-code branch from 8376f61 to 90c91b0 Compare October 2, 2025 02:09
@muupan muupan requested a review from qgallouedec October 2, 2025 02:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Setting trust_remote_code=True for vLLM in GRPOTrainer with vllm_mode=="colocate"
2 participants