Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1333,6 +1333,38 @@ def test_train_vlm_prompt_completion(self):
new_param = trainer.model.get_parameter(n)
assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated"

# Special case for Gemma, as it uses token_type_ids, and we need to ensure they are properly in the collator.
@require_vision
def test_train_vlm_prompt_completion_gemma(self):
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_completion", split="train")

# Initialize the trainer
training_args = SFTConfig(
output_dir=self.tmp_dir,
max_length=None, # For VLMs, truncating can remove image tokens, leading to errors
report_to="none",
)
trainer = SFTTrainer(
model="trl-internal-testing/tiny-Gemma3ForConditionalGeneration",
args=training_args,
train_dataset=dataset,
)

# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

# Train the model
trainer.train()

# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated")

# Gemma 3n uses a timm encoder, making it difficult to create a smaller variant for testing.
# To ensure coverage, we run tests on the full model but mark them as slow to exclude from default runs.
@pytest.mark.slow
Expand Down
15 changes: 14 additions & 1 deletion trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,15 +418,26 @@ def _collate_prompt_completion(self, examples: list[dict[str, Any]]) -> dict[str
input_ids = torch.cat((prompt_ids, completion_ids), dim=1)
attention_mask = torch.cat((prompt_mask, completion_mask), dim=1)
completion_mask = torch.cat((torch.zeros_like(prompt_mask), completion_mask), dim=1)
if "token_type_ids" in processed_prompts: # special case for Gemma
prompt_token_type_ids = processed_prompts["token_type_ids"]
completion_token_type_ids = processed_completions["token_type_ids"]
token_type_ids = torch.cat((prompt_token_type_ids, completion_token_type_ids), dim=1)

# Flush left to reduce padding
attention_mask, input_ids, completion_mask = flush_left(attention_mask, input_ids, completion_mask)
if "token_type_ids" in processed_prompts:
attention_mask, input_ids, completion_mask, token_type_ids = flush_left(
attention_mask, input_ids, completion_mask, token_type_ids
)
else:
attention_mask, input_ids, completion_mask = flush_left(attention_mask, input_ids, completion_mask)

# Truncate if necessary
if self.max_length is not None:
input_ids = input_ids[:, : self.max_length]
attention_mask = attention_mask[:, : self.max_length]
completion_mask = completion_mask[:, : self.max_length]
if "token_type_ids" in processed_prompts:
token_type_ids = token_type_ids[:, : self.max_length]

# Create labels and mask padding tokens
labels = input_ids.clone()
Expand All @@ -439,6 +450,8 @@ def _collate_prompt_completion(self, examples: list[dict[str, Any]]) -> dict[str
output["input_ids"] = input_ids
output["attention_mask"] = attention_mask
output["labels"] = labels
if "token_type_ids" in processed_prompts:
output["token_type_ids"] = token_type_ids
return output


Expand Down
Loading