Skip to content

Commit b5ca379

Browse files
authored
🟩 Drop image_split_sizes in favour of image_grid_thw (#4111)
1 parent a68b4af commit b5ca379

File tree

5 files changed

+18
-71
lines changed

5 files changed

+18
-71
lines changed

tests/test_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -873,7 +873,7 @@ def test_with_scalar(self):
873873
class SplitPixelValuesByGridTester(TrlTestCase):
874874
def test_split_correctly_0(self):
875875
batch = {
876-
"image_split_sizes": [4, 4],
876+
"image_grid_thw": torch.tensor([[1, 2, 2], [1, 2, 2]]),
877877
"pixel_values": torch.arange(8 * 3).reshape(8, 3), # Shape: [8, 3]
878878
}
879879
result = split_pixel_values_by_grid(batch)
@@ -884,7 +884,7 @@ def test_split_correctly_0(self):
884884

885885
def test_split_correctly_1(self):
886886
batch = {
887-
"image_split_sizes": [4, 8],
887+
"image_grid_thw": torch.tensor([[1, 2, 2], [1, 2, 4]]),
888888
"pixel_values": torch.arange(12 * 3).reshape(12, 3), # Shape: [12, 3]
889889
}
890890
result = split_pixel_values_by_grid(batch)
@@ -900,7 +900,7 @@ def test_missing_keys(self):
900900

901901
def test_mismatched_length(self):
902902
batch = {
903-
"image_split_sizes": torch.tensor([2, 2]), # Total = 4
903+
"image_grid_thw": torch.tensor([[1, 1, 2], [1, 2, 1]]), # Total = 8
904904
"pixel_values": torch.randn(3, 5), # Only 3 rows
905905
}
906906
with self.assertRaises(ValueError):

trl/experimental/gfpo/gfpo_trainer.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -93,19 +93,13 @@ def _generate_and_score_completions(self, inputs):
9393
# [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}]
9494
kwargs = {}
9595
has_images = "image" in inputs[0]
96-
image_split_sizes = None
9796
if has_images:
9897
images = [example.get("image") for example in inputs]
9998
kwargs = {"images": [[img] for img in images]}
10099
for prompt in prompts:
101100
if isinstance(prompt, list): # i.e., when using conversational data
102101
prepare_multimodal_messages(prompt, num_images=1)
103102

104-
if hasattr(self.processing_class, "_get_num_multimodal_tokens"):
105-
image_sizes = [(image.height, image.width) for image in images]
106-
multimodal_extra_data = self.processing_class._get_num_multimodal_tokens(image_sizes)
107-
image_split_sizes = multimodal_extra_data.num_image_patches
108-
109103
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
110104

111105
prompt_inputs = self.processing_class(
@@ -116,13 +110,9 @@ def _generate_and_score_completions(self, inputs):
116110
add_special_tokens=False,
117111
**kwargs,
118112
)
119-
prompt_inputs = super(_GRPOTrainer, self)._prepare_inputs(prompt_inputs)
113+
prompt_inputs = super()._prepare_inputs(prompt_inputs)
120114
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
121115

122-
if "image_grid_thw" in prompt_inputs and image_split_sizes is None:
123-
# Fallback for VLMs that require image_grid_thw but don't provide _get_num_multimodal_tokens
124-
image_split_sizes = prompt_inputs["image_grid_thw"].prod(dim=1).tolist()
125-
126116
if self.max_prompt_length is not None:
127117
# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
128118
# Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text,
@@ -407,7 +397,6 @@ def _generate_and_score_completions(self, inputs):
407397
image_grid_thw=prompt_inputs.get("image_grid_thw"),
408398
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
409399
image_sizes=prompt_inputs.get("image_sizes"),
410-
image_split_sizes=image_split_sizes,
411400
)
412401
else:
413402
old_per_token_logps = None
@@ -432,7 +421,6 @@ def _generate_and_score_completions(self, inputs):
432421
image_grid_thw=prompt_inputs.get("image_grid_thw"),
433422
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
434423
image_sizes=prompt_inputs.get("image_sizes"),
435-
image_split_sizes=image_split_sizes,
436424
)
437425
else:
438426
with self.accelerator.unwrap_model(self.model).disable_adapter():
@@ -446,7 +434,6 @@ def _generate_and_score_completions(self, inputs):
446434
image_grid_thw=prompt_inputs.get("image_grid_thw"),
447435
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
448436
image_sizes=prompt_inputs.get("image_sizes"),
449-
image_split_sizes=image_split_sizes,
450437
)
451438
else:
452439
ref_per_token_logps = None
@@ -652,6 +639,4 @@ def _generate_and_score_completions(self, inputs):
652639
output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"]
653640
if "image_sizes" in prompt_inputs:
654641
output["image_sizes"] = prompt_inputs["image_sizes"]
655-
if image_split_sizes is not None:
656-
output["image_split_sizes"] = image_split_sizes
657642
return output

trl/trainer/grpo_trainer.py

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -791,7 +791,6 @@ def _get_per_token_logps_and_entropies(
791791
image_grid_thw=None,
792792
pixel_attention_mask=None,
793793
image_sizes=None,
794-
image_split_sizes=None,
795794
) -> dict[str, Optional[torch.Tensor]]:
796795
"""Compute log-probs and (optionally) entropies for each token."""
797796
batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak
@@ -804,15 +803,13 @@ def _get_per_token_logps_and_entropies(
804803
# Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't)
805804
model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch}
806805

807-
if image_grid_thw is not None:
806+
if image_grid_thw is not None and pixel_values is not None:
808807
model_inputs["image_grid_thw"] = image_grid_thw[start : start + batch_size]
809-
if pixel_values is not None:
810-
if image_split_sizes is not None:
811-
start_pixel_idx = sum(image_split_sizes[:start])
812-
end_pixel_idx = sum(image_split_sizes[: start + batch_size])
813-
model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx]
814-
else:
815-
model_inputs["pixel_values"] = pixel_values[start : start + batch_size]
808+
start_pixel_idx = image_grid_thw[:start].prod(-1).sum().item()
809+
end_pixel_idx = image_grid_thw[: start + batch_size].prod(-1).sum().item()
810+
model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx]
811+
elif pixel_values is not None:
812+
model_inputs["pixel_values"] = pixel_values[start : start + batch_size]
816813
if pixel_attention_mask is not None:
817814
model_inputs["pixel_attention_mask"] = pixel_attention_mask[start : start + batch_size]
818815
if image_sizes is not None:
@@ -1078,19 +1075,13 @@ def _generate_and_score_completions(
10781075
# [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}]
10791076
kwargs = {}
10801077
has_images = "image" in inputs[0]
1081-
image_split_sizes = None
10821078
if has_images:
10831079
images = [example.get("image") for example in inputs]
10841080
kwargs = {"images": [[img] for img in images]}
10851081
for prompt in prompts:
10861082
if isinstance(prompt, list): # i.e., when using conversational data
10871083
prepare_multimodal_messages(prompt, num_images=1)
10881084

1089-
if hasattr(self.processing_class, "_get_num_multimodal_tokens"):
1090-
image_sizes = [(image.height, image.width) for image in images]
1091-
multimodal_extra_data = self.processing_class._get_num_multimodal_tokens(image_sizes)
1092-
image_split_sizes = multimodal_extra_data.num_image_patches
1093-
10941085
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
10951086

10961087
prompt_inputs = self.processing_class(
@@ -1104,10 +1095,6 @@ def _generate_and_score_completions(
11041095
prompt_inputs = super()._prepare_inputs(prompt_inputs)
11051096
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
11061097

1107-
if "image_grid_thw" in prompt_inputs and image_split_sizes is None:
1108-
# Fallback for VLMs that require image_grid_thw but don't provide _get_num_multimodal_tokens
1109-
image_split_sizes = prompt_inputs["image_grid_thw"].prod(dim=1).tolist()
1110-
11111098
if self.max_prompt_length is not None:
11121099
# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
11131100
# Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text,
@@ -1392,7 +1379,6 @@ def _generate_and_score_completions(
13921379
image_grid_thw=prompt_inputs.get("image_grid_thw"),
13931380
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
13941381
image_sizes=prompt_inputs.get("image_sizes"),
1395-
image_split_sizes=image_split_sizes,
13961382
)
13971383
else:
13981384
old_per_token_logps = None
@@ -1417,7 +1403,6 @@ def _generate_and_score_completions(
14171403
image_grid_thw=prompt_inputs.get("image_grid_thw"),
14181404
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
14191405
image_sizes=prompt_inputs.get("image_sizes"),
1420-
image_split_sizes=image_split_sizes,
14211406
)
14221407
else:
14231408
with self.accelerator.unwrap_model(self.model).disable_adapter():
@@ -1431,7 +1416,6 @@ def _generate_and_score_completions(
14311416
image_grid_thw=prompt_inputs.get("image_grid_thw"),
14321417
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
14331418
image_sizes=prompt_inputs.get("image_sizes"),
1434-
image_split_sizes=image_split_sizes,
14351419
)
14361420
else:
14371421
ref_per_token_logps = None
@@ -1580,8 +1564,6 @@ def _generate_and_score_completions(
15801564
output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"]
15811565
if "image_sizes" in prompt_inputs:
15821566
output["image_sizes"] = prompt_inputs["image_sizes"]
1583-
if image_split_sizes is not None:
1584-
output["image_split_sizes"] = image_split_sizes
15851567
return output
15861568

15871569
def compute_liger_loss(self, unwrapped_model, inputs):
@@ -1656,7 +1638,6 @@ def _compute_loss(self, model, inputs):
16561638
image_grid_thw=inputs.get("image_grid_thw"),
16571639
pixel_attention_mask=inputs.get("pixel_attention_mask"),
16581640
image_sizes=inputs.get("image_sizes"),
1659-
image_split_sizes=inputs.get("image_split_sizes"),
16601641
)
16611642

16621643
if self.top_entropy_quantile < 1.0:

trl/trainer/rloo_trainer.py

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -777,7 +777,6 @@ def _get_per_token_logps_and_entropies(
777777
image_grid_thw=None,
778778
pixel_attention_mask=None,
779779
image_sizes=None,
780-
image_split_sizes=None,
781780
) -> dict[str, Optional[torch.Tensor]]:
782781
"""Compute log-probs and (optionally) entropies for each token."""
783782
batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak
@@ -790,15 +789,13 @@ def _get_per_token_logps_and_entropies(
790789
# Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't)
791790
model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch}
792791

793-
if image_grid_thw is not None:
792+
if image_grid_thw is not None and pixel_values is not None:
794793
model_inputs["image_grid_thw"] = image_grid_thw[start : start + batch_size]
795-
if pixel_values is not None:
796-
if image_split_sizes is not None:
797-
start_pixel_idx = sum(image_split_sizes[:start])
798-
end_pixel_idx = sum(image_split_sizes[: start + batch_size])
799-
model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx]
800-
else:
801-
model_inputs["pixel_values"] = pixel_values[start : start + batch_size]
794+
start_pixel_idx = image_grid_thw[:start].prod(-1).sum().item()
795+
end_pixel_idx = image_grid_thw[: start + batch_size].prod(-1).sum().item()
796+
model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx]
797+
elif pixel_values is not None:
798+
model_inputs["pixel_values"] = pixel_values[start : start + batch_size]
802799
if pixel_attention_mask is not None:
803800
model_inputs["pixel_attention_mask"] = pixel_attention_mask[start : start + batch_size]
804801
if image_sizes is not None:
@@ -1064,19 +1061,13 @@ def _generate_and_score_completions(
10641061
# [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}]
10651062
kwargs = {}
10661063
has_images = "image" in inputs[0]
1067-
image_split_sizes = None
10681064
if has_images:
10691065
images = [example.get("image") for example in inputs]
10701066
kwargs = {"images": [[img] for img in images]}
10711067
for prompt in prompts:
10721068
if isinstance(prompt, list): # i.e., when using conversational data
10731069
prepare_multimodal_messages(prompt, num_images=1)
10741070

1075-
if hasattr(self.processing_class, "_get_num_multimodal_tokens"):
1076-
image_sizes = [(image.height, image.width) for image in images]
1077-
multimodal_extra_data = self.processing_class._get_num_multimodal_tokens(image_sizes)
1078-
image_split_sizes = multimodal_extra_data.num_image_patches
1079-
10801071
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
10811072

10821073
prompt_inputs = self.processing_class(
@@ -1090,10 +1081,6 @@ def _generate_and_score_completions(
10901081
prompt_inputs = super()._prepare_inputs(prompt_inputs)
10911082
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
10921083

1093-
if "image_grid_thw" in prompt_inputs and image_split_sizes is None:
1094-
# Fallback for VLMs that require image_grid_thw but don't provide _get_num_multimodal_tokens
1095-
image_split_sizes = prompt_inputs["image_grid_thw"].prod(dim=1).tolist()
1096-
10971084
if self.max_prompt_length is not None:
10981085
# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
10991086
# Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text,
@@ -1346,7 +1333,6 @@ def _generate_and_score_completions(
13461333
image_grid_thw=prompt_inputs.get("image_grid_thw"),
13471334
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
13481335
image_sizes=prompt_inputs.get("image_sizes"),
1349-
image_split_sizes=image_split_sizes,
13501336
)
13511337
old_logps = (old_per_token_logps * completion_mask).sum(1) # mask out padding and tokens after EOS
13521338

@@ -1363,7 +1349,6 @@ def _generate_and_score_completions(
13631349
image_grid_thw=prompt_inputs.get("image_grid_thw"),
13641350
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
13651351
image_sizes=prompt_inputs.get("image_sizes"),
1366-
image_split_sizes=image_split_sizes,
13671352
)
13681353
else:
13691354
with self.accelerator.unwrap_model(self.model).disable_adapter():
@@ -1377,7 +1362,6 @@ def _generate_and_score_completions(
13771362
image_grid_thw=prompt_inputs.get("image_grid_thw"),
13781363
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
13791364
image_sizes=prompt_inputs.get("image_sizes"),
1380-
image_split_sizes=image_split_sizes,
13811365
)
13821366
else:
13831367
ref_per_token_logps = None
@@ -1498,8 +1482,6 @@ def _generate_and_score_completions(
14981482
output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"]
14991483
if "image_sizes" in prompt_inputs:
15001484
output["image_sizes"] = prompt_inputs["image_sizes"]
1501-
if image_split_sizes is not None:
1502-
output["image_split_sizes"] = image_split_sizes
15031485
return output
15041486

15051487
@profiling_decorator
@@ -1527,7 +1509,6 @@ def _compute_loss(self, model, inputs):
15271509
image_grid_thw=inputs.get("image_grid_thw"),
15281510
pixel_attention_mask=inputs.get("pixel_attention_mask"),
15291511
image_sizes=inputs.get("image_sizes"),
1530-
image_split_sizes=inputs.get("image_split_sizes"),
15311512
)
15321513

15331514
logps = (per_token_logps * completion_mask).sum(1) # mask out padding and tokens after EOS

trl/trainer/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1783,10 +1783,10 @@ def split_pixel_values_by_grid(batch: dict[str, torch.Tensor]) -> dict[str, Unio
17831783
Splits `batch["pixel_values"]` into a list of tensors based on the product of each row in
17841784
`batch["image_grid_thw"]`, while keeping other entries unchanged.
17851785
"""
1786-
if "image_split_sizes" not in batch or "pixel_values" not in batch:
1786+
if "image_grid_thw" not in batch or "pixel_values" not in batch:
17871787
return batch
17881788

1789-
lengths = batch["image_split_sizes"] # [batch_size]
1789+
lengths = batch["image_grid_thw"].prod(-1).tolist() # [batch_size]
17901790
pixel_values = batch["pixel_values"] # [total, feature_dim]
17911791

17921792
if sum(lengths) != pixel_values.size(0):

0 commit comments

Comments
 (0)