Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,7 @@ def test_with_scalar(self):
class SplitPixelValuesByGridTester(TrlTestCase):
def test_split_correctly_0(self):
batch = {
"image_split_sizes": [4, 4],
"image_grid_thw": torch.tensor([[1, 2, 2], [1, 2, 2]]),
"pixel_values": torch.arange(8 * 3).reshape(8, 3), # Shape: [8, 3]
}
result = split_pixel_values_by_grid(batch)
Expand All @@ -884,7 +884,7 @@ def test_split_correctly_0(self):

def test_split_correctly_1(self):
batch = {
"image_split_sizes": [4, 8],
"image_grid_thw": torch.tensor([[1, 2, 2], [1, 2, 4]]),
"pixel_values": torch.arange(12 * 3).reshape(12, 3), # Shape: [12, 3]
}
result = split_pixel_values_by_grid(batch)
Expand All @@ -900,7 +900,7 @@ def test_missing_keys(self):

def test_mismatched_length(self):
batch = {
"image_split_sizes": torch.tensor([2, 2]), # Total = 4
"image_grid_thw": torch.tensor([[1, 1, 2], [1, 2, 1]]), # Total = 8
"pixel_values": torch.randn(3, 5), # Only 3 rows
}
with self.assertRaises(ValueError):
Expand Down
31 changes: 6 additions & 25 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,7 +791,6 @@ def _get_per_token_logps_and_entropies(
image_grid_thw=None,
pixel_attention_mask=None,
image_sizes=None,
image_split_sizes=None,
) -> dict[str, Optional[torch.Tensor]]:
"""Compute log-probs and (optionally) entropies for each token."""
batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak
Expand All @@ -804,15 +803,13 @@ def _get_per_token_logps_and_entropies(
# Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't)
model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch}

if image_grid_thw is not None:
if image_grid_thw is not None and pixel_values is not None:
model_inputs["image_grid_thw"] = image_grid_thw[start : start + batch_size]
if pixel_values is not None:
if image_split_sizes is not None:
start_pixel_idx = sum(image_split_sizes[:start])
end_pixel_idx = sum(image_split_sizes[: start + batch_size])
model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx]
else:
model_inputs["pixel_values"] = pixel_values[start : start + batch_size]
start_pixel_idx = image_grid_thw[:start].prod(-1).sum().item()
end_pixel_idx = image_grid_thw[: start + batch_size].prod(-1).sum().item()
model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx]
elif pixel_values is not None:
model_inputs["pixel_values"] = pixel_values[start : start + batch_size]
if pixel_attention_mask is not None:
model_inputs["pixel_attention_mask"] = pixel_attention_mask[start : start + batch_size]
if image_sizes is not None:
Expand Down Expand Up @@ -1078,19 +1075,13 @@ def _generate_and_score_completions(
# [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}]
kwargs = {}
has_images = "image" in inputs[0]
image_split_sizes = None
if has_images:
images = [example.get("image") for example in inputs]
kwargs = {"images": [[img] for img in images]}
for prompt in prompts:
if isinstance(prompt, list): # i.e., when using conversational data
prepare_multimodal_messages(prompt, num_images=1)

if hasattr(self.processing_class, "_get_num_multimodal_tokens"):
image_sizes = [(image.height, image.width) for image in images]
multimodal_extra_data = self.processing_class._get_num_multimodal_tokens(image_sizes)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it also avoids calling a private method

image_split_sizes = multimodal_extra_data.num_image_patches

prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]

prompt_inputs = self.processing_class(
Expand All @@ -1104,10 +1095,6 @@ def _generate_and_score_completions(
prompt_inputs = super()._prepare_inputs(prompt_inputs)
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]

if "image_grid_thw" in prompt_inputs and image_split_sizes is None:
# Fallback for VLMs that require image_grid_thw but don't provide _get_num_multimodal_tokens
image_split_sizes = prompt_inputs["image_grid_thw"].prod(dim=1).tolist()

if self.max_prompt_length is not None:
# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
# Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text,
Expand Down Expand Up @@ -1392,7 +1379,6 @@ def _generate_and_score_completions(
image_grid_thw=prompt_inputs.get("image_grid_thw"),
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
image_sizes=prompt_inputs.get("image_sizes"),
image_split_sizes=image_split_sizes,
)
else:
old_per_token_logps = None
Expand All @@ -1417,7 +1403,6 @@ def _generate_and_score_completions(
image_grid_thw=prompt_inputs.get("image_grid_thw"),
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
image_sizes=prompt_inputs.get("image_sizes"),
image_split_sizes=image_split_sizes,
)
else:
with self.accelerator.unwrap_model(self.model).disable_adapter():
Expand All @@ -1431,7 +1416,6 @@ def _generate_and_score_completions(
image_grid_thw=prompt_inputs.get("image_grid_thw"),
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
image_sizes=prompt_inputs.get("image_sizes"),
image_split_sizes=image_split_sizes,
)
else:
ref_per_token_logps = None
Expand Down Expand Up @@ -1580,8 +1564,6 @@ def _generate_and_score_completions(
output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"]
if "image_sizes" in prompt_inputs:
output["image_sizes"] = prompt_inputs["image_sizes"]
if image_split_sizes is not None:
output["image_split_sizes"] = image_split_sizes
return output

def compute_liger_loss(self, unwrapped_model, inputs):
Expand Down Expand Up @@ -1656,7 +1638,6 @@ def _compute_loss(self, model, inputs):
image_grid_thw=inputs.get("image_grid_thw"),
pixel_attention_mask=inputs.get("pixel_attention_mask"),
image_sizes=inputs.get("image_sizes"),
image_split_sizes=inputs.get("image_split_sizes"),
)

if self.top_entropy_quantile < 1.0:
Expand Down
31 changes: 6 additions & 25 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,6 @@ def _get_per_token_logps_and_entropies(
image_grid_thw=None,
pixel_attention_mask=None,
image_sizes=None,
image_split_sizes=None,
) -> dict[str, Optional[torch.Tensor]]:
"""Compute log-probs and (optionally) entropies for each token."""
batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak
Expand All @@ -790,15 +789,13 @@ def _get_per_token_logps_and_entropies(
# Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't)
model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch}

if image_grid_thw is not None:
if image_grid_thw is not None and pixel_values is not None:
model_inputs["image_grid_thw"] = image_grid_thw[start : start + batch_size]
if pixel_values is not None:
if image_split_sizes is not None:
start_pixel_idx = sum(image_split_sizes[:start])
end_pixel_idx = sum(image_split_sizes[: start + batch_size])
model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx]
else:
model_inputs["pixel_values"] = pixel_values[start : start + batch_size]
start_pixel_idx = image_grid_thw[:start].prod(-1).sum().item()
end_pixel_idx = image_grid_thw[: start + batch_size].prod(-1).sum().item()
model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx]
elif pixel_values is not None:
model_inputs["pixel_values"] = pixel_values[start : start + batch_size]
if pixel_attention_mask is not None:
model_inputs["pixel_attention_mask"] = pixel_attention_mask[start : start + batch_size]
if image_sizes is not None:
Expand Down Expand Up @@ -1064,19 +1061,13 @@ def _generate_and_score_completions(
# [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}]
kwargs = {}
has_images = "image" in inputs[0]
image_split_sizes = None
if has_images:
images = [example.get("image") for example in inputs]
kwargs = {"images": [[img] for img in images]}
for prompt in prompts:
if isinstance(prompt, list): # i.e., when using conversational data
prepare_multimodal_messages(prompt, num_images=1)

if hasattr(self.processing_class, "_get_num_multimodal_tokens"):
image_sizes = [(image.height, image.width) for image in images]
multimodal_extra_data = self.processing_class._get_num_multimodal_tokens(image_sizes)
image_split_sizes = multimodal_extra_data.num_image_patches

prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]

prompt_inputs = self.processing_class(
Expand All @@ -1090,10 +1081,6 @@ def _generate_and_score_completions(
prompt_inputs = super()._prepare_inputs(prompt_inputs)
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]

if "image_grid_thw" in prompt_inputs and image_split_sizes is None:
# Fallback for VLMs that require image_grid_thw but don't provide _get_num_multimodal_tokens
image_split_sizes = prompt_inputs["image_grid_thw"].prod(dim=1).tolist()

if self.max_prompt_length is not None:
# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
# Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text,
Expand Down Expand Up @@ -1346,7 +1333,6 @@ def _generate_and_score_completions(
image_grid_thw=prompt_inputs.get("image_grid_thw"),
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
image_sizes=prompt_inputs.get("image_sizes"),
image_split_sizes=image_split_sizes,
)
old_logps = (old_per_token_logps * completion_mask).sum(1) # mask out padding and tokens after EOS

Expand All @@ -1363,7 +1349,6 @@ def _generate_and_score_completions(
image_grid_thw=prompt_inputs.get("image_grid_thw"),
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
image_sizes=prompt_inputs.get("image_sizes"),
image_split_sizes=image_split_sizes,
)
else:
with self.accelerator.unwrap_model(self.model).disable_adapter():
Expand All @@ -1377,7 +1362,6 @@ def _generate_and_score_completions(
image_grid_thw=prompt_inputs.get("image_grid_thw"),
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
image_sizes=prompt_inputs.get("image_sizes"),
image_split_sizes=image_split_sizes,
)
else:
ref_per_token_logps = None
Expand Down Expand Up @@ -1498,8 +1482,6 @@ def _generate_and_score_completions(
output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"]
if "image_sizes" in prompt_inputs:
output["image_sizes"] = prompt_inputs["image_sizes"]
if image_split_sizes is not None:
output["image_split_sizes"] = image_split_sizes
return output

@profiling_decorator
Expand Down Expand Up @@ -1527,7 +1509,6 @@ def _compute_loss(self, model, inputs):
image_grid_thw=inputs.get("image_grid_thw"),
pixel_attention_mask=inputs.get("pixel_attention_mask"),
image_sizes=inputs.get("image_sizes"),
image_split_sizes=inputs.get("image_split_sizes"),
)

logps = (per_token_logps * completion_mask).sum(1) # mask out padding and tokens after EOS
Expand Down
4 changes: 2 additions & 2 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1783,10 +1783,10 @@ def split_pixel_values_by_grid(batch: dict[str, torch.Tensor]) -> dict[str, Unio
Splits `batch["pixel_values"]` into a list of tensors based on the product of each row in
`batch["image_grid_thw"]`, while keeping other entries unchanged.
"""
if "image_split_sizes" not in batch or "pixel_values" not in batch:
if "image_grid_thw" not in batch or "pixel_values" not in batch:
return batch

lengths = batch["image_split_sizes"] # [batch_size]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I understand correctly, that image_split_sizes was only really used in this helper method to extract the image lengths? If so, then I agree it's redundant with image_grid_thw

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, precisely

lengths = batch["image_grid_thw"].prod(-1).tolist() # [batch_size]
pixel_values = batch["pixel_values"] # [total, feature_dim]

if sum(lengths) != pixel_values.size(0):
Expand Down
Loading