diff --git a/docs/source/dataset_formats.md b/docs/source/dataset_formats.md index da606a7d97e..8a105ff5e34 100644 --- a/docs/source/dataset_formats.md +++ b/docs/source/dataset_formats.md @@ -1037,7 +1037,7 @@ Some trainers also support fine-tuning vision-language models (VLMs) using image A conversational vision dataset differs from a standard conversational dataset in two key ways: -1. The dataset must contain the key `images` with the image data. +1. The dataset must contain the key `images` with the image data (as lists of PIL images) or `image` with a single PIL image. 2. The `"content"` field in messages must be a list of dictionaries, where each dictionary specifies the type of data: `"image"` or `"text"`. Example: diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md index 172e5f93111..e998ef63f69 100644 --- a/docs/source/grpo_trainer.md +++ b/docs/source/grpo_trainer.md @@ -562,7 +562,9 @@ Tested with: - **SmolVLM2** — e.g., `HuggingFaceTB/SmolVLM2-2.2B-Instruct` + Compatibility with all VLMs is not guaranteed. If you believe a model should be supported, feel free to open an issue on GitHub — or better yet, submit a pull request with the required changes. + ### Quick Start @@ -605,7 +607,7 @@ VLM training may fail if image tokens are truncated. We highly recommend disabli Each training sample should include: - `prompt`: Text formatted via the processor's chat template -- `image`: A single image (PIL or NumPy array) +- `image`/`images`: PIL Image or list of PIL Images The trainer automatically handles image-to-tensor conversion via the model’s image processor. diff --git a/docs/source/rloo_trainer.md b/docs/source/rloo_trainer.md index 66a8f3e16e4..bce71b1f0bf 100644 --- a/docs/source/rloo_trainer.md +++ b/docs/source/rloo_trainer.md @@ -533,7 +533,9 @@ Tested with: - **SmolVLM2** — e.g., `HuggingFaceTB/SmolVLM2-2.2B-Instruct` + Compatibility with all VLMs is not guaranteed. If you believe a model should be supported, feel free to open an issue on GitHub — or better yet, submit a pull request with the required changes. + ### Quick Start @@ -576,7 +578,7 @@ VLM training may fail if image tokens are truncated. We highly recommend disabli Each training sample should include: - `prompt`: Text formatted via the processor's chat template -- `image`: A single image (PIL or NumPy array) +- `image`/`images`: PIL Image or list of PIL Images The trainer automatically handles image-to-tensor conversion via the model’s image processor. diff --git a/scripts/generate_tiny_models.py b/scripts/generate_tiny_models.py index b89447a9954..90436c8adc5 100644 --- a/scripts/generate_tiny_models.py +++ b/scripts/generate_tiny_models.py @@ -292,6 +292,7 @@ def init_weights_tiny_model(model): "hidden_size": 16, "num_attention_heads": 4, "num_key_value_heads": 2, + "embed_dim": 64, } kwargs = {} diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 4e4321febcd..cc484c56d0f 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -1258,6 +1258,10 @@ def test_prepare_input_called_with_correct_data(self): def test_training_vlm(self, model_id): dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + training_args = GRPOConfig( output_dir=self.tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test @@ -1269,7 +1273,7 @@ def test_training_vlm(self, model_id): ) trainer = GRPOTrainer( model=model_id, - reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + reward_funcs=reward_func, args=training_args, train_dataset=dataset, ) @@ -1287,7 +1291,6 @@ def test_training_vlm(self, model_id): "model.vision_tower.", "model.multi_modal_projector.", "model.vision_model.", - "model.connector.modality_projection.", "model.visual.", "model.image_newline", ) @@ -1301,6 +1304,10 @@ def test_training_vlm(self, model_id): def test_training_vlm_beta_non_zero(self): dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + training_args = GRPOConfig( output_dir=self.tmp_dir, beta=0.1, # set beta to non-zero value to test the case where the reference model is used @@ -1312,7 +1319,7 @@ def test_training_vlm_beta_non_zero(self): ) trainer = GRPOTrainer( model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", - reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + reward_funcs=reward_func, args=training_args, train_dataset=dataset, ) @@ -1340,7 +1347,11 @@ def test_training_vlm_peft(self): "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration" ) base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] - dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") + + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] training_args = GRPOConfig( output_dir=self.tmp_dir, @@ -1352,7 +1363,7 @@ def test_training_vlm_peft(self): ) trainer = GRPOTrainer( model=model, - reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + reward_funcs=reward_func, args=training_args, train_dataset=dataset, peft_config=LoraConfig(target_modules=["q_proj", "v_proj"]), @@ -1376,6 +1387,10 @@ def test_training_vlm_peft(self): def test_training_vlm_and_importance_sampling(self): dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + training_args = GRPOConfig( output_dir=self.tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test @@ -1387,7 +1402,7 @@ def test_training_vlm_and_importance_sampling(self): ) trainer = GRPOTrainer( model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", - reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + reward_funcs=reward_func, args=training_args, train_dataset=dataset, ) @@ -1413,6 +1428,10 @@ def test_training_vlm_and_importance_sampling(self): def test_training_vlm_and_liger(self): dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + training_args = GRPOConfig( output_dir=self.tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test @@ -1425,7 +1444,7 @@ def test_training_vlm_and_liger(self): ) trainer = GRPOTrainer( model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", - reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + reward_funcs=reward_func, args=training_args, train_dataset=dataset, ) @@ -1451,6 +1470,10 @@ def test_training_vlm_and_prompt_truncation(self): # If not handled properly, prompt truncation may truncate image token dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + training_args = GRPOConfig( output_dir=self.tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test @@ -1462,7 +1485,7 @@ def test_training_vlm_and_prompt_truncation(self): ) trainer = GRPOTrainer( model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", - reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + reward_funcs=reward_func, args=training_args, train_dataset=dataset, ) @@ -1495,6 +1518,10 @@ def test_training_vlm_and_prompt_truncation(self): def test_training_vlm_and_vllm(self, model_id) -> None: dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + training_args = GRPOConfig( output_dir=self.tmp_dir, learning_rate=0.1, @@ -1508,7 +1535,44 @@ def test_training_vlm_and_vllm(self, model_id) -> None: ) trainer = GRPOTrainer( model=model_id, - reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + reward_funcs=reward_func, + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + @require_vision + def test_training_vlm_multi_image(self): + dataset = load_dataset("trl-internal-testing/zen-multi-image", "conversational_prompt_only", split="train") + + # For now, mixing image+text and text-only examples is not supported, so we filter out text-only examples + dataset = dataset.filter(lambda x: len(x["images"]) > 0) + + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + max_prompt_length=None, # disable prompt truncation, because usually, models don't support it + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", + reward_funcs=reward_func, args=training_args, train_dataset=dataset, ) @@ -1519,6 +1583,9 @@ def test_training_vlm_and_vllm(self, model_id) -> None: self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + # Check that the params have changed + # Because of the way the tiny models are initialized, the gradient does not flow properly through the + # vision parts of the model, so we skip them. Ideally, we should fix the init of these models. for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py index 12042bd2b3b..cde52de6047 100644 --- a/tests/test_rloo_trainer.py +++ b/tests/test_rloo_trainer.py @@ -1089,6 +1089,10 @@ def test_prepare_input_called_with_correct_data(self): def test_training_vlm(self, model_id): dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + training_args = RLOOConfig( output_dir=self.tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test @@ -1100,7 +1104,7 @@ def test_training_vlm(self, model_id): ) trainer = RLOOTrainer( model=model_id, - reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + reward_funcs=reward_func, args=training_args, train_dataset=dataset, ) @@ -1117,8 +1121,6 @@ def test_training_vlm(self, model_id): params_to_skip = ( "model.vision_tower.", "model.multi_modal_projector.", - "model.vision_model.", - "model.connector.modality_projection.", "model.visual.", "model.image_newline", ) @@ -1132,6 +1134,10 @@ def test_training_vlm(self, model_id): def test_training_vlm_beta_non_zero(self): dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + training_args = RLOOConfig( output_dir=self.tmp_dir, beta=0.1, # set beta to non-zero value to test the case where the reference model is used @@ -1143,7 +1149,7 @@ def test_training_vlm_beta_non_zero(self): ) trainer = RLOOTrainer( model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", - reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + reward_funcs=reward_func, args=training_args, train_dataset=dataset, ) @@ -1171,7 +1177,11 @@ def test_training_vlm_peft(self): "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration" ) base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] - dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") + + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] training_args = RLOOConfig( output_dir=self.tmp_dir, @@ -1183,7 +1193,7 @@ def test_training_vlm_peft(self): ) trainer = RLOOTrainer( model=model, - reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + reward_funcs=reward_func, args=training_args, train_dataset=dataset, peft_config=LoraConfig(target_modules=["q_proj", "v_proj"]), @@ -1208,6 +1218,10 @@ def test_training_vlm_and_prompt_truncation(self): # If not handled properly, prompt truncation may truncate image token dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + training_args = RLOOConfig( output_dir=self.tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test @@ -1219,7 +1233,7 @@ def test_training_vlm_and_prompt_truncation(self): ) trainer = RLOOTrainer( model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", - reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + reward_funcs=reward_func, args=training_args, train_dataset=dataset, ) @@ -1252,6 +1266,10 @@ def test_training_vlm_and_prompt_truncation(self): def test_training_vlm_and_vllm(self, model_id) -> None: dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + training_args = RLOOConfig( output_dir=self.tmp_dir, learning_rate=0.1, @@ -1265,7 +1283,44 @@ def test_training_vlm_and_vllm(self, model_id) -> None: ) trainer = RLOOTrainer( model=model_id, - reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + reward_funcs=reward_func, + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + @require_vision + def test_training_vlm_multi_image(self): + dataset = load_dataset("trl-internal-testing/zen-multi-image", "conversational_prompt_only", split="train") + + # For now, mixing image+text and text-only examples is not supported, so we filter out text-only examples + dataset = dataset.filter(lambda x: len(x["images"]) > 0) + + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + + training_args = RLOOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + max_prompt_length=None, # disable prompt truncation, because usually, models don't support it + report_to="none", + ) + trainer = RLOOTrainer( + model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", + reward_funcs=reward_func, args=training_args, train_dataset=dataset, ) diff --git a/tests/test_utils.py b/tests/test_utils.py index f036a897e1b..6f6ba1579ef 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -874,6 +874,7 @@ class SplitPixelValuesByGridTester(TrlTestCase): def test_split_correctly_0(self): batch = { "image_grid_thw": torch.tensor([[1, 2, 2], [1, 2, 2]]), + "num_images": [1, 1], "pixel_values": torch.arange(8 * 3).reshape(8, 3), # Shape: [8, 3] } result = split_pixel_values_by_grid(batch) @@ -881,10 +882,15 @@ def test_split_correctly_0(self): self.assertEqual(len(result["pixel_values"]), 2) self.assertTrue(torch.equal(result["pixel_values"][0], batch["pixel_values"][:4])) self.assertTrue(torch.equal(result["pixel_values"][1], batch["pixel_values"][4:])) + self.assertIsInstance(result["image_grid_thw"], list) + self.assertEqual(len(result["image_grid_thw"]), 2) + self.assertTrue(torch.equal(result["image_grid_thw"][0], torch.tensor([[1, 2, 2]]))) + self.assertTrue(torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 2]]))) def test_split_correctly_1(self): batch = { "image_grid_thw": torch.tensor([[1, 2, 2], [1, 2, 4]]), + "num_images": [1, 1], "pixel_values": torch.arange(12 * 3).reshape(12, 3), # Shape: [12, 3] } result = split_pixel_values_by_grid(batch) @@ -892,6 +898,10 @@ def test_split_correctly_1(self): self.assertEqual(len(result["pixel_values"]), 2) self.assertTrue(torch.equal(result["pixel_values"][0], batch["pixel_values"][:4])) self.assertTrue(torch.equal(result["pixel_values"][1], batch["pixel_values"][4:12])) + self.assertIsInstance(result["image_grid_thw"], list) + self.assertEqual(len(result["image_grid_thw"]), 2) + self.assertTrue(torch.equal(result["image_grid_thw"][0], torch.tensor([[1, 2, 2]]))) + self.assertTrue(torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 4]]))) def test_missing_keys(self): batch = {"pixel_values": torch.tensor([1.0])} @@ -901,11 +911,28 @@ def test_missing_keys(self): def test_mismatched_length(self): batch = { "image_grid_thw": torch.tensor([[1, 1, 2], [1, 2, 1]]), # Total = 8 + "num_images": [1, 1], "pixel_values": torch.randn(3, 5), # Only 3 rows } with self.assertRaises(ValueError): split_pixel_values_by_grid(batch) + def test_multi_images(self): + batch = { + "image_grid_thw": torch.tensor([[1, 1, 2], [1, 2, 2], [1, 2, 1]]), # Total = 8 + "num_images": [1, 2], + "pixel_values": torch.arange(8 * 3).reshape(8, 3), # Shape: [8, 3] + } + result = split_pixel_values_by_grid(batch) + self.assertIsInstance(result["pixel_values"], list) + self.assertEqual(len(result["pixel_values"]), 2) + self.assertTrue(torch.equal(result["pixel_values"][0], batch["pixel_values"][:2])) + self.assertTrue(torch.equal(result["pixel_values"][1], batch["pixel_values"][2:])) + self.assertIsInstance(result["image_grid_thw"], list) + self.assertEqual(len(result["image_grid_thw"]), 2) + self.assertTrue(torch.equal(result["image_grid_thw"][0], torch.tensor([[1, 1, 2]]))) + self.assertTrue(torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 2], [1, 2, 1]]))) + class TruncateWithProtectedTokensTester(TrlTestCase): def test_basic_example(self): @@ -1044,12 +1071,16 @@ def test_empty_protected_tokens_list(self): class UnsplitPixelValuesByGridTester(TrlTestCase): def test_unsplit_correctly(self): - split = [torch.randn(4, 5), torch.randn(2, 5)] - merged = torch.cat(split, dim=0) - batch = {"pixel_values": split, "other_key": torch.tensor([1])} + pixel_values = [torch.randn(4, 5), torch.randn(2, 5)] + pixel_values_merged = torch.cat(pixel_values, dim=0) + image_grid_thw = [torch.tensor([[1, 2, 2]]), torch.tensor([[1, 2, 1]])] + image_grid_thw_merged = torch.cat(image_grid_thw, dim=0) + batch = {"pixel_values": pixel_values, "image_grid_thw": image_grid_thw, "other_key": torch.tensor([1])} result = unsplit_pixel_values_by_grid(batch) self.assertIsInstance(result["pixel_values"], torch.Tensor) - self.assertTrue(torch.allclose(result["pixel_values"], merged)) + self.assertTrue(torch.allclose(result["pixel_values"], pixel_values_merged)) + self.assertIsInstance(result["image_grid_thw"], torch.Tensor) + self.assertTrue(torch.equal(result["image_grid_thw"], image_grid_thw_merged)) self.assertIn("other_key", result) def test_no_op_if_not_list(self): diff --git a/trl/experimental/gfpo/gfpo_trainer.py b/trl/experimental/gfpo/gfpo_trainer.py index f2a675fab16..6ac4b6acc7f 100644 --- a/trl/experimental/gfpo/gfpo_trainer.py +++ b/trl/experimental/gfpo/gfpo_trainer.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import logging import re from contextlib import nullcontext @@ -83,22 +82,22 @@ def _generate_and_score_completions(self, inputs): prompts = [x["prompt"] for x in inputs] - # We don't yet support visual reward models/function, so we keep a copy of the original text-only prompts for - # later use in the reward computation. If images are present, we insert {"type": "image"} as required by the - # VLM chat template. - original_prompts = copy.deepcopy(prompts) + if "images" in inputs[0]: + images = [example.get("images") for example in inputs] + elif "image" in inputs[0]: + images = [[example.get("image")] if example.get("image") is not None else None for example in inputs] + else: + images = None # If the prompts are conversational and the inputs contain images, we need to convert the prompts from # [{"role": "user", "content": "What color is the sky?"}] to # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] kwargs = {} - has_images = "image" in inputs[0] - if has_images: - images = [example.get("image") for example in inputs] - kwargs = {"images": [[img] for img in images]} - for prompt in prompts: + if images is not None: + kwargs = {"images": images} + for prompt, image_list in zip(prompts, images): if isinstance(prompt, list): # i.e., when using conversational data - prepare_multimodal_messages(prompt, num_images=1) + prepare_multimodal_messages(prompt, num_images=len(image_list)) prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] @@ -170,7 +169,7 @@ def _generate_and_score_completions(self, inputs): # Generate completions using vLLM: gather all prompts and use them in a single call in the main process if self.vllm_mode == "server": all_prompts_text = gather_object(prompts_text) - if has_images: + if images is not None: all_images = gather_object(images) if self.accelerator.is_main_process: @@ -179,7 +178,7 @@ def _generate_and_score_completions(self, inputs): # prompt individually. ordered_set_of_prompts = all_prompts_text[:: self.num_generations] - if has_images: + if images is not None: ordered_set_of_images = all_images[:: self.num_generations] else: ordered_set_of_images = None @@ -244,7 +243,7 @@ def _generate_and_score_completions(self, inputs): torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group) all_prompts_text = [p for sublist in gathered_prompts for p in sublist] - if has_images: + if images is not None: gathered_images = [None for _ in range(self.vllm_tensor_parallel_size)] torch.distributed.all_gather_object(gathered_images, images, group=self.tp_group) all_images = [img for sublist in gathered_images for img in sublist] @@ -252,15 +251,13 @@ def _generate_and_score_completions(self, inputs): all_images = None else: all_prompts_text = prompts_text - all_images = images if has_images else None + all_images = images - if has_images and all_images: + if images is not None and all_images: vllm_inputs = [] - for prompt, image in zip(all_prompts_text, all_images): - if image is not None: - vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image}}) - else: - vllm_inputs.append(prompt) + for prompt, image_list in zip(all_prompts_text, all_images): + vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image_list}}) + else: vllm_inputs = all_prompts_text @@ -375,6 +372,8 @@ def _generate_and_score_completions(self, inputs): logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size + num_images = [len(img_list) for img_list in images] if images is not None else None + with torch.no_grad(): # If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of # a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the @@ -395,6 +394,7 @@ def _generate_and_score_completions(self, inputs): batch_size, pixel_values=prompt_inputs.get("pixel_values"), image_grid_thw=prompt_inputs.get("image_grid_thw"), + num_images=num_images, pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), ) @@ -419,6 +419,7 @@ def _generate_and_score_completions(self, inputs): batch_size=batch_size, pixel_values=prompt_inputs.get("pixel_values"), image_grid_thw=prompt_inputs.get("image_grid_thw"), + num_images=num_images, pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), ) @@ -432,6 +433,7 @@ def _generate_and_score_completions(self, inputs): batch_size=batch_size, pixel_values=prompt_inputs.get("pixel_values"), image_grid_thw=prompt_inputs.get("image_grid_thw"), + num_images=num_images, pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), ) @@ -451,7 +453,7 @@ def _generate_and_score_completions(self, inputs): # Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is # important because rewards will be normalized per group, and completions are distributed. We will later slice # rewards_per_func to extract each process's subset. - rewards_per_func = self._calculate_rewards(inputs, original_prompts, completions, completion_ids_list) + rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list) # Apply weights to each reward function's output and sum rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) @@ -563,12 +565,12 @@ def _generate_and_score_completions(self, inputs): # Log prompt and completion texts all_prompts_text = gather_object(prompts_text) all_completions_text = gather_object(completions_text) - all_images = gather_object(images) if has_images else None + all_images = gather_object(images) if images is not None else None if self.num_remains_in_group is not None and mode == "train": group_global_indices_list = group_global_indices.tolist() all_prompts_text = [all_prompts_text[i] for i in group_global_indices_list] all_completions_text = [all_completions_text[i] for i in group_global_indices_list] - if has_images: + if images is not None: all_images = [all_images[i] for i in group_global_indices_list] self._logs["prompt"].extend(all_prompts_text) @@ -577,8 +579,8 @@ def _generate_and_score_completions(self, inputs): self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist()) self._logs["advantages"].extend(all_process_advantages.tolist()) - if has_images: - self._logs["image"].extend(all_images) + if images is not None: + self._logs["images"].extend(gather_object(images)) if self.use_vllm and self.vllm_importance_sampling_correction: delta = torch.abs(old_per_token_logps - sampling_per_token_logps) @@ -639,4 +641,6 @@ def _generate_and_score_completions(self, inputs): output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"] if "image_sizes" in prompt_inputs: output["image_sizes"] = prompt_inputs["image_sizes"] + if images is not None: + output["num_images"] = num_images return output diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index bb902445d09..69825102b27 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import inspect import os import re @@ -464,7 +463,7 @@ def __init__( self.num_completions_to_print = args.num_completions_to_print # Keep logs sized to the generation batch to record only outputs from the latest model update. self._logs = { - "image": deque(maxlen=args.generation_batch_size), + "images": deque(maxlen=args.generation_batch_size), "prompt": deque(maxlen=args.generation_batch_size), "completion": deque(maxlen=args.generation_batch_size), "rewards": defaultdict(lambda: deque(maxlen=args.generation_batch_size)), @@ -609,7 +608,7 @@ def _set_signature_columns_if_needed(self): # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work. # Instead, we set them to the columns expected by the `training_step` method, hence the override. if self._signature_columns is None: - self._signature_columns = ["prompt", "image"] + self._signature_columns = ["prompt", "image", "images"] # This method overrides `Trainer.get_train_dataloader` to support our custom batching strategy. # Instead of returning a standard per-step batch (i.e., `per_device_batch_size), our dataloader loads an @@ -789,6 +788,7 @@ def _get_per_token_logps_and_entropies( compute_entropy=False, pixel_values=None, image_grid_thw=None, + num_images=None, pixel_attention_mask=None, image_sizes=None, ) -> dict[str, Optional[torch.Tensor]]: @@ -802,12 +802,16 @@ def _get_per_token_logps_and_entropies( # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} - if image_grid_thw is not None and pixel_values is not None: - model_inputs["image_grid_thw"] = image_grid_thw[start : start + batch_size] - start_pixel_idx = image_grid_thw[:start].prod(-1).sum().item() - end_pixel_idx = image_grid_thw[: start + batch_size].prod(-1).sum().item() - model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx] + rows_per_image = image_grid_thw.prod(dim=-1) + rows_per_sample = torch.split(rows_per_image, num_images) + rows_per_sample = torch.stack([s.sum() for s in rows_per_sample]) + cum_rows = torch.cat([torch.tensor([0], device=rows_per_sample.device), rows_per_sample.cumsum(0)]) + row_start, row_end = cum_rows[start].item(), cum_rows[start + batch_size].item() + model_inputs["pixel_values"] = pixel_values[row_start:row_end] + cum_imgs = torch.tensor([0] + num_images).cumsum(0) + img_start, img_end = cum_imgs[start], cum_imgs[start + batch_size] + model_inputs["image_grid_thw"] = image_grid_thw[img_start:img_end] elif pixel_values is not None: model_inputs["pixel_values"] = pixel_values[start : start + batch_size] if pixel_attention_mask is not None: @@ -1065,22 +1069,22 @@ def _generate_and_score_completions( prompts = [x["prompt"] for x in inputs] - # We don't yet support visual reward models/function, so we keep a copy of the original text-only prompts for - # later use in the reward computation. If images are present, we insert {"type": "image"} as required by the - # VLM chat template. - original_prompts = copy.deepcopy(prompts) + if "images" in inputs[0]: + images = [example.get("images") for example in inputs] + elif "image" in inputs[0]: + images = [[example.get("image")] if example.get("image") is not None else None for example in inputs] + else: + images = None # If the prompts are conversational and the inputs contain images, we need to convert the prompts from # [{"role": "user", "content": "What color is the sky?"}] to # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] kwargs = {} - has_images = "image" in inputs[0] - if has_images: - images = [example.get("image") for example in inputs] - kwargs = {"images": [[img] for img in images]} - for prompt in prompts: + if images is not None: + kwargs = {"images": images} + for prompt, image_list in zip(prompts, images): if isinstance(prompt, list): # i.e., when using conversational data - prepare_multimodal_messages(prompt, num_images=1) + prepare_multimodal_messages(prompt, num_images=len(image_list)) prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] @@ -1152,7 +1156,7 @@ def _generate_and_score_completions( # Generate completions using vLLM: gather all prompts and use them in a single call in the main process if self.vllm_mode == "server": all_prompts_text = gather_object(prompts_text) - if has_images: + if images is not None: all_images = gather_object(images) if self.accelerator.is_main_process: @@ -1161,7 +1165,7 @@ def _generate_and_score_completions( # prompt individually. ordered_set_of_prompts = all_prompts_text[:: self.num_generations] - if has_images: + if images is not None: ordered_set_of_images = all_images[:: self.num_generations] else: ordered_set_of_images = None @@ -1226,7 +1230,7 @@ def _generate_and_score_completions( torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group) all_prompts_text = [p for sublist in gathered_prompts for p in sublist] - if has_images: + if images is not None: gathered_images = [None for _ in range(self.vllm_tensor_parallel_size)] torch.distributed.all_gather_object(gathered_images, images, group=self.tp_group) all_images = [img for sublist in gathered_images for img in sublist] @@ -1234,15 +1238,13 @@ def _generate_and_score_completions( all_images = None else: all_prompts_text = prompts_text - all_images = images if has_images else None + all_images = images - if has_images and all_images: + if images is not None and all_images: vllm_inputs = [] - for prompt, image in zip(all_prompts_text, all_images): - if image is not None: - vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image}}) - else: - vllm_inputs.append(prompt) + for prompt, image_list in zip(all_prompts_text, all_images): + vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image_list}}) + else: vllm_inputs = all_prompts_text @@ -1357,6 +1359,8 @@ def _generate_and_score_completions( logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size + num_images = [len(img_list) for img_list in images] if images is not None else None + with torch.no_grad(): # If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of # a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the @@ -1377,6 +1381,7 @@ def _generate_and_score_completions( batch_size, pixel_values=prompt_inputs.get("pixel_values"), image_grid_thw=prompt_inputs.get("image_grid_thw"), + num_images=num_images, pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), ) @@ -1401,6 +1406,7 @@ def _generate_and_score_completions( batch_size=batch_size, pixel_values=prompt_inputs.get("pixel_values"), image_grid_thw=prompt_inputs.get("image_grid_thw"), + num_images=num_images, pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), ) @@ -1414,6 +1420,7 @@ def _generate_and_score_completions( batch_size=batch_size, pixel_values=prompt_inputs.get("pixel_values"), image_grid_thw=prompt_inputs.get("image_grid_thw"), + num_images=num_images, pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), ) @@ -1433,7 +1440,7 @@ def _generate_and_score_completions( # Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is # important because rewards will be normalized per group, and completions are distributed. We will later slice # rewards_per_func to extract each process's subset. - rewards_per_func = self._calculate_rewards(inputs, original_prompts, completions, completion_ids_list) + rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list) # Apply weights to each reward function's output and sum rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) @@ -1507,8 +1514,8 @@ def _generate_and_score_completions( self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist()) self._logs["advantages"].extend(all_process_advantages.tolist()) - if has_images: - self._logs["image"].extend(gather_object(images)) + if images is not None: + self._logs["images"].extend(gather_object(images)) if self.use_vllm and self.vllm_importance_sampling_correction: delta = torch.abs(old_per_token_logps - sampling_per_token_logps) @@ -1564,6 +1571,8 @@ def _generate_and_score_completions( output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"] if "image_sizes" in prompt_inputs: output["image_sizes"] = prompt_inputs["image_sizes"] + if images is not None: + output["num_images"] = num_images return output def compute_liger_loss(self, unwrapped_model, inputs): @@ -1636,6 +1645,7 @@ def _compute_loss(self, model, inputs): compute_entropy=True, pixel_values=inputs.get("pixel_values"), image_grid_thw=inputs.get("image_grid_thw"), + num_images=inputs.get("num_images"), pixel_attention_mask=inputs.get("pixel_attention_mask"), image_sizes=inputs.get("image_sizes"), ) @@ -1790,14 +1800,11 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non "advantage": self._logs["advantages"], } - if self._logs["image"]: - table["image"] = [] - for img in self._logs["image"]: - if img is not None: - # Convert images to wandb Image objects for proper visualization - table["image"].append(wandb.Image(img)) - else: - table["image"].append(None) + if self._logs["images"]: + table["images"] = [] + for image_list in self._logs["images"]: + # Convert images to wandb Image objects for proper visualization + table["images"].append([wandb.Image(image) for image in image_list]) df = pd.DataFrame(table) if self.wandb_log_unique_prompts: diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 56ffbfe7fea..359cb68e43b 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import inspect import os import re @@ -536,7 +535,7 @@ def decode(example, tokenizer): self.num_completions_to_print = args.num_completions_to_print # Keep logs sized to the generation batch to record only outputs from the latest model update. self._logs = { - "image": deque(maxlen=args.generation_batch_size), + "images": deque(maxlen=args.generation_batch_size), "prompt": deque(maxlen=args.generation_batch_size), "completion": deque(maxlen=args.generation_batch_size), "rewards": defaultdict(lambda: deque(maxlen=args.generation_batch_size)), @@ -678,7 +677,7 @@ def _set_signature_columns_if_needed(self): # In RLOOTrainer, we preprocess data, so using the model's signature columns doesn't work. # Instead, we set them to the columns expected by the `training_step` method, hence the override. if self._signature_columns is None: - self._signature_columns = ["prompt", "image"] + self._signature_columns = ["prompt", "image", "images"] # This method overrides `Trainer.get_train_dataloader` to support our custom batching strategy. # Instead of returning a standard per-step batch (i.e., `per_device_batch_size), our dataloader loads an @@ -775,6 +774,7 @@ def _get_per_token_logps_and_entropies( compute_entropy=False, pixel_values=None, image_grid_thw=None, + num_images=None, pixel_attention_mask=None, image_sizes=None, ) -> dict[str, Optional[torch.Tensor]]: @@ -790,10 +790,15 @@ def _get_per_token_logps_and_entropies( model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} if image_grid_thw is not None and pixel_values is not None: - model_inputs["image_grid_thw"] = image_grid_thw[start : start + batch_size] - start_pixel_idx = image_grid_thw[:start].prod(-1).sum().item() - end_pixel_idx = image_grid_thw[: start + batch_size].prod(-1).sum().item() - model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx] + rows_per_image = image_grid_thw.prod(dim=-1) + rows_per_sample = torch.split(rows_per_image, num_images) + rows_per_sample = torch.stack([s.sum() for s in rows_per_sample]) + cum_rows = torch.cat([torch.tensor([0], device=rows_per_sample.device), rows_per_sample.cumsum(0)]) + row_start, row_end = cum_rows[start].item(), cum_rows[start + batch_size].item() + model_inputs["pixel_values"] = pixel_values[row_start:row_end] + cum_imgs = torch.tensor([0] + num_images).cumsum(0) + img_start, img_end = cum_imgs[start], cum_imgs[start + batch_size] + model_inputs["image_grid_thw"] = image_grid_thw[img_start:img_end] elif pixel_values is not None: model_inputs["pixel_values"] = pixel_values[start : start + batch_size] if pixel_attention_mask is not None: @@ -1051,22 +1056,22 @@ def _generate_and_score_completions( prompts = [x["prompt"] for x in inputs] - # We don't yet support visual reward models/function, so we keep a copy of the original text-only prompts for - # later use in the reward computation. If images are present, we insert {"type": "image"} as required by the - # VLM chat template. - original_prompts = copy.deepcopy(prompts) + if "images" in inputs[0]: + images = [example.get("images") for example in inputs] + elif "image" in inputs[0]: + images = [[example.get("image")] if example.get("image") is not None else None for example in inputs] + else: + images = None # If the prompts are conversational and the inputs contain images, we need to convert the prompts from # [{"role": "user", "content": "What color is the sky?"}] to # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] kwargs = {} - has_images = "image" in inputs[0] - if has_images: - images = [example.get("image") for example in inputs] - kwargs = {"images": [[img] for img in images]} - for prompt in prompts: + if images is not None: + kwargs = {"images": images} + for prompt, image_list in zip(prompts, images): if isinstance(prompt, list): # i.e., when using conversational data - prepare_multimodal_messages(prompt, num_images=1) + prepare_multimodal_messages(prompt, num_images=len(image_list)) prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] @@ -1133,7 +1138,7 @@ def _generate_and_score_completions( # Generate completions using vLLM: gather all prompts and use them in a single call in the main process if self.vllm_mode == "server": all_prompts_text = gather_object(prompts_text) - if has_images: + if images is not None: all_images = gather_object(images) if self.accelerator.is_main_process: @@ -1142,7 +1147,7 @@ def _generate_and_score_completions( # prompt individually. ordered_set_of_prompts = all_prompts_text[:: self.num_generations] - if has_images: + if images is not None: ordered_set_of_images = all_images[:: self.num_generations] else: ordered_set_of_images = None @@ -1205,7 +1210,7 @@ def _generate_and_score_completions( torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group) all_prompts_text = [p for sublist in gathered_prompts for p in sublist] - if has_images: + if images is not None: gathered_images = [None for _ in range(self.vllm_tensor_parallel_size)] torch.distributed.all_gather_object(gathered_images, images, group=self.tp_group) all_images = [img for sublist in gathered_images for img in sublist] @@ -1213,15 +1218,13 @@ def _generate_and_score_completions( all_images = None else: all_prompts_text = prompts_text - all_images = images if has_images else None + all_images = images - if has_images and all_images: + if images is not None and all_images: vllm_inputs = [] - for prompt, image in zip(all_prompts_text, all_images): - if image is not None: - vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image}}) - else: - vllm_inputs.append(prompt) + for prompt, image_list in zip(all_prompts_text, all_images): + vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image_list}}) + else: vllm_inputs = all_prompts_text @@ -1321,6 +1324,8 @@ def _generate_and_score_completions( logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size + num_images = [len(img_list) for img_list in images] if images is not None else None + with torch.no_grad(): # Compute the per-token log probabilities for the current model old_per_token_logps, _ = self._get_per_token_logps_and_entropies( @@ -1331,6 +1336,7 @@ def _generate_and_score_completions( batch_size, pixel_values=prompt_inputs.get("pixel_values"), image_grid_thw=prompt_inputs.get("image_grid_thw"), + num_images=num_images, pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), ) @@ -1347,6 +1353,7 @@ def _generate_and_score_completions( batch_size=batch_size, pixel_values=prompt_inputs.get("pixel_values"), image_grid_thw=prompt_inputs.get("image_grid_thw"), + num_images=num_images, pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), ) @@ -1360,6 +1367,7 @@ def _generate_and_score_completions( batch_size=batch_size, pixel_values=prompt_inputs.get("pixel_values"), image_grid_thw=prompt_inputs.get("image_grid_thw"), + num_images=num_images, pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), ) @@ -1379,7 +1387,7 @@ def _generate_and_score_completions( # Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is # important because rewards will be normalized per group, and completions are distributed. We will later slice # rewards_per_func to extract each process's subset. - rewards_per_func = self._calculate_rewards(inputs, original_prompts, completions, completion_ids_list) + rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list) # Apply weights to each reward function's output and sum rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) @@ -1463,8 +1471,8 @@ def _generate_and_score_completions( self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist()) self._logs["advantages"].extend(all_process_advantages.tolist()) - if has_images: - self._logs["image"].extend(gather_object(images)) + if images is not None: + self._logs["images"].extend(gather_object(images)) output = { "prompt_ids": prompt_ids, @@ -1482,6 +1490,8 @@ def _generate_and_score_completions( output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"] if "image_sizes" in prompt_inputs: output["image_sizes"] = prompt_inputs["image_sizes"] + if images is not None: + output["num_images"] = num_images return output @profiling_decorator @@ -1507,6 +1517,7 @@ def _compute_loss(self, model, inputs): compute_entropy=True, pixel_values=inputs.get("pixel_values"), image_grid_thw=inputs.get("image_grid_thw"), + num_images=inputs.get("num_images"), pixel_attention_mask=inputs.get("pixel_attention_mask"), image_sizes=inputs.get("image_sizes"), ) @@ -1588,14 +1599,11 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non "advantage": self._logs["advantages"], } - if self._logs["image"]: - table["image"] = [] - for img in self._logs["image"]: - if img is not None: - # Convert images to wandb Image objects for proper visualization - table["image"].append(wandb.Image(img)) - else: - table["image"].append(None) + if self._logs["images"]: + table["images"] = [] + for image_list in self._logs["images"]: + # Convert images to wandb Image objects for proper visualization + table["images"].append([wandb.Image(image) for image in image_list]) df = pd.DataFrame(table) if self.wandb_log_unique_prompts: diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 37612a423bd..337e9857c1a 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -19,6 +19,7 @@ from collections.abc import Sequence, Sized from dataclasses import dataclass, field from importlib.metadata import version +from itertools import accumulate from typing import Any, Literal, Optional, Union import numpy as np @@ -1780,20 +1781,23 @@ def identity(x): def split_pixel_values_by_grid(batch: dict[str, torch.Tensor]) -> dict[str, Union[torch.Tensor, list[torch.Tensor]]]: """ - Splits `batch["pixel_values"]` into a list of tensors based on the product of each row in - `batch["image_grid_thw"]`, while keeping other entries unchanged. + Splits `batch["pixel_values"]` into a list of tensors based on the product of each row in `batch["image_grid_thw"]` + and batch["num_images"] while keeping other entries unchanged. """ - if "image_grid_thw" not in batch or "pixel_values" not in batch: + if "image_grid_thw" not in batch or "pixel_values" not in batch or "num_images" not in batch: return batch - lengths = batch["image_grid_thw"].prod(-1).tolist() # [batch_size] + lengths = batch["image_grid_thw"].prod(-1).tolist() # [num_images] pixel_values = batch["pixel_values"] # [total, feature_dim] if sum(lengths) != pixel_values.size(0): raise ValueError(f"Mismatch: sum(lengths) = {sum(lengths)} != pixel_values.size(0) = {pixel_values.size(0)}") - split_values = list(torch.split(batch["pixel_values"], lengths, dim=0)) - return {**batch, "pixel_values": split_values} + boundaries = [0, *accumulate(batch["num_images"])] # [3, 4, 5] -> [0, 3, 7, 12] + sections = [sum(lengths[boundaries[i] : boundaries[i + 1]]) for i in range(len(batch["num_images"]))] + split_values = list(torch.split(batch["pixel_values"], sections, dim=0)) + image_grid_thw = list(torch.split(batch["image_grid_thw"], batch["num_images"], dim=0)) + return {**batch, "pixel_values": split_values, "image_grid_thw": image_grid_thw} def unsplit_pixel_values_by_grid(batch: dict[str, Union[torch.Tensor, list[torch.Tensor]]]) -> dict[str, torch.Tensor]: @@ -1802,12 +1806,16 @@ def unsplit_pixel_values_by_grid(batch: dict[str, Union[torch.Tensor, list[torch tensor along the first dimension. """ pixel_values = batch.get("pixel_values") - if isinstance(pixel_values, list): merged = torch.cat(pixel_values, dim=0) - return {**batch, "pixel_values": merged} - else: - return batch + batch = {**batch, "pixel_values": merged} + + image_grid_thw = batch.get("image_grid_thw") + if isinstance(image_grid_thw, list): + merged = torch.cat(image_grid_thw, dim=0) + batch = {**batch, "image_grid_thw": merged} + + return batch def truncate_with_protected_tokens(