Skip to content

[GRPO] Allow the use of the vllm logprobs, rather than recomputing them #3193

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from

Conversation

edbeeching
Copy link
Collaborator

This PR exposes the logprobs of token IDs that were generated from the vllm client.
The default will be false, as they are likely different to those produced by the policy.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good with a potential typo in the padding of the log probs. It would be quite interesting to know if there is any noticeable difference between using the vLLM logprobs vs native ones when beta > 0 (not for this PR, but more generally on some simple baseline)

@@ -81,6 +81,8 @@ class GRPOConfig(TrainingArguments):
use_vllm (`bool`, *optional*, defaults to `False`):
Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for
training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`).
use_vllm_logprobs (`bool`, *optional*, defaults to `False`):
Whether to use vLLM's logprobs for the `"old_logprobs"` in the GRPO loss. Requires `use_vllm=True`.
Copy link
Member

@lewtun lewtun Mar 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
Whether to use vLLM's logprobs for the `"old_logprobs"` in the GRPO loss. Requires `use_vllm=True`.
Whether to use vLLM's logprobs to compute the GRPO loss instead of using the native `forward()` method. This is more compute efficient because vLLM computes the policy's logprobs in parallel to completions. Requires use_vllm=True.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. By the way, I think there may be some confusion, this PR is not related to the ref model at all. Previously the old_logprobs were (re)calculated from the model we are optimizing in torch.no_grad mode, using the tokens generated from the vllm instance. They are used as part of GRPO's clipping loss to constrain the model's updates to be within beta of the old model.

Now we use the logprobs from the vllm instance, I believe you are thinking of the KL penalty, which still uses the log_probs from the refmodel and the model we are optimizing. Not the vllm logprobs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yeah you're right - I've amended my suggestion to align with that. (I think it's not a good idea to reference internal variables in docstrings as the user cannot see them unless they drill into the code)

use_vllm_logprobs: bool = field(
default=False,
metadata={
"help": "Whether to use vLLM's logprobs for the 'old_logprobs' in the GRPO loss. Requires use_vllm=True."
Copy link
Member

@lewtun lewtun Mar 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"help": "Whether to use vLLM's logprobs for the 'old_logprobs' in the GRPO loss. Requires use_vllm=True."
Whether to use vLLM's logprobs to compute the GRPO loss instead of using the native `forward()` method. This is more compute efficient because vLLM computes the policy's logprobs in parallel to completions. Requires use_vllm=True.

# Broadcast the completions from the main process to all processes, ensuring each process receives its
# corresponding slice.
completion_ids = broadcast_object_list(completion_ids, from_process=0)
vllm_log_probs = broadcast_object_list(vllm_log_probs, from_process=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we get a performance hit from broadcasting these arrays when use_vllm_logprobs=False or is it marginal?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most likely marginal as the GPUs all already in sync due to the gather of completion_ids.

Copy link
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👌


# Pad the completions, and concatenate them with the prompts
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
vllm_log_probs = [torch.tensor(logp, device=device) for logp in vllm_log_probs]
vllm_log_probs = pad(vllm_log_probs, padding_value=self.processing_class.pad_token_id)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
vllm_log_probs = pad(vllm_log_probs, padding_value=self.processing_class.pad_token_id)
vllm_log_probs = pad(vllm_log_probs, padding_value=0.0)

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification about how the logprobs are used in the loss. LGTM with the minor tweak to the docstring

Comment on lines 325 to +326
{"completion_ids": [[101, 102, 103], [201, 202, 203]]}
{"log_probs": [[1.1, 1.2, 1.3], [2.1, 2.2, 2.3]]}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
{"completion_ids": [[101, 102, 103], [201, 202, 203]]}
{"log_probs": [[1.1, 1.2, 1.3], [2.1, 2.2, 2.3]]}
{"completion_ids": [[101, 102, 103], [201, 202, 203]],
"log_probs": [[1.1, 1.2, 1.3], [2.1, 2.2, 2.3]]}

@edbeeching
Copy link
Collaborator Author

I benchmarked this, and the performance is comparable but speed improvement is marginal, so it's not worth adding the complexity to the codebase.
image

@edbeeching edbeeching closed this Apr 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants