Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,6 +1152,69 @@ def test_dpo_trainer_use_logits_to_keep(self):

trainer.train()

# Special case for Gemma, as it uses token_type_ids, and we need to ensure they are properly in the collator.
def test_dpo_trainer_token_type_ids(self):
model_id = "trl-internal-testing/tiny-Gemma3ForConditionalGeneration"
# fmt: off
dataset_dict = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure!

"prompt": [
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Describe the image in great detail."}]}],
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Is this bus in the USA?"}]}],
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Give a thorough description of the image."}]}],
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Who are the people in the image?"}]}],
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What is written?"}]}],
],
"chosen": [
[{"role": "assistant", "content": [{"type": "text", "text": "The image features a modern, multi-colored train."}]}],
[{"role": "assistant", "content": [{"type": "text", "text": "Yes, it can be assumed that this bus is in the USA."}]}],
[{"role": "assistant", "content": [{"type": "text", "text": "The image features a forest path."}]}],
[{"role": "assistant", "content": [{"type": "text", "text": "There are two individuals, possibly girls or women."}]}],
[{"role": "assistant", "content": [{"type": "text", "text": '"ccpb".'}]}],
],
"rejected": [
[{"role": "assistant", "content": [{"type": "text", "text": "The image features a modern, colorful train."}]}],
[{"role": "assistant", "content": [{"type": "text", "text": "No, it's not in the USA."}]}],
[{"role": "assistant", "content": [{"type": "text", "text": "The image features a forest path surrounded by trees."}]}],
[{"role": "assistant", "content": [{"type": "text", "text": "In the image, there are two individuals."}]}],
[{"role": "assistant", "content": [{"type": "text", "text": '"ccpb".'}]}],
],
"images": [
[Image.fromarray(np.random.randint(0, 255, (92, 33, 3), dtype=np.uint8))],
[Image.fromarray(np.random.randint(0, 255, (64, 48, 3), dtype=np.uint8))],
[Image.fromarray(np.random.randint(0, 255, (80, 152, 3), dtype=np.uint8))],
[Image.fromarray(np.random.randint(0, 255, (57, 24, 3), dtype=np.uint8))],
[Image.fromarray(np.random.randint(0, 255, (102, 48, 3), dtype=np.uint8))],
],
}
# fmt: on
dataset = Dataset.from_dict(dataset_dict)
dataset = dataset.cast_column("images", features.Sequence(features.Image()))

# Instantiate the model and processor
model = AutoModelForImageTextToText.from_pretrained(model_id)
ref_model = AutoModelForImageTextToText.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id)

training_args = DPOConfig(
output_dir=self.tmp_dir,
per_device_train_batch_size=2,
remove_unused_columns=False,
learning_rate=0.01, # increase learning rate to speed up test
max_prompt_length=None, # don't truncate to avoid issues with patch tokens
max_length=None,
report_to="none",
)
trainer = DPOTrainer(
model=model,
ref_model=ref_model,
args=training_args,
processing_class=processor,
train_dataset=dataset,
eval_dataset=dataset,
)

trainer.train()

def test_dpo_trainer_with_tools(self):
model_id = "trl-internal-testing/tiny-LlamaForCausalLM-3.2"
tokenizer = AutoTokenizer.from_pretrained(model_id)
Expand Down
43 changes: 39 additions & 4 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d
if "ref_chosen_logps" in examples[0] and "ref_rejected_logps" in examples[0]:
output["ref_chosen_logps"] = ref_chosen_logps
output["ref_rejected_logps"] = ref_rejected_logps
if "token_type_ids" in examples[0]:
token_type_ids = [torch.tensor(example["token_type_ids"]) for example in examples]
output["token_type_ids"] = pad(token_type_ids, padding_value=0, padding_side="left")

return output

Expand Down Expand Up @@ -790,6 +793,8 @@ def process_row(
output["pixel_attention_mask"] = processed_features["pixel_attention_mask"][0]
if "image_sizes" in processed_features:
output["image_sizes"] = processed_features["image_sizes"][0]
if "token_type_ids" in processed_features:
output["token_type_ids"] = processed_features["token_type_ids"][0]

return output

Expand All @@ -806,6 +811,7 @@ def _set_signature_columns_if_needed(self):
"image_sizes",
"ref_chosen_logps",
"ref_rejected_logps",
"token_type_ids",
]

def get_train_dataloader(self) -> DataLoader:
Expand Down Expand Up @@ -991,6 +997,8 @@ def concatenated_inputs(
)
if "image_sizes" in batch:
output["image_sizes"] = torch.cat([batch["image_sizes"], batch["image_sizes"]], dim=0)
if "token_type_ids" in batch:
output["token_type_ids"] = torch.cat((batch["token_type_ids"], batch["token_type_ids"]))

# Concatenate the chosen and rejected completions
max_completion_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
Expand Down Expand Up @@ -1516,6 +1524,9 @@ def concatenated_forward(
# Concatenate the prompt and completion inputs
input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1)
attention_mask = torch.cat((prompt_attention_mask, completion_attention_mask), dim=1)
if "token_type_ids" in concatenated_batch:
prompt_token_type_ids = concatenated_batch["token_type_ids"]
token_type_ids = pad_to_length(prompt_token_type_ids, input_ids.shape[1], 0)
# Mask the prompt but not the completion for the loss
loss_mask = torch.cat(
(torch.zeros_like(prompt_attention_mask), completion_attention_mask),
Expand All @@ -1528,19 +1539,35 @@ def concatenated_forward(
# Flush left to reduce the memory usage
# [[0, 0, x, x, x, x], -> [[x, x, x, x],
# [0, x, x, x, 0, 0]] [x, x, x, 0]]
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
if "token_type_ids" in concatenated_batch:
attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
attention_mask, input_ids, loss_mask, token_type_ids
)
else:
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
attention_mask = attention_mask[:, : self.max_length]
input_ids = input_ids[:, : self.max_length]
loss_mask = loss_mask[:, : self.max_length]
elif self.truncation_mode == "keep_end":
# Flush right before truncating left, then flush left
# [[0, 0, x, x, x, x], -> [[0, 0, x, x],
# [0, x, x, x, 0, 0]] [0, x, x, x]]
attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask)
if "token_type_ids" in concatenated_batch:
attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
attention_mask, input_ids, loss_mask, token_type_ids
)
token_type_ids = token_type_ids[:, -self.max_length :]
else:
attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask)
input_ids = input_ids[:, -self.max_length :]
attention_mask = attention_mask[:, -self.max_length :]
loss_mask = loss_mask[:, -self.max_length :]
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
if "token_type_ids" in concatenated_batch:
attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
attention_mask, input_ids, loss_mask, token_type_ids
)
else:
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
else:
raise ValueError(
f"Unknown truncation mode: '{self.truncation_mode}'. Should be one of ['keep_end', "
Expand All @@ -1550,7 +1577,15 @@ def concatenated_forward(
# Flush left to reduce the memory usage
# [[0, 0, x, x, x, x], -> [[x, x, x, x],
# [0, x, x, x, 0, 0]] [x, x, x, 0]]
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
if "token_type_ids" in concatenated_batch:
attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
attention_mask, input_ids, loss_mask, token_type_ids
)
else:
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)

if "token_type_ids" in concatenated_batch:
model_kwargs["token_type_ids"] = token_type_ids

if self.use_logits_to_keep:
# Compute logits_to_keep based on loss_mask pattern:
Expand Down