Skip to content
Closed
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
f911c32
Kept, padding logic
pluesclues Jun 23, 2025
2ba7f50
Made sure prediction step in rl.py allows logging for callbacks in RL…
pluesclues Jun 23, 2025
0c1bc4d
Merge branch 'unslothai:main' into main
pluesclues Jun 23, 2025
78336ce
updated llama.py to new online_dpo changes
pluesclues Jun 23, 2025
383aa9c
Update rl.py to make logic simpiler
pluesclues Jun 23, 2025
532af4f
Update rl.py, made sure tokenized_output on eval step was on same device
pluesclues Jun 24, 2025
49f77c1
Update rl.py, corrected tokenized_outputs to inputs
pluesclues Jun 24, 2025
7921aa7
Update rl.py, removed sagemaker stuff
pluesclues Jun 25, 2025
54f03ee
Update llama.py, figures out if there is right padding automatically
pluesclues Jul 2, 2025
a8d4168
Update llama.py, changed conditional statement for right padding slig…
pluesclues Jul 2, 2025
236b924
Update llama.py, updated OS.environ variable to temp variable
pluesclues Jul 8, 2025
76d73c6
Merge branch 'main' into main
pluesclues Jul 8, 2025
fa2e18e
Update rl.py, made it account for right padding in online dpo and rew…
pluesclues Jul 8, 2025
80f9cd2
Update llama.py, automatically figures out if right padding is needed
pluesclues Jul 8, 2025
ed1771a
Merge branch 'main' into main
pluesclues Jul 12, 2025
49d3844
Merge branch 'main' into main
pluesclues Aug 3, 2025
b0a9c65
Merge branch 'unslothai:main' into main
pluesclues Aug 8, 2025
6edcb0d
Merge branch 'unslothai:main' into main
pluesclues Aug 11, 2025
90c581b
Merge branch 'unslothai:main' into main
pluesclues Aug 22, 2025
0d2b9dc
Merge branch 'unslothai:main' into fix_grpo_nan
pluesclues Sep 5, 2025
5df4532
Update rl_replacements.py
pluesclues Sep 5, 2025
eb65ecf
Update rl.py
pluesclues Sep 5, 2025
4751abf
Update rl.py, chagned order of util functions for padding
pluesclues Sep 5, 2025
d86953b
Update rl_replacements.py, disabled commenting out logits_to_keep
pluesclues Sep 5, 2025
190c2c0
Update llama.py
pluesclues Sep 5, 2025
0b9068c
Merge branch 'unslothai:main' into fix_grpo_nan
pluesclues Sep 8, 2025
1fba36c
Update unsloth/models/rl.py
pluesclues Sep 9, 2025
fa48726
Merge branch 'unslothai:main' into fix_grpo_nan
pluesclues Sep 9, 2025
fad14ca
Update rl_replacements.py
pluesclues Sep 9, 2025
6aedc2f
Update rl_replacements.py, added new line
pluesclues Sep 9, 2025
4bcd41e
Update rl_replacements.py, updated version
pluesclues Sep 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,19 +840,14 @@ def LlamaModel_fast_forward(
inputs_embeds *= attention_mask.unsqueeze(0).transpose(0, 1).transpose(1, 2)
if inputs_requires_grad: inputs_embeds.requires_grad_(True)
pass

# Ignore attention_mask
if attention_mask is None:
padding_mask = None
elif self.training:
elif self.training:
attention_mask = None
padding_mask = None
else:
# if 0 in attention_mask:
# padding_mask = attention_mask
# else:
padding_mask = None

attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
Expand Down
20 changes: 20 additions & 0 deletions unsloth/models/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@
RL_CONFIG_CHANGES,
RL_METRICS_CHANGES,
)

selective_log_softmax = RL_REPLACEMENTS["selective_log_softmax"]
create_completion_attention_mask = RL_REPLACEMENTS["create_completion_attention_mask"]
calculate_pad_tokens_in_prompt = RL_REPLACEMENTS["calculate_pad_tokens_in_prompt"]
left_pack_padding = RL_REPLACEMENTS["left_pack_padding"]

torch_compile_options = {
"epilogue_fusion" : True,
Expand Down Expand Up @@ -109,6 +113,12 @@ def generate_with_clone(*args, **kwargs):
from torch.nn import functional as F
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling

{create_completion_attention_mask_code}

{calculate_pad_tokens_in_prompt_code}

{left_pack_padding_code}

torch_compile_options = {{
"epilogue_fusion" : True,
"max_autotune" : False,
Expand All @@ -118,6 +128,7 @@ def generate_with_clone(*args, **kwargs):
}}

{selective_log_softmax_code}

{RL_pre}

@dataclass
Expand Down Expand Up @@ -695,6 +706,11 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
# Selective log softmax
selective_log_softmax_code = inspect.getsource(selective_log_softmax)

#GRPO masking code
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#GRPO masking code
# GRPO masking code

create_completion_attention_mask_code = inspect.getsource(create_completion_attention_mask)
calculate_pad_tokens_in_prompt_code = inspect.getsource(calculate_pad_tokens_in_prompt)
left_pack_padding_code = inspect.getsource(left_pack_padding)

# Get final source code
RLTrainer_source = RLTrainer_replacement.format(
RLTrainer_name = RLTrainer_name,
Expand All @@ -720,6 +736,10 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
max_seq_length_post = max_seq_length_post,

selective_log_softmax_code = selective_log_softmax_code,
create_completion_attention_mask_code = create_completion_attention_mask_code,
calculate_pad_tokens_in_prompt_code = calculate_pad_tokens_in_prompt_code,
left_pack_padding_code = left_pack_padding_code,

)

if RLTrainer_name == "SFTTrainer":
Expand Down
17 changes: 16 additions & 1 deletion unsloth/models/rl_replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,20 @@ def grpo_trainer__generate_and_score_completions(function_name, function):
"prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False",
)

# Left pad prompt before calculation old and ref hidden states
line_to_replace = "batch_size = self.args.per_device_train_batch_size if mode == \"train\" else self.args.per_device_eval_batch_size"

# The new lines you want to insert
replacement_lines = """batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
prompt_completion_ids = left_pack_padding(prompt_completion_ids, self.processing_class.pad_token_id)"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe newline?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to easierly resolve merge conflicts when testing with the Fast VLM infernece branch I moved everything in this PR to: #3132

https://github.com/pluesclues/unsloth/blob/fb115fb16cb2592caf99a9414b7d1f95f1f819ca/unsloth/models/rl_replacements.py#L252-L256


function = function.replace(line_to_replace, replacement_lines)

# function = function.replace(
# "logits_to_keep,",
# "#logits_to_keep,",
# )

# Always between max_prompt_length and use_vllm
found = re.findall(
r"\n(([ ]{8,})if self\.max_prompt_length is not None:.*?"\
Expand Down Expand Up @@ -282,7 +296,8 @@ def strip_leading_tokens(text):
# Generate completions using either vLLM or regular generation
if self.use_vllm:"""
function = function.replace(replace_part, new_replacement)
pass


return function
pass
RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__generate_and_score_completions)
Expand Down