diff --git a/docs/source/dataset_formats.md b/docs/source/dataset_formats.md
index da606a7d97e..8a105ff5e34 100644
--- a/docs/source/dataset_formats.md
+++ b/docs/source/dataset_formats.md
@@ -1037,7 +1037,7 @@ Some trainers also support fine-tuning vision-language models (VLMs) using image
A conversational vision dataset differs from a standard conversational dataset in two key ways:
-1. The dataset must contain the key `images` with the image data.
+1. The dataset must contain the key `images` with the image data (as lists of PIL images) or `image` with a single PIL image.
2. The `"content"` field in messages must be a list of dictionaries, where each dictionary specifies the type of data: `"image"` or `"text"`.
Example:
diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md
index 172e5f93111..e998ef63f69 100644
--- a/docs/source/grpo_trainer.md
+++ b/docs/source/grpo_trainer.md
@@ -562,7 +562,9 @@ Tested with:
- **SmolVLM2** — e.g., `HuggingFaceTB/SmolVLM2-2.2B-Instruct`
+
Compatibility with all VLMs is not guaranteed. If you believe a model should be supported, feel free to open an issue on GitHub — or better yet, submit a pull request with the required changes.
+
### Quick Start
@@ -605,7 +607,7 @@ VLM training may fail if image tokens are truncated. We highly recommend disabli
Each training sample should include:
- `prompt`: Text formatted via the processor's chat template
-- `image`: A single image (PIL or NumPy array)
+- `image`/`images`: PIL Image or list of PIL Images
The trainer automatically handles image-to-tensor conversion via the model’s image processor.
diff --git a/docs/source/rloo_trainer.md b/docs/source/rloo_trainer.md
index 66a8f3e16e4..bce71b1f0bf 100644
--- a/docs/source/rloo_trainer.md
+++ b/docs/source/rloo_trainer.md
@@ -533,7 +533,9 @@ Tested with:
- **SmolVLM2** — e.g., `HuggingFaceTB/SmolVLM2-2.2B-Instruct`
+
Compatibility with all VLMs is not guaranteed. If you believe a model should be supported, feel free to open an issue on GitHub — or better yet, submit a pull request with the required changes.
+
### Quick Start
@@ -576,7 +578,7 @@ VLM training may fail if image tokens are truncated. We highly recommend disabli
Each training sample should include:
- `prompt`: Text formatted via the processor's chat template
-- `image`: A single image (PIL or NumPy array)
+- `image`/`images`: PIL Image or list of PIL Images
The trainer automatically handles image-to-tensor conversion via the model’s image processor.
diff --git a/scripts/generate_tiny_models.py b/scripts/generate_tiny_models.py
index b89447a9954..90436c8adc5 100644
--- a/scripts/generate_tiny_models.py
+++ b/scripts/generate_tiny_models.py
@@ -292,6 +292,7 @@ def init_weights_tiny_model(model):
"hidden_size": 16,
"num_attention_heads": 4,
"num_key_value_heads": 2,
+ "embed_dim": 64,
}
kwargs = {}
diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py
index 4e4321febcd..cc484c56d0f 100644
--- a/tests/test_grpo_trainer.py
+++ b/tests/test_grpo_trainer.py
@@ -1258,6 +1258,10 @@ def test_prepare_input_called_with_correct_data(self):
def test_training_vlm(self, model_id):
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")
+ def reward_func(completions, **kwargs):
+ """Reward function that rewards longer completions."""
+ return [float(len(completion[0]["content"])) for completion in completions]
+
training_args = GRPOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
@@ -1269,7 +1273,7 @@ def test_training_vlm(self, model_id):
)
trainer = GRPOTrainer(
model=model_id,
- reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
+ reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
)
@@ -1287,7 +1291,6 @@ def test_training_vlm(self, model_id):
"model.vision_tower.",
"model.multi_modal_projector.",
"model.vision_model.",
- "model.connector.modality_projection.",
"model.visual.",
"model.image_newline",
)
@@ -1301,6 +1304,10 @@ def test_training_vlm(self, model_id):
def test_training_vlm_beta_non_zero(self):
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")
+ def reward_func(completions, **kwargs):
+ """Reward function that rewards longer completions."""
+ return [float(len(completion[0]["content"])) for completion in completions]
+
training_args = GRPOConfig(
output_dir=self.tmp_dir,
beta=0.1, # set beta to non-zero value to test the case where the reference model is used
@@ -1312,7 +1319,7 @@ def test_training_vlm_beta_non_zero(self):
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
- reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
+ reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
)
@@ -1340,7 +1347,11 @@ def test_training_vlm_peft(self):
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration"
)
base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]
- dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
+ dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")
+
+ def reward_func(completions, **kwargs):
+ """Reward function that rewards longer completions."""
+ return [float(len(completion[0]["content"])) for completion in completions]
training_args = GRPOConfig(
output_dir=self.tmp_dir,
@@ -1352,7 +1363,7 @@ def test_training_vlm_peft(self):
)
trainer = GRPOTrainer(
model=model,
- reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
+ reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
peft_config=LoraConfig(target_modules=["q_proj", "v_proj"]),
@@ -1376,6 +1387,10 @@ def test_training_vlm_peft(self):
def test_training_vlm_and_importance_sampling(self):
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")
+ def reward_func(completions, **kwargs):
+ """Reward function that rewards longer completions."""
+ return [float(len(completion[0]["content"])) for completion in completions]
+
training_args = GRPOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
@@ -1387,7 +1402,7 @@ def test_training_vlm_and_importance_sampling(self):
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
- reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
+ reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
)
@@ -1413,6 +1428,10 @@ def test_training_vlm_and_importance_sampling(self):
def test_training_vlm_and_liger(self):
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")
+ def reward_func(completions, **kwargs):
+ """Reward function that rewards longer completions."""
+ return [float(len(completion[0]["content"])) for completion in completions]
+
training_args = GRPOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
@@ -1425,7 +1444,7 @@ def test_training_vlm_and_liger(self):
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
- reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
+ reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
)
@@ -1451,6 +1470,10 @@ def test_training_vlm_and_prompt_truncation(self):
# If not handled properly, prompt truncation may truncate image token
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")
+ def reward_func(completions, **kwargs):
+ """Reward function that rewards longer completions."""
+ return [float(len(completion[0]["content"])) for completion in completions]
+
training_args = GRPOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
@@ -1462,7 +1485,7 @@ def test_training_vlm_and_prompt_truncation(self):
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
- reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
+ reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
)
@@ -1495,6 +1518,10 @@ def test_training_vlm_and_prompt_truncation(self):
def test_training_vlm_and_vllm(self, model_id) -> None:
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")
+ def reward_func(completions, **kwargs):
+ """Reward function that rewards longer completions."""
+ return [float(len(completion[0]["content"])) for completion in completions]
+
training_args = GRPOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1,
@@ -1508,7 +1535,44 @@ def test_training_vlm_and_vllm(self, model_id) -> None:
)
trainer = GRPOTrainer(
model=model_id,
- reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
+ reward_funcs=reward_func,
+ args=training_args,
+ train_dataset=dataset,
+ )
+
+ previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
+
+ trainer.train()
+
+ self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+
+ for n, param in previous_trainable_params.items():
+ new_param = trainer.model.get_parameter(n)
+ self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+
+ @require_vision
+ def test_training_vlm_multi_image(self):
+ dataset = load_dataset("trl-internal-testing/zen-multi-image", "conversational_prompt_only", split="train")
+
+ # For now, mixing image+text and text-only examples is not supported, so we filter out text-only examples
+ dataset = dataset.filter(lambda x: len(x["images"]) > 0)
+
+ def reward_func(completions, **kwargs):
+ """Reward function that rewards longer completions."""
+ return [float(len(completion[0]["content"])) for completion in completions]
+
+ training_args = GRPOConfig(
+ output_dir=self.tmp_dir,
+ learning_rate=0.1, # increase the learning rate to speed up the test
+ per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
+ num_generations=3, # reduce the number of generations to reduce memory usage
+ max_completion_length=8, # reduce the completion length to reduce memory usage
+ max_prompt_length=None, # disable prompt truncation, because usually, models don't support it
+ report_to="none",
+ )
+ trainer = GRPOTrainer(
+ model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
+ reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
)
@@ -1519,6 +1583,9 @@ def test_training_vlm_and_vllm(self, model_id) -> None:
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+ # Check that the params have changed
+ # Because of the way the tiny models are initialized, the gradient does not flow properly through the
+ # vision parts of the model, so we skip them. Ideally, we should fix the init of these models.
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py
index 12042bd2b3b..cde52de6047 100644
--- a/tests/test_rloo_trainer.py
+++ b/tests/test_rloo_trainer.py
@@ -1089,6 +1089,10 @@ def test_prepare_input_called_with_correct_data(self):
def test_training_vlm(self, model_id):
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")
+ def reward_func(completions, **kwargs):
+ """Reward function that rewards longer completions."""
+ return [float(len(completion[0]["content"])) for completion in completions]
+
training_args = RLOOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
@@ -1100,7 +1104,7 @@ def test_training_vlm(self, model_id):
)
trainer = RLOOTrainer(
model=model_id,
- reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
+ reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
)
@@ -1117,8 +1121,6 @@ def test_training_vlm(self, model_id):
params_to_skip = (
"model.vision_tower.",
"model.multi_modal_projector.",
- "model.vision_model.",
- "model.connector.modality_projection.",
"model.visual.",
"model.image_newline",
)
@@ -1132,6 +1134,10 @@ def test_training_vlm(self, model_id):
def test_training_vlm_beta_non_zero(self):
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")
+ def reward_func(completions, **kwargs):
+ """Reward function that rewards longer completions."""
+ return [float(len(completion[0]["content"])) for completion in completions]
+
training_args = RLOOConfig(
output_dir=self.tmp_dir,
beta=0.1, # set beta to non-zero value to test the case where the reference model is used
@@ -1143,7 +1149,7 @@ def test_training_vlm_beta_non_zero(self):
)
trainer = RLOOTrainer(
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
- reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
+ reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
)
@@ -1171,7 +1177,11 @@ def test_training_vlm_peft(self):
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration"
)
base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]
- dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
+ dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")
+
+ def reward_func(completions, **kwargs):
+ """Reward function that rewards longer completions."""
+ return [float(len(completion[0]["content"])) for completion in completions]
training_args = RLOOConfig(
output_dir=self.tmp_dir,
@@ -1183,7 +1193,7 @@ def test_training_vlm_peft(self):
)
trainer = RLOOTrainer(
model=model,
- reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
+ reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
peft_config=LoraConfig(target_modules=["q_proj", "v_proj"]),
@@ -1208,6 +1218,10 @@ def test_training_vlm_and_prompt_truncation(self):
# If not handled properly, prompt truncation may truncate image token
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")
+ def reward_func(completions, **kwargs):
+ """Reward function that rewards longer completions."""
+ return [float(len(completion[0]["content"])) for completion in completions]
+
training_args = RLOOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
@@ -1219,7 +1233,7 @@ def test_training_vlm_and_prompt_truncation(self):
)
trainer = RLOOTrainer(
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
- reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
+ reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
)
@@ -1252,6 +1266,10 @@ def test_training_vlm_and_prompt_truncation(self):
def test_training_vlm_and_vllm(self, model_id) -> None:
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")
+ def reward_func(completions, **kwargs):
+ """Reward function that rewards longer completions."""
+ return [float(len(completion[0]["content"])) for completion in completions]
+
training_args = RLOOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1,
@@ -1265,7 +1283,44 @@ def test_training_vlm_and_vllm(self, model_id) -> None:
)
trainer = RLOOTrainer(
model=model_id,
- reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
+ reward_funcs=reward_func,
+ args=training_args,
+ train_dataset=dataset,
+ )
+
+ previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
+
+ trainer.train()
+
+ self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+
+ for n, param in previous_trainable_params.items():
+ new_param = trainer.model.get_parameter(n)
+ self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+
+ @require_vision
+ def test_training_vlm_multi_image(self):
+ dataset = load_dataset("trl-internal-testing/zen-multi-image", "conversational_prompt_only", split="train")
+
+ # For now, mixing image+text and text-only examples is not supported, so we filter out text-only examples
+ dataset = dataset.filter(lambda x: len(x["images"]) > 0)
+
+ def reward_func(completions, **kwargs):
+ """Reward function that rewards longer completions."""
+ return [float(len(completion[0]["content"])) for completion in completions]
+
+ training_args = RLOOConfig(
+ output_dir=self.tmp_dir,
+ learning_rate=0.1, # increase the learning rate to speed up the test
+ per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
+ num_generations=3, # reduce the number of generations to reduce memory usage
+ max_completion_length=8, # reduce the completion length to reduce memory usage
+ max_prompt_length=None, # disable prompt truncation, because usually, models don't support it
+ report_to="none",
+ )
+ trainer = RLOOTrainer(
+ model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
+ reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
)
diff --git a/tests/test_utils.py b/tests/test_utils.py
index f036a897e1b..6f6ba1579ef 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -874,6 +874,7 @@ class SplitPixelValuesByGridTester(TrlTestCase):
def test_split_correctly_0(self):
batch = {
"image_grid_thw": torch.tensor([[1, 2, 2], [1, 2, 2]]),
+ "num_images": [1, 1],
"pixel_values": torch.arange(8 * 3).reshape(8, 3), # Shape: [8, 3]
}
result = split_pixel_values_by_grid(batch)
@@ -881,10 +882,15 @@ def test_split_correctly_0(self):
self.assertEqual(len(result["pixel_values"]), 2)
self.assertTrue(torch.equal(result["pixel_values"][0], batch["pixel_values"][:4]))
self.assertTrue(torch.equal(result["pixel_values"][1], batch["pixel_values"][4:]))
+ self.assertIsInstance(result["image_grid_thw"], list)
+ self.assertEqual(len(result["image_grid_thw"]), 2)
+ self.assertTrue(torch.equal(result["image_grid_thw"][0], torch.tensor([[1, 2, 2]])))
+ self.assertTrue(torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 2]])))
def test_split_correctly_1(self):
batch = {
"image_grid_thw": torch.tensor([[1, 2, 2], [1, 2, 4]]),
+ "num_images": [1, 1],
"pixel_values": torch.arange(12 * 3).reshape(12, 3), # Shape: [12, 3]
}
result = split_pixel_values_by_grid(batch)
@@ -892,6 +898,10 @@ def test_split_correctly_1(self):
self.assertEqual(len(result["pixel_values"]), 2)
self.assertTrue(torch.equal(result["pixel_values"][0], batch["pixel_values"][:4]))
self.assertTrue(torch.equal(result["pixel_values"][1], batch["pixel_values"][4:12]))
+ self.assertIsInstance(result["image_grid_thw"], list)
+ self.assertEqual(len(result["image_grid_thw"]), 2)
+ self.assertTrue(torch.equal(result["image_grid_thw"][0], torch.tensor([[1, 2, 2]])))
+ self.assertTrue(torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 4]])))
def test_missing_keys(self):
batch = {"pixel_values": torch.tensor([1.0])}
@@ -901,11 +911,28 @@ def test_missing_keys(self):
def test_mismatched_length(self):
batch = {
"image_grid_thw": torch.tensor([[1, 1, 2], [1, 2, 1]]), # Total = 8
+ "num_images": [1, 1],
"pixel_values": torch.randn(3, 5), # Only 3 rows
}
with self.assertRaises(ValueError):
split_pixel_values_by_grid(batch)
+ def test_multi_images(self):
+ batch = {
+ "image_grid_thw": torch.tensor([[1, 1, 2], [1, 2, 2], [1, 2, 1]]), # Total = 8
+ "num_images": [1, 2],
+ "pixel_values": torch.arange(8 * 3).reshape(8, 3), # Shape: [8, 3]
+ }
+ result = split_pixel_values_by_grid(batch)
+ self.assertIsInstance(result["pixel_values"], list)
+ self.assertEqual(len(result["pixel_values"]), 2)
+ self.assertTrue(torch.equal(result["pixel_values"][0], batch["pixel_values"][:2]))
+ self.assertTrue(torch.equal(result["pixel_values"][1], batch["pixel_values"][2:]))
+ self.assertIsInstance(result["image_grid_thw"], list)
+ self.assertEqual(len(result["image_grid_thw"]), 2)
+ self.assertTrue(torch.equal(result["image_grid_thw"][0], torch.tensor([[1, 1, 2]])))
+ self.assertTrue(torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 2], [1, 2, 1]])))
+
class TruncateWithProtectedTokensTester(TrlTestCase):
def test_basic_example(self):
@@ -1044,12 +1071,16 @@ def test_empty_protected_tokens_list(self):
class UnsplitPixelValuesByGridTester(TrlTestCase):
def test_unsplit_correctly(self):
- split = [torch.randn(4, 5), torch.randn(2, 5)]
- merged = torch.cat(split, dim=0)
- batch = {"pixel_values": split, "other_key": torch.tensor([1])}
+ pixel_values = [torch.randn(4, 5), torch.randn(2, 5)]
+ pixel_values_merged = torch.cat(pixel_values, dim=0)
+ image_grid_thw = [torch.tensor([[1, 2, 2]]), torch.tensor([[1, 2, 1]])]
+ image_grid_thw_merged = torch.cat(image_grid_thw, dim=0)
+ batch = {"pixel_values": pixel_values, "image_grid_thw": image_grid_thw, "other_key": torch.tensor([1])}
result = unsplit_pixel_values_by_grid(batch)
self.assertIsInstance(result["pixel_values"], torch.Tensor)
- self.assertTrue(torch.allclose(result["pixel_values"], merged))
+ self.assertTrue(torch.allclose(result["pixel_values"], pixel_values_merged))
+ self.assertIsInstance(result["image_grid_thw"], torch.Tensor)
+ self.assertTrue(torch.equal(result["image_grid_thw"], image_grid_thw_merged))
self.assertIn("other_key", result)
def test_no_op_if_not_list(self):
diff --git a/trl/experimental/gfpo/gfpo_trainer.py b/trl/experimental/gfpo/gfpo_trainer.py
index f2a675fab16..6ac4b6acc7f 100644
--- a/trl/experimental/gfpo/gfpo_trainer.py
+++ b/trl/experimental/gfpo/gfpo_trainer.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import copy
import logging
import re
from contextlib import nullcontext
@@ -83,22 +82,22 @@ def _generate_and_score_completions(self, inputs):
prompts = [x["prompt"] for x in inputs]
- # We don't yet support visual reward models/function, so we keep a copy of the original text-only prompts for
- # later use in the reward computation. If images are present, we insert {"type": "image"} as required by the
- # VLM chat template.
- original_prompts = copy.deepcopy(prompts)
+ if "images" in inputs[0]:
+ images = [example.get("images") for example in inputs]
+ elif "image" in inputs[0]:
+ images = [[example.get("image")] if example.get("image") is not None else None for example in inputs]
+ else:
+ images = None
# If the prompts are conversational and the inputs contain images, we need to convert the prompts from
# [{"role": "user", "content": "What color is the sky?"}] to
# [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}]
kwargs = {}
- has_images = "image" in inputs[0]
- if has_images:
- images = [example.get("image") for example in inputs]
- kwargs = {"images": [[img] for img in images]}
- for prompt in prompts:
+ if images is not None:
+ kwargs = {"images": images}
+ for prompt, image_list in zip(prompts, images):
if isinstance(prompt, list): # i.e., when using conversational data
- prepare_multimodal_messages(prompt, num_images=1)
+ prepare_multimodal_messages(prompt, num_images=len(image_list))
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
@@ -170,7 +169,7 @@ def _generate_and_score_completions(self, inputs):
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
if self.vllm_mode == "server":
all_prompts_text = gather_object(prompts_text)
- if has_images:
+ if images is not None:
all_images = gather_object(images)
if self.accelerator.is_main_process:
@@ -179,7 +178,7 @@ def _generate_and_score_completions(self, inputs):
# prompt individually.
ordered_set_of_prompts = all_prompts_text[:: self.num_generations]
- if has_images:
+ if images is not None:
ordered_set_of_images = all_images[:: self.num_generations]
else:
ordered_set_of_images = None
@@ -244,7 +243,7 @@ def _generate_and_score_completions(self, inputs):
torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group)
all_prompts_text = [p for sublist in gathered_prompts for p in sublist]
- if has_images:
+ if images is not None:
gathered_images = [None for _ in range(self.vllm_tensor_parallel_size)]
torch.distributed.all_gather_object(gathered_images, images, group=self.tp_group)
all_images = [img for sublist in gathered_images for img in sublist]
@@ -252,15 +251,13 @@ def _generate_and_score_completions(self, inputs):
all_images = None
else:
all_prompts_text = prompts_text
- all_images = images if has_images else None
+ all_images = images
- if has_images and all_images:
+ if images is not None and all_images:
vllm_inputs = []
- for prompt, image in zip(all_prompts_text, all_images):
- if image is not None:
- vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image}})
- else:
- vllm_inputs.append(prompt)
+ for prompt, image_list in zip(all_prompts_text, all_images):
+ vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image_list}})
+
else:
vllm_inputs = all_prompts_text
@@ -375,6 +372,8 @@ def _generate_and_score_completions(self, inputs):
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
+ num_images = [len(img_list) for img_list in images] if images is not None else None
+
with torch.no_grad():
# If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of
# a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the
@@ -395,6 +394,7 @@ def _generate_and_score_completions(self, inputs):
batch_size,
pixel_values=prompt_inputs.get("pixel_values"),
image_grid_thw=prompt_inputs.get("image_grid_thw"),
+ num_images=num_images,
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
image_sizes=prompt_inputs.get("image_sizes"),
)
@@ -419,6 +419,7 @@ def _generate_and_score_completions(self, inputs):
batch_size=batch_size,
pixel_values=prompt_inputs.get("pixel_values"),
image_grid_thw=prompt_inputs.get("image_grid_thw"),
+ num_images=num_images,
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
image_sizes=prompt_inputs.get("image_sizes"),
)
@@ -432,6 +433,7 @@ def _generate_and_score_completions(self, inputs):
batch_size=batch_size,
pixel_values=prompt_inputs.get("pixel_values"),
image_grid_thw=prompt_inputs.get("image_grid_thw"),
+ num_images=num_images,
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
image_sizes=prompt_inputs.get("image_sizes"),
)
@@ -451,7 +453,7 @@ def _generate_and_score_completions(self, inputs):
# Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is
# important because rewards will be normalized per group, and completions are distributed. We will later slice
# rewards_per_func to extract each process's subset.
- rewards_per_func = self._calculate_rewards(inputs, original_prompts, completions, completion_ids_list)
+ rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list)
# Apply weights to each reward function's output and sum
rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)
@@ -563,12 +565,12 @@ def _generate_and_score_completions(self, inputs):
# Log prompt and completion texts
all_prompts_text = gather_object(prompts_text)
all_completions_text = gather_object(completions_text)
- all_images = gather_object(images) if has_images else None
+ all_images = gather_object(images) if images is not None else None
if self.num_remains_in_group is not None and mode == "train":
group_global_indices_list = group_global_indices.tolist()
all_prompts_text = [all_prompts_text[i] for i in group_global_indices_list]
all_completions_text = [all_completions_text[i] for i in group_global_indices_list]
- if has_images:
+ if images is not None:
all_images = [all_images[i] for i in group_global_indices_list]
self._logs["prompt"].extend(all_prompts_text)
@@ -577,8 +579,8 @@ def _generate_and_score_completions(self, inputs):
self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist())
self._logs["advantages"].extend(all_process_advantages.tolist())
- if has_images:
- self._logs["image"].extend(all_images)
+ if images is not None:
+ self._logs["images"].extend(gather_object(images))
if self.use_vllm and self.vllm_importance_sampling_correction:
delta = torch.abs(old_per_token_logps - sampling_per_token_logps)
@@ -639,4 +641,6 @@ def _generate_and_score_completions(self, inputs):
output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"]
if "image_sizes" in prompt_inputs:
output["image_sizes"] = prompt_inputs["image_sizes"]
+ if images is not None:
+ output["num_images"] = num_images
return output
diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py
index bb902445d09..69825102b27 100644
--- a/trl/trainer/grpo_trainer.py
+++ b/trl/trainer/grpo_trainer.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import copy
import inspect
import os
import re
@@ -464,7 +463,7 @@ def __init__(
self.num_completions_to_print = args.num_completions_to_print
# Keep logs sized to the generation batch to record only outputs from the latest model update.
self._logs = {
- "image": deque(maxlen=args.generation_batch_size),
+ "images": deque(maxlen=args.generation_batch_size),
"prompt": deque(maxlen=args.generation_batch_size),
"completion": deque(maxlen=args.generation_batch_size),
"rewards": defaultdict(lambda: deque(maxlen=args.generation_batch_size)),
@@ -609,7 +608,7 @@ def _set_signature_columns_if_needed(self):
# In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
# Instead, we set them to the columns expected by the `training_step` method, hence the override.
if self._signature_columns is None:
- self._signature_columns = ["prompt", "image"]
+ self._signature_columns = ["prompt", "image", "images"]
# This method overrides `Trainer.get_train_dataloader` to support our custom batching strategy.
# Instead of returning a standard per-step batch (i.e., `per_device_batch_size), our dataloader loads an
@@ -789,6 +788,7 @@ def _get_per_token_logps_and_entropies(
compute_entropy=False,
pixel_values=None,
image_grid_thw=None,
+ num_images=None,
pixel_attention_mask=None,
image_sizes=None,
) -> dict[str, Optional[torch.Tensor]]:
@@ -802,12 +802,16 @@ def _get_per_token_logps_and_entropies(
# Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't)
model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch}
-
if image_grid_thw is not None and pixel_values is not None:
- model_inputs["image_grid_thw"] = image_grid_thw[start : start + batch_size]
- start_pixel_idx = image_grid_thw[:start].prod(-1).sum().item()
- end_pixel_idx = image_grid_thw[: start + batch_size].prod(-1).sum().item()
- model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx]
+ rows_per_image = image_grid_thw.prod(dim=-1)
+ rows_per_sample = torch.split(rows_per_image, num_images)
+ rows_per_sample = torch.stack([s.sum() for s in rows_per_sample])
+ cum_rows = torch.cat([torch.tensor([0], device=rows_per_sample.device), rows_per_sample.cumsum(0)])
+ row_start, row_end = cum_rows[start].item(), cum_rows[start + batch_size].item()
+ model_inputs["pixel_values"] = pixel_values[row_start:row_end]
+ cum_imgs = torch.tensor([0] + num_images).cumsum(0)
+ img_start, img_end = cum_imgs[start], cum_imgs[start + batch_size]
+ model_inputs["image_grid_thw"] = image_grid_thw[img_start:img_end]
elif pixel_values is not None:
model_inputs["pixel_values"] = pixel_values[start : start + batch_size]
if pixel_attention_mask is not None:
@@ -1065,22 +1069,22 @@ def _generate_and_score_completions(
prompts = [x["prompt"] for x in inputs]
- # We don't yet support visual reward models/function, so we keep a copy of the original text-only prompts for
- # later use in the reward computation. If images are present, we insert {"type": "image"} as required by the
- # VLM chat template.
- original_prompts = copy.deepcopy(prompts)
+ if "images" in inputs[0]:
+ images = [example.get("images") for example in inputs]
+ elif "image" in inputs[0]:
+ images = [[example.get("image")] if example.get("image") is not None else None for example in inputs]
+ else:
+ images = None
# If the prompts are conversational and the inputs contain images, we need to convert the prompts from
# [{"role": "user", "content": "What color is the sky?"}] to
# [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}]
kwargs = {}
- has_images = "image" in inputs[0]
- if has_images:
- images = [example.get("image") for example in inputs]
- kwargs = {"images": [[img] for img in images]}
- for prompt in prompts:
+ if images is not None:
+ kwargs = {"images": images}
+ for prompt, image_list in zip(prompts, images):
if isinstance(prompt, list): # i.e., when using conversational data
- prepare_multimodal_messages(prompt, num_images=1)
+ prepare_multimodal_messages(prompt, num_images=len(image_list))
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
@@ -1152,7 +1156,7 @@ def _generate_and_score_completions(
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
if self.vllm_mode == "server":
all_prompts_text = gather_object(prompts_text)
- if has_images:
+ if images is not None:
all_images = gather_object(images)
if self.accelerator.is_main_process:
@@ -1161,7 +1165,7 @@ def _generate_and_score_completions(
# prompt individually.
ordered_set_of_prompts = all_prompts_text[:: self.num_generations]
- if has_images:
+ if images is not None:
ordered_set_of_images = all_images[:: self.num_generations]
else:
ordered_set_of_images = None
@@ -1226,7 +1230,7 @@ def _generate_and_score_completions(
torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group)
all_prompts_text = [p for sublist in gathered_prompts for p in sublist]
- if has_images:
+ if images is not None:
gathered_images = [None for _ in range(self.vllm_tensor_parallel_size)]
torch.distributed.all_gather_object(gathered_images, images, group=self.tp_group)
all_images = [img for sublist in gathered_images for img in sublist]
@@ -1234,15 +1238,13 @@ def _generate_and_score_completions(
all_images = None
else:
all_prompts_text = prompts_text
- all_images = images if has_images else None
+ all_images = images
- if has_images and all_images:
+ if images is not None and all_images:
vllm_inputs = []
- for prompt, image in zip(all_prompts_text, all_images):
- if image is not None:
- vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image}})
- else:
- vllm_inputs.append(prompt)
+ for prompt, image_list in zip(all_prompts_text, all_images):
+ vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image_list}})
+
else:
vllm_inputs = all_prompts_text
@@ -1357,6 +1359,8 @@ def _generate_and_score_completions(
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
+ num_images = [len(img_list) for img_list in images] if images is not None else None
+
with torch.no_grad():
# If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of
# a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the
@@ -1377,6 +1381,7 @@ def _generate_and_score_completions(
batch_size,
pixel_values=prompt_inputs.get("pixel_values"),
image_grid_thw=prompt_inputs.get("image_grid_thw"),
+ num_images=num_images,
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
image_sizes=prompt_inputs.get("image_sizes"),
)
@@ -1401,6 +1406,7 @@ def _generate_and_score_completions(
batch_size=batch_size,
pixel_values=prompt_inputs.get("pixel_values"),
image_grid_thw=prompt_inputs.get("image_grid_thw"),
+ num_images=num_images,
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
image_sizes=prompt_inputs.get("image_sizes"),
)
@@ -1414,6 +1420,7 @@ def _generate_and_score_completions(
batch_size=batch_size,
pixel_values=prompt_inputs.get("pixel_values"),
image_grid_thw=prompt_inputs.get("image_grid_thw"),
+ num_images=num_images,
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
image_sizes=prompt_inputs.get("image_sizes"),
)
@@ -1433,7 +1440,7 @@ def _generate_and_score_completions(
# Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is
# important because rewards will be normalized per group, and completions are distributed. We will later slice
# rewards_per_func to extract each process's subset.
- rewards_per_func = self._calculate_rewards(inputs, original_prompts, completions, completion_ids_list)
+ rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list)
# Apply weights to each reward function's output and sum
rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)
@@ -1507,8 +1514,8 @@ def _generate_and_score_completions(
self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist())
self._logs["advantages"].extend(all_process_advantages.tolist())
- if has_images:
- self._logs["image"].extend(gather_object(images))
+ if images is not None:
+ self._logs["images"].extend(gather_object(images))
if self.use_vllm and self.vllm_importance_sampling_correction:
delta = torch.abs(old_per_token_logps - sampling_per_token_logps)
@@ -1564,6 +1571,8 @@ def _generate_and_score_completions(
output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"]
if "image_sizes" in prompt_inputs:
output["image_sizes"] = prompt_inputs["image_sizes"]
+ if images is not None:
+ output["num_images"] = num_images
return output
def compute_liger_loss(self, unwrapped_model, inputs):
@@ -1636,6 +1645,7 @@ def _compute_loss(self, model, inputs):
compute_entropy=True,
pixel_values=inputs.get("pixel_values"),
image_grid_thw=inputs.get("image_grid_thw"),
+ num_images=inputs.get("num_images"),
pixel_attention_mask=inputs.get("pixel_attention_mask"),
image_sizes=inputs.get("image_sizes"),
)
@@ -1790,14 +1800,11 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non
"advantage": self._logs["advantages"],
}
- if self._logs["image"]:
- table["image"] = []
- for img in self._logs["image"]:
- if img is not None:
- # Convert images to wandb Image objects for proper visualization
- table["image"].append(wandb.Image(img))
- else:
- table["image"].append(None)
+ if self._logs["images"]:
+ table["images"] = []
+ for image_list in self._logs["images"]:
+ # Convert images to wandb Image objects for proper visualization
+ table["images"].append([wandb.Image(image) for image in image_list])
df = pd.DataFrame(table)
if self.wandb_log_unique_prompts:
diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py
index 56ffbfe7fea..359cb68e43b 100644
--- a/trl/trainer/rloo_trainer.py
+++ b/trl/trainer/rloo_trainer.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import copy
import inspect
import os
import re
@@ -536,7 +535,7 @@ def decode(example, tokenizer):
self.num_completions_to_print = args.num_completions_to_print
# Keep logs sized to the generation batch to record only outputs from the latest model update.
self._logs = {
- "image": deque(maxlen=args.generation_batch_size),
+ "images": deque(maxlen=args.generation_batch_size),
"prompt": deque(maxlen=args.generation_batch_size),
"completion": deque(maxlen=args.generation_batch_size),
"rewards": defaultdict(lambda: deque(maxlen=args.generation_batch_size)),
@@ -678,7 +677,7 @@ def _set_signature_columns_if_needed(self):
# In RLOOTrainer, we preprocess data, so using the model's signature columns doesn't work.
# Instead, we set them to the columns expected by the `training_step` method, hence the override.
if self._signature_columns is None:
- self._signature_columns = ["prompt", "image"]
+ self._signature_columns = ["prompt", "image", "images"]
# This method overrides `Trainer.get_train_dataloader` to support our custom batching strategy.
# Instead of returning a standard per-step batch (i.e., `per_device_batch_size), our dataloader loads an
@@ -775,6 +774,7 @@ def _get_per_token_logps_and_entropies(
compute_entropy=False,
pixel_values=None,
image_grid_thw=None,
+ num_images=None,
pixel_attention_mask=None,
image_sizes=None,
) -> dict[str, Optional[torch.Tensor]]:
@@ -790,10 +790,15 @@ def _get_per_token_logps_and_entropies(
model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch}
if image_grid_thw is not None and pixel_values is not None:
- model_inputs["image_grid_thw"] = image_grid_thw[start : start + batch_size]
- start_pixel_idx = image_grid_thw[:start].prod(-1).sum().item()
- end_pixel_idx = image_grid_thw[: start + batch_size].prod(-1).sum().item()
- model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx]
+ rows_per_image = image_grid_thw.prod(dim=-1)
+ rows_per_sample = torch.split(rows_per_image, num_images)
+ rows_per_sample = torch.stack([s.sum() for s in rows_per_sample])
+ cum_rows = torch.cat([torch.tensor([0], device=rows_per_sample.device), rows_per_sample.cumsum(0)])
+ row_start, row_end = cum_rows[start].item(), cum_rows[start + batch_size].item()
+ model_inputs["pixel_values"] = pixel_values[row_start:row_end]
+ cum_imgs = torch.tensor([0] + num_images).cumsum(0)
+ img_start, img_end = cum_imgs[start], cum_imgs[start + batch_size]
+ model_inputs["image_grid_thw"] = image_grid_thw[img_start:img_end]
elif pixel_values is not None:
model_inputs["pixel_values"] = pixel_values[start : start + batch_size]
if pixel_attention_mask is not None:
@@ -1051,22 +1056,22 @@ def _generate_and_score_completions(
prompts = [x["prompt"] for x in inputs]
- # We don't yet support visual reward models/function, so we keep a copy of the original text-only prompts for
- # later use in the reward computation. If images are present, we insert {"type": "image"} as required by the
- # VLM chat template.
- original_prompts = copy.deepcopy(prompts)
+ if "images" in inputs[0]:
+ images = [example.get("images") for example in inputs]
+ elif "image" in inputs[0]:
+ images = [[example.get("image")] if example.get("image") is not None else None for example in inputs]
+ else:
+ images = None
# If the prompts are conversational and the inputs contain images, we need to convert the prompts from
# [{"role": "user", "content": "What color is the sky?"}] to
# [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}]
kwargs = {}
- has_images = "image" in inputs[0]
- if has_images:
- images = [example.get("image") for example in inputs]
- kwargs = {"images": [[img] for img in images]}
- for prompt in prompts:
+ if images is not None:
+ kwargs = {"images": images}
+ for prompt, image_list in zip(prompts, images):
if isinstance(prompt, list): # i.e., when using conversational data
- prepare_multimodal_messages(prompt, num_images=1)
+ prepare_multimodal_messages(prompt, num_images=len(image_list))
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
@@ -1133,7 +1138,7 @@ def _generate_and_score_completions(
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
if self.vllm_mode == "server":
all_prompts_text = gather_object(prompts_text)
- if has_images:
+ if images is not None:
all_images = gather_object(images)
if self.accelerator.is_main_process:
@@ -1142,7 +1147,7 @@ def _generate_and_score_completions(
# prompt individually.
ordered_set_of_prompts = all_prompts_text[:: self.num_generations]
- if has_images:
+ if images is not None:
ordered_set_of_images = all_images[:: self.num_generations]
else:
ordered_set_of_images = None
@@ -1205,7 +1210,7 @@ def _generate_and_score_completions(
torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group)
all_prompts_text = [p for sublist in gathered_prompts for p in sublist]
- if has_images:
+ if images is not None:
gathered_images = [None for _ in range(self.vllm_tensor_parallel_size)]
torch.distributed.all_gather_object(gathered_images, images, group=self.tp_group)
all_images = [img for sublist in gathered_images for img in sublist]
@@ -1213,15 +1218,13 @@ def _generate_and_score_completions(
all_images = None
else:
all_prompts_text = prompts_text
- all_images = images if has_images else None
+ all_images = images
- if has_images and all_images:
+ if images is not None and all_images:
vllm_inputs = []
- for prompt, image in zip(all_prompts_text, all_images):
- if image is not None:
- vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image}})
- else:
- vllm_inputs.append(prompt)
+ for prompt, image_list in zip(all_prompts_text, all_images):
+ vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image_list}})
+
else:
vllm_inputs = all_prompts_text
@@ -1321,6 +1324,8 @@ def _generate_and_score_completions(
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
+ num_images = [len(img_list) for img_list in images] if images is not None else None
+
with torch.no_grad():
# Compute the per-token log probabilities for the current model
old_per_token_logps, _ = self._get_per_token_logps_and_entropies(
@@ -1331,6 +1336,7 @@ def _generate_and_score_completions(
batch_size,
pixel_values=prompt_inputs.get("pixel_values"),
image_grid_thw=prompt_inputs.get("image_grid_thw"),
+ num_images=num_images,
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
image_sizes=prompt_inputs.get("image_sizes"),
)
@@ -1347,6 +1353,7 @@ def _generate_and_score_completions(
batch_size=batch_size,
pixel_values=prompt_inputs.get("pixel_values"),
image_grid_thw=prompt_inputs.get("image_grid_thw"),
+ num_images=num_images,
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
image_sizes=prompt_inputs.get("image_sizes"),
)
@@ -1360,6 +1367,7 @@ def _generate_and_score_completions(
batch_size=batch_size,
pixel_values=prompt_inputs.get("pixel_values"),
image_grid_thw=prompt_inputs.get("image_grid_thw"),
+ num_images=num_images,
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
image_sizes=prompt_inputs.get("image_sizes"),
)
@@ -1379,7 +1387,7 @@ def _generate_and_score_completions(
# Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is
# important because rewards will be normalized per group, and completions are distributed. We will later slice
# rewards_per_func to extract each process's subset.
- rewards_per_func = self._calculate_rewards(inputs, original_prompts, completions, completion_ids_list)
+ rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list)
# Apply weights to each reward function's output and sum
rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)
@@ -1463,8 +1471,8 @@ def _generate_and_score_completions(
self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist())
self._logs["advantages"].extend(all_process_advantages.tolist())
- if has_images:
- self._logs["image"].extend(gather_object(images))
+ if images is not None:
+ self._logs["images"].extend(gather_object(images))
output = {
"prompt_ids": prompt_ids,
@@ -1482,6 +1490,8 @@ def _generate_and_score_completions(
output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"]
if "image_sizes" in prompt_inputs:
output["image_sizes"] = prompt_inputs["image_sizes"]
+ if images is not None:
+ output["num_images"] = num_images
return output
@profiling_decorator
@@ -1507,6 +1517,7 @@ def _compute_loss(self, model, inputs):
compute_entropy=True,
pixel_values=inputs.get("pixel_values"),
image_grid_thw=inputs.get("image_grid_thw"),
+ num_images=inputs.get("num_images"),
pixel_attention_mask=inputs.get("pixel_attention_mask"),
image_sizes=inputs.get("image_sizes"),
)
@@ -1588,14 +1599,11 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non
"advantage": self._logs["advantages"],
}
- if self._logs["image"]:
- table["image"] = []
- for img in self._logs["image"]:
- if img is not None:
- # Convert images to wandb Image objects for proper visualization
- table["image"].append(wandb.Image(img))
- else:
- table["image"].append(None)
+ if self._logs["images"]:
+ table["images"] = []
+ for image_list in self._logs["images"]:
+ # Convert images to wandb Image objects for proper visualization
+ table["images"].append([wandb.Image(image) for image in image_list])
df = pd.DataFrame(table)
if self.wandb_log_unique_prompts:
diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py
index 37612a423bd..337e9857c1a 100644
--- a/trl/trainer/utils.py
+++ b/trl/trainer/utils.py
@@ -19,6 +19,7 @@
from collections.abc import Sequence, Sized
from dataclasses import dataclass, field
from importlib.metadata import version
+from itertools import accumulate
from typing import Any, Literal, Optional, Union
import numpy as np
@@ -1780,20 +1781,23 @@ def identity(x):
def split_pixel_values_by_grid(batch: dict[str, torch.Tensor]) -> dict[str, Union[torch.Tensor, list[torch.Tensor]]]:
"""
- Splits `batch["pixel_values"]` into a list of tensors based on the product of each row in
- `batch["image_grid_thw"]`, while keeping other entries unchanged.
+ Splits `batch["pixel_values"]` into a list of tensors based on the product of each row in `batch["image_grid_thw"]`
+ and batch["num_images"] while keeping other entries unchanged.
"""
- if "image_grid_thw" not in batch or "pixel_values" not in batch:
+ if "image_grid_thw" not in batch or "pixel_values" not in batch or "num_images" not in batch:
return batch
- lengths = batch["image_grid_thw"].prod(-1).tolist() # [batch_size]
+ lengths = batch["image_grid_thw"].prod(-1).tolist() # [num_images]
pixel_values = batch["pixel_values"] # [total, feature_dim]
if sum(lengths) != pixel_values.size(0):
raise ValueError(f"Mismatch: sum(lengths) = {sum(lengths)} != pixel_values.size(0) = {pixel_values.size(0)}")
- split_values = list(torch.split(batch["pixel_values"], lengths, dim=0))
- return {**batch, "pixel_values": split_values}
+ boundaries = [0, *accumulate(batch["num_images"])] # [3, 4, 5] -> [0, 3, 7, 12]
+ sections = [sum(lengths[boundaries[i] : boundaries[i + 1]]) for i in range(len(batch["num_images"]))]
+ split_values = list(torch.split(batch["pixel_values"], sections, dim=0))
+ image_grid_thw = list(torch.split(batch["image_grid_thw"], batch["num_images"], dim=0))
+ return {**batch, "pixel_values": split_values, "image_grid_thw": image_grid_thw}
def unsplit_pixel_values_by_grid(batch: dict[str, Union[torch.Tensor, list[torch.Tensor]]]) -> dict[str, torch.Tensor]:
@@ -1802,12 +1806,16 @@ def unsplit_pixel_values_by_grid(batch: dict[str, Union[torch.Tensor, list[torch
tensor along the first dimension.
"""
pixel_values = batch.get("pixel_values")
-
if isinstance(pixel_values, list):
merged = torch.cat(pixel_values, dim=0)
- return {**batch, "pixel_values": merged}
- else:
- return batch
+ batch = {**batch, "pixel_values": merged}
+
+ image_grid_thw = batch.get("image_grid_thw")
+ if isinstance(image_grid_thw, list):
+ merged = torch.cat(image_grid_thw, dim=0)
+ batch = {**batch, "image_grid_thw": merged}
+
+ return batch
def truncate_with_protected_tokens(