diff --git a/tests/test_utils.py b/tests/test_utils.py index 60730d685d0..f036a897e1b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -873,7 +873,7 @@ def test_with_scalar(self): class SplitPixelValuesByGridTester(TrlTestCase): def test_split_correctly_0(self): batch = { - "image_split_sizes": [4, 4], + "image_grid_thw": torch.tensor([[1, 2, 2], [1, 2, 2]]), "pixel_values": torch.arange(8 * 3).reshape(8, 3), # Shape: [8, 3] } result = split_pixel_values_by_grid(batch) @@ -884,7 +884,7 @@ def test_split_correctly_0(self): def test_split_correctly_1(self): batch = { - "image_split_sizes": [4, 8], + "image_grid_thw": torch.tensor([[1, 2, 2], [1, 2, 4]]), "pixel_values": torch.arange(12 * 3).reshape(12, 3), # Shape: [12, 3] } result = split_pixel_values_by_grid(batch) @@ -900,7 +900,7 @@ def test_missing_keys(self): def test_mismatched_length(self): batch = { - "image_split_sizes": torch.tensor([2, 2]), # Total = 4 + "image_grid_thw": torch.tensor([[1, 1, 2], [1, 2, 1]]), # Total = 8 "pixel_values": torch.randn(3, 5), # Only 3 rows } with self.assertRaises(ValueError): diff --git a/trl/experimental/gfpo/gfpo_trainer.py b/trl/experimental/gfpo/gfpo_trainer.py index b067e7410c7..f2a675fab16 100644 --- a/trl/experimental/gfpo/gfpo_trainer.py +++ b/trl/experimental/gfpo/gfpo_trainer.py @@ -93,7 +93,6 @@ def _generate_and_score_completions(self, inputs): # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] kwargs = {} has_images = "image" in inputs[0] - image_split_sizes = None if has_images: images = [example.get("image") for example in inputs] kwargs = {"images": [[img] for img in images]} @@ -101,11 +100,6 @@ def _generate_and_score_completions(self, inputs): if isinstance(prompt, list): # i.e., when using conversational data prepare_multimodal_messages(prompt, num_images=1) - if hasattr(self.processing_class, "_get_num_multimodal_tokens"): - image_sizes = [(image.height, image.width) for image in images] - multimodal_extra_data = self.processing_class._get_num_multimodal_tokens(image_sizes) - image_split_sizes = multimodal_extra_data.num_image_patches - prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] prompt_inputs = self.processing_class( @@ -116,13 +110,9 @@ def _generate_and_score_completions(self, inputs): add_special_tokens=False, **kwargs, ) - prompt_inputs = super(_GRPOTrainer, self)._prepare_inputs(prompt_inputs) + prompt_inputs = super()._prepare_inputs(prompt_inputs) prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] - if "image_grid_thw" in prompt_inputs and image_split_sizes is None: - # Fallback for VLMs that require image_grid_thw but don't provide _get_num_multimodal_tokens - image_split_sizes = prompt_inputs["image_grid_thw"].prod(dim=1).tolist() - if self.max_prompt_length is not None: # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens. # Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text, @@ -407,7 +397,6 @@ def _generate_and_score_completions(self, inputs): image_grid_thw=prompt_inputs.get("image_grid_thw"), pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), - image_split_sizes=image_split_sizes, ) else: old_per_token_logps = None @@ -432,7 +421,6 @@ def _generate_and_score_completions(self, inputs): image_grid_thw=prompt_inputs.get("image_grid_thw"), pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), - image_split_sizes=image_split_sizes, ) else: with self.accelerator.unwrap_model(self.model).disable_adapter(): @@ -446,7 +434,6 @@ def _generate_and_score_completions(self, inputs): image_grid_thw=prompt_inputs.get("image_grid_thw"), pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), - image_split_sizes=image_split_sizes, ) else: ref_per_token_logps = None @@ -652,6 +639,4 @@ def _generate_and_score_completions(self, inputs): output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"] if "image_sizes" in prompt_inputs: output["image_sizes"] = prompt_inputs["image_sizes"] - if image_split_sizes is not None: - output["image_split_sizes"] = image_split_sizes return output diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 0c2ad9a3121..bb902445d09 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -791,7 +791,6 @@ def _get_per_token_logps_and_entropies( image_grid_thw=None, pixel_attention_mask=None, image_sizes=None, - image_split_sizes=None, ) -> dict[str, Optional[torch.Tensor]]: """Compute log-probs and (optionally) entropies for each token.""" batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak @@ -804,15 +803,13 @@ def _get_per_token_logps_and_entropies( # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} - if image_grid_thw is not None: + if image_grid_thw is not None and pixel_values is not None: model_inputs["image_grid_thw"] = image_grid_thw[start : start + batch_size] - if pixel_values is not None: - if image_split_sizes is not None: - start_pixel_idx = sum(image_split_sizes[:start]) - end_pixel_idx = sum(image_split_sizes[: start + batch_size]) - model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx] - else: - model_inputs["pixel_values"] = pixel_values[start : start + batch_size] + start_pixel_idx = image_grid_thw[:start].prod(-1).sum().item() + end_pixel_idx = image_grid_thw[: start + batch_size].prod(-1).sum().item() + model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx] + elif pixel_values is not None: + model_inputs["pixel_values"] = pixel_values[start : start + batch_size] if pixel_attention_mask is not None: model_inputs["pixel_attention_mask"] = pixel_attention_mask[start : start + batch_size] if image_sizes is not None: @@ -1078,7 +1075,6 @@ def _generate_and_score_completions( # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] kwargs = {} has_images = "image" in inputs[0] - image_split_sizes = None if has_images: images = [example.get("image") for example in inputs] kwargs = {"images": [[img] for img in images]} @@ -1086,11 +1082,6 @@ def _generate_and_score_completions( if isinstance(prompt, list): # i.e., when using conversational data prepare_multimodal_messages(prompt, num_images=1) - if hasattr(self.processing_class, "_get_num_multimodal_tokens"): - image_sizes = [(image.height, image.width) for image in images] - multimodal_extra_data = self.processing_class._get_num_multimodal_tokens(image_sizes) - image_split_sizes = multimodal_extra_data.num_image_patches - prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] prompt_inputs = self.processing_class( @@ -1104,10 +1095,6 @@ def _generate_and_score_completions( prompt_inputs = super()._prepare_inputs(prompt_inputs) prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] - if "image_grid_thw" in prompt_inputs and image_split_sizes is None: - # Fallback for VLMs that require image_grid_thw but don't provide _get_num_multimodal_tokens - image_split_sizes = prompt_inputs["image_grid_thw"].prod(dim=1).tolist() - if self.max_prompt_length is not None: # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens. # Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text, @@ -1392,7 +1379,6 @@ def _generate_and_score_completions( image_grid_thw=prompt_inputs.get("image_grid_thw"), pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), - image_split_sizes=image_split_sizes, ) else: old_per_token_logps = None @@ -1417,7 +1403,6 @@ def _generate_and_score_completions( image_grid_thw=prompt_inputs.get("image_grid_thw"), pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), - image_split_sizes=image_split_sizes, ) else: with self.accelerator.unwrap_model(self.model).disable_adapter(): @@ -1431,7 +1416,6 @@ def _generate_and_score_completions( image_grid_thw=prompt_inputs.get("image_grid_thw"), pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), - image_split_sizes=image_split_sizes, ) else: ref_per_token_logps = None @@ -1580,8 +1564,6 @@ def _generate_and_score_completions( output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"] if "image_sizes" in prompt_inputs: output["image_sizes"] = prompt_inputs["image_sizes"] - if image_split_sizes is not None: - output["image_split_sizes"] = image_split_sizes return output def compute_liger_loss(self, unwrapped_model, inputs): @@ -1656,7 +1638,6 @@ def _compute_loss(self, model, inputs): image_grid_thw=inputs.get("image_grid_thw"), pixel_attention_mask=inputs.get("pixel_attention_mask"), image_sizes=inputs.get("image_sizes"), - image_split_sizes=inputs.get("image_split_sizes"), ) if self.top_entropy_quantile < 1.0: diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 86aeb8910a0..56ffbfe7fea 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -777,7 +777,6 @@ def _get_per_token_logps_and_entropies( image_grid_thw=None, pixel_attention_mask=None, image_sizes=None, - image_split_sizes=None, ) -> dict[str, Optional[torch.Tensor]]: """Compute log-probs and (optionally) entropies for each token.""" batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak @@ -790,15 +789,13 @@ def _get_per_token_logps_and_entropies( # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} - if image_grid_thw is not None: + if image_grid_thw is not None and pixel_values is not None: model_inputs["image_grid_thw"] = image_grid_thw[start : start + batch_size] - if pixel_values is not None: - if image_split_sizes is not None: - start_pixel_idx = sum(image_split_sizes[:start]) - end_pixel_idx = sum(image_split_sizes[: start + batch_size]) - model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx] - else: - model_inputs["pixel_values"] = pixel_values[start : start + batch_size] + start_pixel_idx = image_grid_thw[:start].prod(-1).sum().item() + end_pixel_idx = image_grid_thw[: start + batch_size].prod(-1).sum().item() + model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx] + elif pixel_values is not None: + model_inputs["pixel_values"] = pixel_values[start : start + batch_size] if pixel_attention_mask is not None: model_inputs["pixel_attention_mask"] = pixel_attention_mask[start : start + batch_size] if image_sizes is not None: @@ -1064,7 +1061,6 @@ def _generate_and_score_completions( # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] kwargs = {} has_images = "image" in inputs[0] - image_split_sizes = None if has_images: images = [example.get("image") for example in inputs] kwargs = {"images": [[img] for img in images]} @@ -1072,11 +1068,6 @@ def _generate_and_score_completions( if isinstance(prompt, list): # i.e., when using conversational data prepare_multimodal_messages(prompt, num_images=1) - if hasattr(self.processing_class, "_get_num_multimodal_tokens"): - image_sizes = [(image.height, image.width) for image in images] - multimodal_extra_data = self.processing_class._get_num_multimodal_tokens(image_sizes) - image_split_sizes = multimodal_extra_data.num_image_patches - prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] prompt_inputs = self.processing_class( @@ -1090,10 +1081,6 @@ def _generate_and_score_completions( prompt_inputs = super()._prepare_inputs(prompt_inputs) prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] - if "image_grid_thw" in prompt_inputs and image_split_sizes is None: - # Fallback for VLMs that require image_grid_thw but don't provide _get_num_multimodal_tokens - image_split_sizes = prompt_inputs["image_grid_thw"].prod(dim=1).tolist() - if self.max_prompt_length is not None: # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens. # Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text, @@ -1346,7 +1333,6 @@ def _generate_and_score_completions( image_grid_thw=prompt_inputs.get("image_grid_thw"), pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), - image_split_sizes=image_split_sizes, ) old_logps = (old_per_token_logps * completion_mask).sum(1) # mask out padding and tokens after EOS @@ -1363,7 +1349,6 @@ def _generate_and_score_completions( image_grid_thw=prompt_inputs.get("image_grid_thw"), pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), - image_split_sizes=image_split_sizes, ) else: with self.accelerator.unwrap_model(self.model).disable_adapter(): @@ -1377,7 +1362,6 @@ def _generate_and_score_completions( image_grid_thw=prompt_inputs.get("image_grid_thw"), pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), - image_split_sizes=image_split_sizes, ) else: ref_per_token_logps = None @@ -1498,8 +1482,6 @@ def _generate_and_score_completions( output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"] if "image_sizes" in prompt_inputs: output["image_sizes"] = prompt_inputs["image_sizes"] - if image_split_sizes is not None: - output["image_split_sizes"] = image_split_sizes return output @profiling_decorator @@ -1527,7 +1509,6 @@ def _compute_loss(self, model, inputs): image_grid_thw=inputs.get("image_grid_thw"), pixel_attention_mask=inputs.get("pixel_attention_mask"), image_sizes=inputs.get("image_sizes"), - image_split_sizes=inputs.get("image_split_sizes"), ) logps = (per_token_logps * completion_mask).sum(1) # mask out padding and tokens after EOS diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 16ce8321612..37612a423bd 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -1783,10 +1783,10 @@ def split_pixel_values_by_grid(batch: dict[str, torch.Tensor]) -> dict[str, Unio Splits `batch["pixel_values"]` into a list of tensors based on the product of each row in `batch["image_grid_thw"]`, while keeping other entries unchanged. """ - if "image_split_sizes" not in batch or "pixel_values" not in batch: + if "image_grid_thw" not in batch or "pixel_values" not in batch: return batch - lengths = batch["image_split_sizes"] # [batch_size] + lengths = batch["image_grid_thw"].prod(-1).tolist() # [batch_size] pixel_values = batch["pixel_values"] # [total, feature_dim] if sum(lengths) != pixel_values.size(0):