-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[GRPO] Allow the use of the vllm logprobs, rather than recomputing them #3193
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good with a potential typo in the padding of the log probs. It would be quite interesting to know if there is any noticeable difference between using the vLLM logprobs vs native ones when beta > 0
(not for this PR, but more generally on some simple baseline)
@@ -81,6 +81,8 @@ class GRPOConfig(TrainingArguments): | |||
use_vllm (`bool`, *optional*, defaults to `False`): | |||
Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for | |||
training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`). | |||
use_vllm_logprobs (`bool`, *optional*, defaults to `False`): | |||
Whether to use vLLM's logprobs for the `"old_logprobs"` in the GRPO loss. Requires `use_vllm=True`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
Whether to use vLLM's logprobs for the `"old_logprobs"` in the GRPO loss. Requires `use_vllm=True`. | |
Whether to use vLLM's logprobs to compute the GRPO loss instead of using the native `forward()` method. This is more compute efficient because vLLM computes the policy's logprobs in parallel to completions. Requires use_vllm=True. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. By the way, I think there may be some confusion, this PR is not related to the ref model at all. Previously the old_logprobs
were (re)calculated from the model we are optimizing in torch.no_grad
mode, using the tokens generated from the vllm instance. They are used as part of GRPO's clipping loss to constrain the model's updates to be within beta of the old model.
Now we use the logprobs from the vllm instance, I believe you are thinking of the KL penalty, which still uses the log_probs from the refmodel and the model we are optimizing. Not the vllm logprobs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yeah you're right - I've amended my suggestion to align with that. (I think it's not a good idea to reference internal variables in docstrings as the user cannot see them unless they drill into the code)
use_vllm_logprobs: bool = field( | ||
default=False, | ||
metadata={ | ||
"help": "Whether to use vLLM's logprobs for the 'old_logprobs' in the GRPO loss. Requires use_vllm=True." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"help": "Whether to use vLLM's logprobs for the 'old_logprobs' in the GRPO loss. Requires use_vllm=True." | |
Whether to use vLLM's logprobs to compute the GRPO loss instead of using the native `forward()` method. This is more compute efficient because vLLM computes the policy's logprobs in parallel to completions. Requires use_vllm=True. |
# Broadcast the completions from the main process to all processes, ensuring each process receives its | ||
# corresponding slice. | ||
completion_ids = broadcast_object_list(completion_ids, from_process=0) | ||
vllm_log_probs = broadcast_object_list(vllm_log_probs, from_process=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we get a performance hit from broadcasting these arrays when use_vllm_logprobs=False
or is it marginal?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most likely marginal as the GPUs all already in sync due to the gather of completion_ids
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👌
|
||
# Pad the completions, and concatenate them with the prompts | ||
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] | ||
completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id) | ||
vllm_log_probs = [torch.tensor(logp, device=device) for logp in vllm_log_probs] | ||
vllm_log_probs = pad(vllm_log_probs, padding_value=self.processing_class.pad_token_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vllm_log_probs = pad(vllm_log_probs, padding_value=self.processing_class.pad_token_id) | |
vllm_log_probs = pad(vllm_log_probs, padding_value=0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the clarification about how the logprobs are used in the loss. LGTM with the minor tweak to the docstring
{"completion_ids": [[101, 102, 103], [201, 202, 203]]} | ||
{"log_probs": [[1.1, 1.2, 1.3], [2.1, 2.2, 2.3]]} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
{"completion_ids": [[101, 102, 103], [201, 202, 203]]} | |
{"log_probs": [[1.1, 1.2, 1.3], [2.1, 2.2, 2.3]]} | |
{"completion_ids": [[101, 102, 103], [201, 202, 203]], | |
"log_probs": [[1.1, 1.2, 1.3], [2.1, 2.2, 2.3]]} |
This PR exposes the logprobs of token IDs that were generated from the vllm client.
The default will be false, as they are likely different to those produced by the policy.