Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1342,8 +1342,14 @@ def test_train_vlm_multi_image(self):
new_param = trainer.model.get_parameter(n)
assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated"

@parameterized.expand(
[
("trl-internal-testing/tiny-Gemma3ForConditionalGeneration",),
("trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",),
]
)
@require_vision
def test_train_vlm_prompt_completion(self):
def test_train_vlm_prompt_completion(self, model_id):
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_completion", split="train")

Expand All @@ -1354,7 +1360,7 @@ def test_train_vlm_prompt_completion(self):
report_to="none",
)
trainer = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
model=model_id,
args=training_args,
train_dataset=dataset,
)
Expand Down
51 changes: 34 additions & 17 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,34 +418,51 @@ def _collate_prompt_completion(self, examples: list[dict[str, Any]]) -> dict[str
add_special_tokens=False, # to avoid adding the BOS, twice see https://huggingface.co/blog/qgallouedec/gotchas-in-tokenizer-behavior#7-chat-template-and-tokenization-dont-compose-due-to-special-tokens
)

to_concat_keys = {"input_ids", "attention_mask"}
for candidate_key in (processed_prompts.keys() & processed_completions.keys()) - to_concat_keys:
if (
processed_prompts[candidate_key].shape == processed_prompts["attention_mask"].shape
and processed_completions[candidate_key].shape == processed_completions["attention_mask"].shape
):
to_concat_keys.add(candidate_key)

# Concatenate prompts and completions
prompt_ids, completion_ids = processed_prompts["input_ids"], processed_completions["input_ids"]
prompt_mask, completion_mask = processed_prompts["attention_mask"], processed_completions["attention_mask"]
input_ids = torch.cat((prompt_ids, completion_ids), dim=1)
attention_mask = torch.cat((prompt_mask, completion_mask), dim=1)
completion_mask = torch.cat((torch.zeros_like(prompt_mask), completion_mask), dim=1)
prompt_completion_concats = {
key: torch.cat((processed_prompts[key], processed_completions[key]), dim=1) for key in to_concat_keys
}
prompt_completion_concats["completion_mask"] = torch.cat(
(torch.zeros_like(processed_prompts["attention_mask"]), processed_completions["attention_mask"]), dim=1
)

# Flush left to reduce padding
attention_mask, input_ids, completion_mask = flush_left(attention_mask, input_ids, completion_mask)
non_attention_mask_keys = prompt_completion_concats.keys() - {"attention_mask"}
prompt_completion_concats = dict(
zip(
["attention_mask", *non_attention_mask_keys],
flush_left(
prompt_completion_concats["attention_mask"],
*(prompt_completion_concats[key] for key in non_attention_mask_keys),
),
)
)

# Truncate if necessary
if self.max_length is not None:
input_ids = input_ids[:, : self.max_length]
attention_mask = attention_mask[:, : self.max_length]
completion_mask = completion_mask[:, : self.max_length]
prompt_completion_concats = {
key: val[:, : self.max_length] for key, val in prompt_completion_concats.items()
}

# Create labels and mask padding tokens
labels = input_ids.clone()
labels[attention_mask == 0] = -100
labels = prompt_completion_concats["input_ids"].clone()
labels[prompt_completion_concats["attention_mask"] == 0] = -100
if self.completion_only_loss:
labels[completion_mask == 0] = -100
labels[prompt_completion_concats["completion_mask"] == 0] = -100
prompt_completion_concats["labels"] = labels

# Build the output dictionary
output = processed_prompts # we take processed_prompts because it contains the images
output["input_ids"] = input_ids
output["attention_mask"] = attention_mask
output["labels"] = labels
return output
return (
processed_prompts | prompt_completion_concats
) # we take processed_prompts because it contains the images


def dft_loss(outputs, labels, num_items_in_batch=None):
Expand Down