diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 9703fad09a..cde6ee76a5 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -1342,8 +1342,14 @@ def test_train_vlm_multi_image(self): new_param = trainer.model.get_parameter(n) assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated" + @parameterized.expand( + [ + ("trl-internal-testing/tiny-Gemma3ForConditionalGeneration",), + ("trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",), + ] + ) @require_vision - def test_train_vlm_prompt_completion(self): + def test_train_vlm_prompt_completion(self, model_id): # Get the dataset dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_completion", split="train") @@ -1354,7 +1360,7 @@ def test_train_vlm_prompt_completion(self): report_to="none", ) trainer = SFTTrainer( - model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", + model=model_id, args=training_args, train_dataset=dataset, ) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 3f0509a9ae..032d12fb66 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -418,34 +418,51 @@ def _collate_prompt_completion(self, examples: list[dict[str, Any]]) -> dict[str add_special_tokens=False, # to avoid adding the BOS, twice see https://huggingface.co/blog/qgallouedec/gotchas-in-tokenizer-behavior#7-chat-template-and-tokenization-dont-compose-due-to-special-tokens ) + to_concat_keys = {"input_ids", "attention_mask"} + for candidate_key in (processed_prompts.keys() & processed_completions.keys()) - to_concat_keys: + if ( + processed_prompts[candidate_key].shape == processed_prompts["attention_mask"].shape + and processed_completions[candidate_key].shape == processed_completions["attention_mask"].shape + ): + to_concat_keys.add(candidate_key) + # Concatenate prompts and completions - prompt_ids, completion_ids = processed_prompts["input_ids"], processed_completions["input_ids"] - prompt_mask, completion_mask = processed_prompts["attention_mask"], processed_completions["attention_mask"] - input_ids = torch.cat((prompt_ids, completion_ids), dim=1) - attention_mask = torch.cat((prompt_mask, completion_mask), dim=1) - completion_mask = torch.cat((torch.zeros_like(prompt_mask), completion_mask), dim=1) + prompt_completion_concats = { + key: torch.cat((processed_prompts[key], processed_completions[key]), dim=1) for key in to_concat_keys + } + prompt_completion_concats["completion_mask"] = torch.cat( + (torch.zeros_like(processed_prompts["attention_mask"]), processed_completions["attention_mask"]), dim=1 + ) # Flush left to reduce padding - attention_mask, input_ids, completion_mask = flush_left(attention_mask, input_ids, completion_mask) + non_attention_mask_keys = prompt_completion_concats.keys() - {"attention_mask"} + prompt_completion_concats = dict( + zip( + ["attention_mask", *non_attention_mask_keys], + flush_left( + prompt_completion_concats["attention_mask"], + *(prompt_completion_concats[key] for key in non_attention_mask_keys), + ), + ) + ) # Truncate if necessary if self.max_length is not None: - input_ids = input_ids[:, : self.max_length] - attention_mask = attention_mask[:, : self.max_length] - completion_mask = completion_mask[:, : self.max_length] + prompt_completion_concats = { + key: val[:, : self.max_length] for key, val in prompt_completion_concats.items() + } # Create labels and mask padding tokens - labels = input_ids.clone() - labels[attention_mask == 0] = -100 + labels = prompt_completion_concats["input_ids"].clone() + labels[prompt_completion_concats["attention_mask"] == 0] = -100 if self.completion_only_loss: - labels[completion_mask == 0] = -100 + labels[prompt_completion_concats["completion_mask"] == 0] = -100 + prompt_completion_concats["labels"] = labels # Build the output dictionary - output = processed_prompts # we take processed_prompts because it contains the images - output["input_ids"] = input_ids - output["attention_mask"] = attention_mask - output["labels"] = labels - return output + return ( + processed_prompts | prompt_completion_concats + ) # we take processed_prompts because it contains the images def dft_loss(outputs, labels, num_items_in_batch=None):