-
-
Notifications
You must be signed in to change notification settings - Fork 4k
Fix grpo nan #3278
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Fix grpo nan #3278
Changes from 26 commits
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
f911c32
Kept, padding logic
pluesclues 2ba7f50
Made sure prediction step in rl.py allows logging for callbacks in RL…
pluesclues 0c1bc4d
Merge branch 'unslothai:main' into main
pluesclues 78336ce
updated llama.py to new online_dpo changes
pluesclues 383aa9c
Update rl.py to make logic simpiler
pluesclues 532af4f
Update rl.py, made sure tokenized_output on eval step was on same device
pluesclues 49f77c1
Update rl.py, corrected tokenized_outputs to inputs
pluesclues 7921aa7
Update rl.py, removed sagemaker stuff
pluesclues 54f03ee
Update llama.py, figures out if there is right padding automatically
pluesclues a8d4168
Update llama.py, changed conditional statement for right padding slig…
pluesclues 236b924
Update llama.py, updated OS.environ variable to temp variable
pluesclues 76d73c6
Merge branch 'main' into main
pluesclues fa2e18e
Update rl.py, made it account for right padding in online dpo and rew…
pluesclues 80f9cd2
Update llama.py, automatically figures out if right padding is needed
pluesclues ed1771a
Merge branch 'main' into main
pluesclues 49d3844
Merge branch 'main' into main
pluesclues b0a9c65
Merge branch 'unslothai:main' into main
pluesclues 6edcb0d
Merge branch 'unslothai:main' into main
pluesclues 90c581b
Merge branch 'unslothai:main' into main
pluesclues 0d2b9dc
Merge branch 'unslothai:main' into fix_grpo_nan
pluesclues 5df4532
Update rl_replacements.py
pluesclues eb65ecf
Update rl.py
pluesclues 4751abf
Update rl.py, chagned order of util functions for padding
pluesclues d86953b
Update rl_replacements.py, disabled commenting out logits_to_keep
pluesclues 190c2c0
Update llama.py
pluesclues 0b9068c
Merge branch 'unslothai:main' into fix_grpo_nan
pluesclues 1fba36c
Update unsloth/models/rl.py
pluesclues fa48726
Merge branch 'unslothai:main' into fix_grpo_nan
pluesclues fad14ca
Update rl_replacements.py
pluesclues 6aedc2f
Update rl_replacements.py, added new line
pluesclues 4bcd41e
Update rl_replacements.py, updated version
pluesclues File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -245,6 +245,20 @@ def grpo_trainer__generate_and_score_completions(function_name, function): | |
| "prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False", | ||
| ) | ||
|
|
||
| # Left pad prompt before calculation old and ref hidden states | ||
| line_to_replace = "batch_size = self.args.per_device_train_batch_size if mode == \"train\" else self.args.per_device_eval_batch_size" | ||
|
|
||
| # The new lines you want to insert | ||
| replacement_lines = """batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size | ||
| prompt_completion_ids = left_pack_padding(prompt_completion_ids, self.processing_class.pad_token_id)""" | ||
|
||
|
|
||
| function = function.replace(line_to_replace, replacement_lines) | ||
|
|
||
| # function = function.replace( | ||
pluesclues marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # "logits_to_keep,", | ||
| # "#logits_to_keep,", | ||
| # ) | ||
|
|
||
| # Always between max_prompt_length and use_vllm | ||
| found = re.findall( | ||
| r"\n(([ ]{8,})if self\.max_prompt_length is not None:.*?"\ | ||
|
|
@@ -282,7 +296,8 @@ def strip_leading_tokens(text): | |
| # Generate completions using either vLLM or regular generation | ||
| if self.use_vllm:""" | ||
| function = function.replace(replace_part, new_replacement) | ||
| pass | ||
|
|
||
|
|
||
| return function | ||
| pass | ||
| RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__generate_and_score_completions) | ||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.