From 552e899015c18c1a10a3b2ffe80eaa964d44afbb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 19 Sep 2025 20:57:51 +0000 Subject: [PATCH 01/29] Refactor image handling: replace `image_split_sizes` with `image_grid_thw` in GRPO and RLOO trainers; update `split_pixel_values_by_grid` to use `image_grid_thw` --- tests/test_utils.py | 6 +++--- trl/trainer/grpo_trainer.py | 20 ++------------------ trl/trainer/rloo_trainer.py | 20 ++------------------ trl/trainer/utils.py | 4 ++-- 4 files changed, 9 insertions(+), 41 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 60730d685d0..f036a897e1b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -873,7 +873,7 @@ def test_with_scalar(self): class SplitPixelValuesByGridTester(TrlTestCase): def test_split_correctly_0(self): batch = { - "image_split_sizes": [4, 4], + "image_grid_thw": torch.tensor([[1, 2, 2], [1, 2, 2]]), "pixel_values": torch.arange(8 * 3).reshape(8, 3), # Shape: [8, 3] } result = split_pixel_values_by_grid(batch) @@ -884,7 +884,7 @@ def test_split_correctly_0(self): def test_split_correctly_1(self): batch = { - "image_split_sizes": [4, 8], + "image_grid_thw": torch.tensor([[1, 2, 2], [1, 2, 4]]), "pixel_values": torch.arange(12 * 3).reshape(12, 3), # Shape: [12, 3] } result = split_pixel_values_by_grid(batch) @@ -900,7 +900,7 @@ def test_missing_keys(self): def test_mismatched_length(self): batch = { - "image_split_sizes": torch.tensor([2, 2]), # Total = 4 + "image_grid_thw": torch.tensor([[1, 1, 2], [1, 2, 1]]), # Total = 8 "pixel_values": torch.randn(3, 5), # Only 3 rows } with self.assertRaises(ValueError): diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 0c2ad9a3121..9f618eefe8f 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -791,7 +791,6 @@ def _get_per_token_logps_and_entropies( image_grid_thw=None, pixel_attention_mask=None, image_sizes=None, - image_split_sizes=None, ) -> dict[str, Optional[torch.Tensor]]: """Compute log-probs and (optionally) entropies for each token.""" batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak @@ -807,7 +806,8 @@ def _get_per_token_logps_and_entropies( if image_grid_thw is not None: model_inputs["image_grid_thw"] = image_grid_thw[start : start + batch_size] if pixel_values is not None: - if image_split_sizes is not None: + if image_grid_thw is not None: + image_split_sizes = image_grid_thw.prod(dim=1).tolist() start_pixel_idx = sum(image_split_sizes[:start]) end_pixel_idx = sum(image_split_sizes[: start + batch_size]) model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx] @@ -1078,7 +1078,6 @@ def _generate_and_score_completions( # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] kwargs = {} has_images = "image" in inputs[0] - image_split_sizes = None if has_images: images = [example.get("image") for example in inputs] kwargs = {"images": [[img] for img in images]} @@ -1086,11 +1085,6 @@ def _generate_and_score_completions( if isinstance(prompt, list): # i.e., when using conversational data prepare_multimodal_messages(prompt, num_images=1) - if hasattr(self.processing_class, "_get_num_multimodal_tokens"): - image_sizes = [(image.height, image.width) for image in images] - multimodal_extra_data = self.processing_class._get_num_multimodal_tokens(image_sizes) - image_split_sizes = multimodal_extra_data.num_image_patches - prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] prompt_inputs = self.processing_class( @@ -1104,10 +1098,6 @@ def _generate_and_score_completions( prompt_inputs = super()._prepare_inputs(prompt_inputs) prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] - if "image_grid_thw" in prompt_inputs and image_split_sizes is None: - # Fallback for VLMs that require image_grid_thw but don't provide _get_num_multimodal_tokens - image_split_sizes = prompt_inputs["image_grid_thw"].prod(dim=1).tolist() - if self.max_prompt_length is not None: # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens. # Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text, @@ -1392,7 +1382,6 @@ def _generate_and_score_completions( image_grid_thw=prompt_inputs.get("image_grid_thw"), pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), - image_split_sizes=image_split_sizes, ) else: old_per_token_logps = None @@ -1417,7 +1406,6 @@ def _generate_and_score_completions( image_grid_thw=prompt_inputs.get("image_grid_thw"), pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), - image_split_sizes=image_split_sizes, ) else: with self.accelerator.unwrap_model(self.model).disable_adapter(): @@ -1431,7 +1419,6 @@ def _generate_and_score_completions( image_grid_thw=prompt_inputs.get("image_grid_thw"), pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), - image_split_sizes=image_split_sizes, ) else: ref_per_token_logps = None @@ -1580,8 +1567,6 @@ def _generate_and_score_completions( output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"] if "image_sizes" in prompt_inputs: output["image_sizes"] = prompt_inputs["image_sizes"] - if image_split_sizes is not None: - output["image_split_sizes"] = image_split_sizes return output def compute_liger_loss(self, unwrapped_model, inputs): @@ -1656,7 +1641,6 @@ def _compute_loss(self, model, inputs): image_grid_thw=inputs.get("image_grid_thw"), pixel_attention_mask=inputs.get("pixel_attention_mask"), image_sizes=inputs.get("image_sizes"), - image_split_sizes=inputs.get("image_split_sizes"), ) if self.top_entropy_quantile < 1.0: diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 86aeb8910a0..b70c3b4db4f 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -777,7 +777,6 @@ def _get_per_token_logps_and_entropies( image_grid_thw=None, pixel_attention_mask=None, image_sizes=None, - image_split_sizes=None, ) -> dict[str, Optional[torch.Tensor]]: """Compute log-probs and (optionally) entropies for each token.""" batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak @@ -793,7 +792,8 @@ def _get_per_token_logps_and_entropies( if image_grid_thw is not None: model_inputs["image_grid_thw"] = image_grid_thw[start : start + batch_size] if pixel_values is not None: - if image_split_sizes is not None: + if image_grid_thw is not None: + image_split_sizes = image_grid_thw.prod(dim=1).tolist() start_pixel_idx = sum(image_split_sizes[:start]) end_pixel_idx = sum(image_split_sizes[: start + batch_size]) model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx] @@ -1064,7 +1064,6 @@ def _generate_and_score_completions( # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] kwargs = {} has_images = "image" in inputs[0] - image_split_sizes = None if has_images: images = [example.get("image") for example in inputs] kwargs = {"images": [[img] for img in images]} @@ -1072,11 +1071,6 @@ def _generate_and_score_completions( if isinstance(prompt, list): # i.e., when using conversational data prepare_multimodal_messages(prompt, num_images=1) - if hasattr(self.processing_class, "_get_num_multimodal_tokens"): - image_sizes = [(image.height, image.width) for image in images] - multimodal_extra_data = self.processing_class._get_num_multimodal_tokens(image_sizes) - image_split_sizes = multimodal_extra_data.num_image_patches - prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] prompt_inputs = self.processing_class( @@ -1090,10 +1084,6 @@ def _generate_and_score_completions( prompt_inputs = super()._prepare_inputs(prompt_inputs) prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] - if "image_grid_thw" in prompt_inputs and image_split_sizes is None: - # Fallback for VLMs that require image_grid_thw but don't provide _get_num_multimodal_tokens - image_split_sizes = prompt_inputs["image_grid_thw"].prod(dim=1).tolist() - if self.max_prompt_length is not None: # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens. # Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text, @@ -1346,7 +1336,6 @@ def _generate_and_score_completions( image_grid_thw=prompt_inputs.get("image_grid_thw"), pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), - image_split_sizes=image_split_sizes, ) old_logps = (old_per_token_logps * completion_mask).sum(1) # mask out padding and tokens after EOS @@ -1363,7 +1352,6 @@ def _generate_and_score_completions( image_grid_thw=prompt_inputs.get("image_grid_thw"), pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), - image_split_sizes=image_split_sizes, ) else: with self.accelerator.unwrap_model(self.model).disable_adapter(): @@ -1377,7 +1365,6 @@ def _generate_and_score_completions( image_grid_thw=prompt_inputs.get("image_grid_thw"), pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), - image_split_sizes=image_split_sizes, ) else: ref_per_token_logps = None @@ -1498,8 +1485,6 @@ def _generate_and_score_completions( output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"] if "image_sizes" in prompt_inputs: output["image_sizes"] = prompt_inputs["image_sizes"] - if image_split_sizes is not None: - output["image_split_sizes"] = image_split_sizes return output @profiling_decorator @@ -1527,7 +1512,6 @@ def _compute_loss(self, model, inputs): image_grid_thw=inputs.get("image_grid_thw"), pixel_attention_mask=inputs.get("pixel_attention_mask"), image_sizes=inputs.get("image_sizes"), - image_split_sizes=inputs.get("image_split_sizes"), ) logps = (per_token_logps * completion_mask).sum(1) # mask out padding and tokens after EOS diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 16ce8321612..37612a423bd 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -1783,10 +1783,10 @@ def split_pixel_values_by_grid(batch: dict[str, torch.Tensor]) -> dict[str, Unio Splits `batch["pixel_values"]` into a list of tensors based on the product of each row in `batch["image_grid_thw"]`, while keeping other entries unchanged. """ - if "image_split_sizes" not in batch or "pixel_values" not in batch: + if "image_grid_thw" not in batch or "pixel_values" not in batch: return batch - lengths = batch["image_split_sizes"] # [batch_size] + lengths = batch["image_grid_thw"].prod(-1).tolist() # [batch_size] pixel_values = batch["pixel_values"] # [total, feature_dim] if sum(lengths) != pixel_values.size(0): From 449ef079191ed50fae281c07b9d6775efcb345f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 19 Sep 2025 21:05:47 +0000 Subject: [PATCH 02/29] simpler --- trl/trainer/grpo_trainer.py | 15 ++++++--------- trl/trainer/rloo_trainer.py | 15 ++++++--------- 2 files changed, 12 insertions(+), 18 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 9f618eefe8f..bb902445d09 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -803,16 +803,13 @@ def _get_per_token_logps_and_entropies( # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} - if image_grid_thw is not None: + if image_grid_thw is not None and pixel_values is not None: model_inputs["image_grid_thw"] = image_grid_thw[start : start + batch_size] - if pixel_values is not None: - if image_grid_thw is not None: - image_split_sizes = image_grid_thw.prod(dim=1).tolist() - start_pixel_idx = sum(image_split_sizes[:start]) - end_pixel_idx = sum(image_split_sizes[: start + batch_size]) - model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx] - else: - model_inputs["pixel_values"] = pixel_values[start : start + batch_size] + start_pixel_idx = image_grid_thw[:start].prod(-1).sum().item() + end_pixel_idx = image_grid_thw[: start + batch_size].prod(-1).sum().item() + model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx] + elif pixel_values is not None: + model_inputs["pixel_values"] = pixel_values[start : start + batch_size] if pixel_attention_mask is not None: model_inputs["pixel_attention_mask"] = pixel_attention_mask[start : start + batch_size] if image_sizes is not None: diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index b70c3b4db4f..56ffbfe7fea 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -789,16 +789,13 @@ def _get_per_token_logps_and_entropies( # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} - if image_grid_thw is not None: + if image_grid_thw is not None and pixel_values is not None: model_inputs["image_grid_thw"] = image_grid_thw[start : start + batch_size] - if pixel_values is not None: - if image_grid_thw is not None: - image_split_sizes = image_grid_thw.prod(dim=1).tolist() - start_pixel_idx = sum(image_split_sizes[:start]) - end_pixel_idx = sum(image_split_sizes[: start + batch_size]) - model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx] - else: - model_inputs["pixel_values"] = pixel_values[start : start + batch_size] + start_pixel_idx = image_grid_thw[:start].prod(-1).sum().item() + end_pixel_idx = image_grid_thw[: start + batch_size].prod(-1).sum().item() + model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx] + elif pixel_values is not None: + model_inputs["pixel_values"] = pixel_values[start : start + batch_size] if pixel_attention_mask is not None: model_inputs["pixel_attention_mask"] = pixel_attention_mask[start : start + batch_size] if image_sizes is not None: From c8933aa856b2b71d10470456356c48bae4aefa17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 19 Sep 2025 21:10:06 +0000 Subject: [PATCH 03/29] gfpo --- trl/experimental/gfpo/gfpo_trainer.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/trl/experimental/gfpo/gfpo_trainer.py b/trl/experimental/gfpo/gfpo_trainer.py index b067e7410c7..f2a675fab16 100644 --- a/trl/experimental/gfpo/gfpo_trainer.py +++ b/trl/experimental/gfpo/gfpo_trainer.py @@ -93,7 +93,6 @@ def _generate_and_score_completions(self, inputs): # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] kwargs = {} has_images = "image" in inputs[0] - image_split_sizes = None if has_images: images = [example.get("image") for example in inputs] kwargs = {"images": [[img] for img in images]} @@ -101,11 +100,6 @@ def _generate_and_score_completions(self, inputs): if isinstance(prompt, list): # i.e., when using conversational data prepare_multimodal_messages(prompt, num_images=1) - if hasattr(self.processing_class, "_get_num_multimodal_tokens"): - image_sizes = [(image.height, image.width) for image in images] - multimodal_extra_data = self.processing_class._get_num_multimodal_tokens(image_sizes) - image_split_sizes = multimodal_extra_data.num_image_patches - prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] prompt_inputs = self.processing_class( @@ -116,13 +110,9 @@ def _generate_and_score_completions(self, inputs): add_special_tokens=False, **kwargs, ) - prompt_inputs = super(_GRPOTrainer, self)._prepare_inputs(prompt_inputs) + prompt_inputs = super()._prepare_inputs(prompt_inputs) prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] - if "image_grid_thw" in prompt_inputs and image_split_sizes is None: - # Fallback for VLMs that require image_grid_thw but don't provide _get_num_multimodal_tokens - image_split_sizes = prompt_inputs["image_grid_thw"].prod(dim=1).tolist() - if self.max_prompt_length is not None: # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens. # Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text, @@ -407,7 +397,6 @@ def _generate_and_score_completions(self, inputs): image_grid_thw=prompt_inputs.get("image_grid_thw"), pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), - image_split_sizes=image_split_sizes, ) else: old_per_token_logps = None @@ -432,7 +421,6 @@ def _generate_and_score_completions(self, inputs): image_grid_thw=prompt_inputs.get("image_grid_thw"), pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), - image_split_sizes=image_split_sizes, ) else: with self.accelerator.unwrap_model(self.model).disable_adapter(): @@ -446,7 +434,6 @@ def _generate_and_score_completions(self, inputs): image_grid_thw=prompt_inputs.get("image_grid_thw"), pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), - image_split_sizes=image_split_sizes, ) else: ref_per_token_logps = None @@ -652,6 +639,4 @@ def _generate_and_score_completions(self, inputs): output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"] if "image_sizes" in prompt_inputs: output["image_sizes"] = prompt_inputs["image_sizes"] - if image_split_sizes is not None: - output["image_split_sizes"] = image_split_sizes return output From 229c5549291b65c59537717893b7b09ad1cec0e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 19 Sep 2025 22:45:57 +0000 Subject: [PATCH 04/29] multi-image grpo --- tests/test_utils.py | 27 +++++++++++++++++ trl/trainer/grpo_trainer.py | 60 +++++++++++++++++++------------------ trl/trainer/utils.py | 16 ++++++---- 3 files changed, 68 insertions(+), 35 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index f036a897e1b..0fc16682336 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -874,6 +874,7 @@ class SplitPixelValuesByGridTester(TrlTestCase): def test_split_correctly_0(self): batch = { "image_grid_thw": torch.tensor([[1, 2, 2], [1, 2, 2]]), + "num_images": [1, 1], "pixel_values": torch.arange(8 * 3).reshape(8, 3), # Shape: [8, 3] } result = split_pixel_values_by_grid(batch) @@ -881,10 +882,15 @@ def test_split_correctly_0(self): self.assertEqual(len(result["pixel_values"]), 2) self.assertTrue(torch.equal(result["pixel_values"][0], batch["pixel_values"][:4])) self.assertTrue(torch.equal(result["pixel_values"][1], batch["pixel_values"][4:])) + self.assertIsInstance(result["image_grid_thw"], list) + self.assertEqual(len(result["image_grid_thw"]), 2) + self.assertTrue(torch.equal(result["image_grid_thw"][0], torch.tensor([[1, 2, 2]]))) + self.assertTrue(torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 2]]))) def test_split_correctly_1(self): batch = { "image_grid_thw": torch.tensor([[1, 2, 2], [1, 2, 4]]), + "num_images": [1, 1], "pixel_values": torch.arange(12 * 3).reshape(12, 3), # Shape: [12, 3] } result = split_pixel_values_by_grid(batch) @@ -892,6 +898,10 @@ def test_split_correctly_1(self): self.assertEqual(len(result["pixel_values"]), 2) self.assertTrue(torch.equal(result["pixel_values"][0], batch["pixel_values"][:4])) self.assertTrue(torch.equal(result["pixel_values"][1], batch["pixel_values"][4:12])) + self.assertIsInstance(result["image_grid_thw"], list) + self.assertEqual(len(result["image_grid_thw"]), 2) + self.assertTrue(torch.equal(result["image_grid_thw"][0], torch.tensor([[1, 2, 2]]))) + self.assertTrue(torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 4]]))) def test_missing_keys(self): batch = {"pixel_values": torch.tensor([1.0])} @@ -901,11 +911,28 @@ def test_missing_keys(self): def test_mismatched_length(self): batch = { "image_grid_thw": torch.tensor([[1, 1, 2], [1, 2, 1]]), # Total = 8 + "num_images": [1, 1], "pixel_values": torch.randn(3, 5), # Only 3 rows } with self.assertRaises(ValueError): split_pixel_values_by_grid(batch) + def test_multi_images(self): + batch = { + "image_grid_thw": torch.tensor([[1, 1, 2], [1, 2, 2], [1, 2, 1]]), # Total = 8 + "num_images": [1, 2], + "pixel_values": torch.arange(8 * 3).reshape(8, 3), # Shape: [8, 3] + } + result = split_pixel_values_by_grid(batch) + self.assertIsInstance(result["pixel_values"], list) + self.assertEqual(len(result["pixel_values"]), 2) + self.assertTrue(torch.equal(result["pixel_values"][0], batch["pixel_values"][:2])) + self.assertTrue(torch.equal(result["pixel_values"][1], batch["pixel_values"][2:])) + self.assertIsInstance(result["image_grid_thw"], list) + self.assertEqual(len(result["image_grid_thw"]), 2) + self.assertTrue(torch.equal(result["image_grid_thw"][0], torch.tensor([[1, 1, 2]]))) + self.assertTrue(torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 2], [1, 2, 1]]))) + class TruncateWithProtectedTokensTester(TrlTestCase): def test_basic_example(self): diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index bb902445d09..f98d895fb18 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -464,7 +464,7 @@ def __init__( self.num_completions_to_print = args.num_completions_to_print # Keep logs sized to the generation batch to record only outputs from the latest model update. self._logs = { - "image": deque(maxlen=args.generation_batch_size), + "images": deque(maxlen=args.generation_batch_size), "prompt": deque(maxlen=args.generation_batch_size), "completion": deque(maxlen=args.generation_batch_size), "rewards": defaultdict(lambda: deque(maxlen=args.generation_batch_size)), @@ -609,7 +609,7 @@ def _set_signature_columns_if_needed(self): # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work. # Instead, we set them to the columns expected by the `training_step` method, hence the override. if self._signature_columns is None: - self._signature_columns = ["prompt", "image"] + self._signature_columns = ["prompt", "image", "images"] # This method overrides `Trainer.get_train_dataloader` to support our custom batching strategy. # Instead of returning a standard per-step batch (i.e., `per_device_batch_size), our dataloader loads an @@ -804,9 +804,9 @@ def _get_per_token_logps_and_entropies( model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} if image_grid_thw is not None and pixel_values is not None: - model_inputs["image_grid_thw"] = image_grid_thw[start : start + batch_size] - start_pixel_idx = image_grid_thw[:start].prod(-1).sum().item() - end_pixel_idx = image_grid_thw[: start + batch_size].prod(-1).sum().item() + model_inputs["image_grid_thw"] = torch.cat(image_grid_thw[start : start + batch_size]) + start_pixel_idx = 0 if start == 0 else torch.cat(image_grid_thw[:start]).prod(-1).sum().item() + end_pixel_idx = torch.cat(image_grid_thw[: start + batch_size]).prod(-1).sum().item() model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx] elif pixel_values is not None: model_inputs["pixel_values"] = pixel_values[start : start + batch_size] @@ -1070,14 +1070,19 @@ def _generate_and_score_completions( # VLM chat template. original_prompts = copy.deepcopy(prompts) + if "images" in inputs[0]: + images = [example.get("images") for example in inputs] + elif "image" in inputs[0]: + images = [[example.get("image")] if example.get("image") is not None else None for example in inputs] + else: + images = None + # If the prompts are conversational and the inputs contain images, we need to convert the prompts from # [{"role": "user", "content": "What color is the sky?"}] to # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] kwargs = {} - has_images = "image" in inputs[0] - if has_images: - images = [example.get("image") for example in inputs] - kwargs = {"images": [[img] for img in images]} + if images is not None: + kwargs = {"images": images} for prompt in prompts: if isinstance(prompt, list): # i.e., when using conversational data prepare_multimodal_messages(prompt, num_images=1) @@ -1152,7 +1157,7 @@ def _generate_and_score_completions( # Generate completions using vLLM: gather all prompts and use them in a single call in the main process if self.vllm_mode == "server": all_prompts_text = gather_object(prompts_text) - if has_images: + if images is not None: all_images = gather_object(images) if self.accelerator.is_main_process: @@ -1161,7 +1166,7 @@ def _generate_and_score_completions( # prompt individually. ordered_set_of_prompts = all_prompts_text[:: self.num_generations] - if has_images: + if images is not None: ordered_set_of_images = all_images[:: self.num_generations] else: ordered_set_of_images = None @@ -1226,7 +1231,7 @@ def _generate_and_score_completions( torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group) all_prompts_text = [p for sublist in gathered_prompts for p in sublist] - if has_images: + if images is not None: gathered_images = [None for _ in range(self.vllm_tensor_parallel_size)] torch.distributed.all_gather_object(gathered_images, images, group=self.tp_group) all_images = [img for sublist in gathered_images for img in sublist] @@ -1234,15 +1239,13 @@ def _generate_and_score_completions( all_images = None else: all_prompts_text = prompts_text - all_images = images if has_images else None + all_images = images - if has_images and all_images: + if images is not None and all_images: vllm_inputs = [] - for prompt, image in zip(all_prompts_text, all_images): - if image is not None: - vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image}}) - else: - vllm_inputs.append(prompt) + for prompt, image_list in zip(all_prompts_text, all_images): + vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image_list}}) + else: vllm_inputs = all_prompts_text @@ -1507,8 +1510,8 @@ def _generate_and_score_completions( self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist()) self._logs["advantages"].extend(all_process_advantages.tolist()) - if has_images: - self._logs["image"].extend(gather_object(images)) + if images is not None: + self._logs["images"].extend(gather_object(images)) if self.use_vllm and self.vllm_importance_sampling_correction: delta = torch.abs(old_per_token_logps - sampling_per_token_logps) @@ -1564,6 +1567,8 @@ def _generate_and_score_completions( output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"] if "image_sizes" in prompt_inputs: output["image_sizes"] = prompt_inputs["image_sizes"] + if images is not None: + output["num_images"] = [len(img_list) if img_list is not None else 0 for img_list in images] return output def compute_liger_loss(self, unwrapped_model, inputs): @@ -1790,14 +1795,11 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non "advantage": self._logs["advantages"], } - if self._logs["image"]: - table["image"] = [] - for img in self._logs["image"]: - if img is not None: - # Convert images to wandb Image objects for proper visualization - table["image"].append(wandb.Image(img)) - else: - table["image"].append(None) + if self._logs["images"]: + table["images"] = [] + for img in self._logs["images"]: + # Convert images to wandb Image objects for proper visualization + table["images"].append(wandb.Image(img)) df = pd.DataFrame(table) if self.wandb_log_unique_prompts: diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 37612a423bd..7cd16472c16 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -19,6 +19,7 @@ from collections.abc import Sequence, Sized from dataclasses import dataclass, field from importlib.metadata import version +from itertools import accumulate from typing import Any, Literal, Optional, Union import numpy as np @@ -1780,20 +1781,23 @@ def identity(x): def split_pixel_values_by_grid(batch: dict[str, torch.Tensor]) -> dict[str, Union[torch.Tensor, list[torch.Tensor]]]: """ - Splits `batch["pixel_values"]` into a list of tensors based on the product of each row in - `batch["image_grid_thw"]`, while keeping other entries unchanged. + Splits `batch["pixel_values"]` into a list of tensors based on the product of each row in `batch["image_grid_thw"]` + and batch["num_images"] while keeping other entries unchanged. """ - if "image_grid_thw" not in batch or "pixel_values" not in batch: + if "image_grid_thw" not in batch or "pixel_values" not in batch or "num_images" not in batch: return batch - lengths = batch["image_grid_thw"].prod(-1).tolist() # [batch_size] + lengths = batch["image_grid_thw"].prod(-1).tolist() # [num_images] pixel_values = batch["pixel_values"] # [total, feature_dim] if sum(lengths) != pixel_values.size(0): raise ValueError(f"Mismatch: sum(lengths) = {sum(lengths)} != pixel_values.size(0) = {pixel_values.size(0)}") - split_values = list(torch.split(batch["pixel_values"], lengths, dim=0)) - return {**batch, "pixel_values": split_values} + boundaries = [0, *accumulate(batch["num_images"])] # [3, 4, 5] -> [0, 3, 7, 12] + sections = [sum(lengths[boundaries[i] : boundaries[i + 1]]) for i in range(len(batch["num_images"]))] + split_values = list(torch.split(batch["pixel_values"], sections, dim=0)) + image_grid_thw = list(torch.split(batch["image_grid_thw"], batch["num_images"], dim=0)) + return {**batch, "pixel_values": split_values, "image_grid_thw": image_grid_thw} def unsplit_pixel_values_by_grid(batch: dict[str, Union[torch.Tensor, list[torch.Tensor]]]) -> dict[str, torch.Tensor]: From 3ca6ad50036aba363f8a87e7c227efceab7b4496 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 19 Sep 2025 23:31:06 +0000 Subject: [PATCH 05/29] log with wandb --- trl/trainer/grpo_trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index f98d895fb18..eb9e9bfcd9f 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1083,9 +1083,9 @@ def _generate_and_score_completions( kwargs = {} if images is not None: kwargs = {"images": images} - for prompt in prompts: + for prompt, image_list in zip(prompts, images): if isinstance(prompt, list): # i.e., when using conversational data - prepare_multimodal_messages(prompt, num_images=1) + prepare_multimodal_messages(prompt, num_images=len(image_list)) prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] @@ -1797,9 +1797,9 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non if self._logs["images"]: table["images"] = [] - for img in self._logs["images"]: + for image_list in self._logs["images"]: # Convert images to wandb Image objects for proper visualization - table["images"].append(wandb.Image(img)) + table["images"].append([wandb.Image(image) for image in image_list]) df = pd.DataFrame(table) if self.wandb_log_unique_prompts: From dcf4b92da0085d2d94f3a8ca5bb9e1b18e20b86a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 20 Sep 2025 00:18:18 +0000 Subject: [PATCH 06/29] no vlm reward models --- tests/test_grpo_trainer.py | 92 ++++++++++++++++++++++++++++++++++--- trl/trainer/grpo_trainer.py | 16 ++++--- 2 files changed, 94 insertions(+), 14 deletions(-) diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 4e4321febcd..ced4de9d73a 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -1258,6 +1258,10 @@ def test_prepare_input_called_with_correct_data(self): def test_training_vlm(self, model_id): dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + training_args = GRPOConfig( output_dir=self.tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test @@ -1269,7 +1273,7 @@ def test_training_vlm(self, model_id): ) trainer = GRPOTrainer( model=model_id, - reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + reward_funcs=reward_func, args=training_args, train_dataset=dataset, ) @@ -1301,6 +1305,10 @@ def test_training_vlm(self, model_id): def test_training_vlm_beta_non_zero(self): dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + training_args = GRPOConfig( output_dir=self.tmp_dir, beta=0.1, # set beta to non-zero value to test the case where the reference model is used @@ -1312,7 +1320,7 @@ def test_training_vlm_beta_non_zero(self): ) trainer = GRPOTrainer( model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", - reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + reward_funcs=reward_func, args=training_args, train_dataset=dataset, ) @@ -1342,6 +1350,10 @@ def test_training_vlm_peft(self): base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + training_args = GRPOConfig( output_dir=self.tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test @@ -1352,7 +1364,7 @@ def test_training_vlm_peft(self): ) trainer = GRPOTrainer( model=model, - reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + reward_funcs=reward_func, args=training_args, train_dataset=dataset, peft_config=LoraConfig(target_modules=["q_proj", "v_proj"]), @@ -1376,6 +1388,10 @@ def test_training_vlm_peft(self): def test_training_vlm_and_importance_sampling(self): dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + training_args = GRPOConfig( output_dir=self.tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test @@ -1387,7 +1403,7 @@ def test_training_vlm_and_importance_sampling(self): ) trainer = GRPOTrainer( model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", - reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + reward_funcs=reward_func, args=training_args, train_dataset=dataset, ) @@ -1413,6 +1429,10 @@ def test_training_vlm_and_importance_sampling(self): def test_training_vlm_and_liger(self): dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + training_args = GRPOConfig( output_dir=self.tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test @@ -1425,7 +1445,7 @@ def test_training_vlm_and_liger(self): ) trainer = GRPOTrainer( model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", - reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + reward_funcs=reward_func, args=training_args, train_dataset=dataset, ) @@ -1451,6 +1471,10 @@ def test_training_vlm_and_prompt_truncation(self): # If not handled properly, prompt truncation may truncate image token dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + training_args = GRPOConfig( output_dir=self.tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test @@ -1462,7 +1486,7 @@ def test_training_vlm_and_prompt_truncation(self): ) trainer = GRPOTrainer( model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", - reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + reward_funcs=reward_func, args=training_args, train_dataset=dataset, ) @@ -1495,6 +1519,10 @@ def test_training_vlm_and_prompt_truncation(self): def test_training_vlm_and_vllm(self, model_id) -> None: dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + training_args = GRPOConfig( output_dir=self.tmp_dir, learning_rate=0.1, @@ -1508,7 +1536,44 @@ def test_training_vlm_and_vllm(self, model_id) -> None: ) trainer = GRPOTrainer( model=model_id, - reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + reward_funcs=reward_func, + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + @require_vision + def test_training_vlm_multi_image(self): + dataset = load_dataset("trl-internal-testing/zen-multi-image", "conversational_prompt_only", split="train") + + # For now, mixing image+text and text-only examples is not supported, so we filter out text-only examples + dataset = dataset.filter(lambda x: len(x["images"]) > 0) + + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + max_prompt_length=None, # disable prompt truncation, because usually, models don't support it + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", + reward_funcs=reward_func, args=training_args, train_dataset=dataset, ) @@ -1519,7 +1584,20 @@ def test_training_vlm_and_vllm(self, model_id) -> None: self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + # Check that the params have changed + # Because of the way the tiny models are initialized, the gradient does not flow properly through the + # vision parts of the model, so we skip them. Ideally, we should fix the init of these models. + params_to_skip = ( + # "model.vision_tower.", + # "model.multi_modal_projector.", + # "model.vision_model.", + # "model.connector.modality_projection.", + # "model.visual.", + # "model.image_newline", + ) for n, param in previous_trainable_params.items(): + if n.startswith(params_to_skip): + continue new_param = trainer.model.get_parameter(n) self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index eb9e9bfcd9f..59135c232e6 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import inspect import os import re @@ -1020,6 +1019,14 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): with profiling_context(self, reward_func_name): if isinstance(reward_func, nn.Module): # Module (no PretrainedModel) for compat with compiled models if is_conversational(inputs[0]): + # VLM reward models aren't supported yet, so we drop the image and raise a warning if needed + for prompt in prompts: + for turn in prompt: + if isinstance(turn["content"], list): + logger.warning_once("Visual reward models aren't supported yet; dropping image.") + turn["content"] = " ".join( + e["text"] for e in turn["content"] if e["type"] == "text" + ) messages = [{"messages": p + c} for p, c in zip(prompts, completions)] texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages] else: @@ -1065,11 +1072,6 @@ def _generate_and_score_completions( prompts = [x["prompt"] for x in inputs] - # We don't yet support visual reward models/function, so we keep a copy of the original text-only prompts for - # later use in the reward computation. If images are present, we insert {"type": "image"} as required by the - # VLM chat template. - original_prompts = copy.deepcopy(prompts) - if "images" in inputs[0]: images = [example.get("images") for example in inputs] elif "image" in inputs[0]: @@ -1436,7 +1438,7 @@ def _generate_and_score_completions( # Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is # important because rewards will be normalized per group, and completions are distributed. We will later slice # rewards_per_func to extract each process's subset. - rewards_per_func = self._calculate_rewards(inputs, original_prompts, completions, completion_ids_list) + rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list) # Apply weights to each reward function's output and sum rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) From 30ad7ca371286e2d55998d285ae66a3f123eee83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 20 Sep 2025 00:37:54 +0000 Subject: [PATCH 07/29] rloo --- trl/trainer/rloo_trainer.py | 78 +++++++++++++++++++------------------ 1 file changed, 41 insertions(+), 37 deletions(-) diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 56ffbfe7fea..48801496b01 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import inspect import os import re @@ -536,7 +535,7 @@ def decode(example, tokenizer): self.num_completions_to_print = args.num_completions_to_print # Keep logs sized to the generation batch to record only outputs from the latest model update. self._logs = { - "image": deque(maxlen=args.generation_batch_size), + "images": deque(maxlen=args.generation_batch_size), "prompt": deque(maxlen=args.generation_batch_size), "completion": deque(maxlen=args.generation_batch_size), "rewards": defaultdict(lambda: deque(maxlen=args.generation_batch_size)), @@ -678,7 +677,7 @@ def _set_signature_columns_if_needed(self): # In RLOOTrainer, we preprocess data, so using the model's signature columns doesn't work. # Instead, we set them to the columns expected by the `training_step` method, hence the override. if self._signature_columns is None: - self._signature_columns = ["prompt", "image"] + self._signature_columns = ["prompt", "image", "images"] # This method overrides `Trainer.get_train_dataloader` to support our custom batching strategy. # Instead of returning a standard per-step batch (i.e., `per_device_batch_size), our dataloader loads an @@ -790,9 +789,9 @@ def _get_per_token_logps_and_entropies( model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} if image_grid_thw is not None and pixel_values is not None: - model_inputs["image_grid_thw"] = image_grid_thw[start : start + batch_size] - start_pixel_idx = image_grid_thw[:start].prod(-1).sum().item() - end_pixel_idx = image_grid_thw[: start + batch_size].prod(-1).sum().item() + model_inputs["image_grid_thw"] = torch.cat(image_grid_thw[start : start + batch_size]) + start_pixel_idx = 0 if start == 0 else torch.cat(image_grid_thw[:start]).prod(-1).sum().item() + end_pixel_idx = torch.cat(image_grid_thw[: start + batch_size]).prod(-1).sum().item() model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx] elif pixel_values is not None: model_inputs["pixel_values"] = pixel_values[start : start + batch_size] @@ -1006,6 +1005,14 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): with profiling_context(self, reward_func_name): if isinstance(reward_func, nn.Module): # Module (no PretrainedModel) for compat with compiled models if is_conversational(inputs[0]): + # VLM reward models aren't supported yet, so we drop the image and raise a warning if needed + for prompt in prompts: + for turn in prompt: + if isinstance(turn["content"], list): + logger.warning_once("Visual reward models aren't supported yet; dropping image.") + turn["content"] = " ".join( + e["text"] for e in turn["content"] if e["type"] == "text" + ) messages = [{"messages": p + c} for p, c in zip(prompts, completions)] texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages] else: @@ -1051,22 +1058,22 @@ def _generate_and_score_completions( prompts = [x["prompt"] for x in inputs] - # We don't yet support visual reward models/function, so we keep a copy of the original text-only prompts for - # later use in the reward computation. If images are present, we insert {"type": "image"} as required by the - # VLM chat template. - original_prompts = copy.deepcopy(prompts) + if "images" in inputs[0]: + images = [example.get("images") for example in inputs] + elif "image" in inputs[0]: + images = [[example.get("image")] if example.get("image") is not None else None for example in inputs] + else: + images = None # If the prompts are conversational and the inputs contain images, we need to convert the prompts from # [{"role": "user", "content": "What color is the sky?"}] to # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] kwargs = {} - has_images = "image" in inputs[0] - if has_images: - images = [example.get("image") for example in inputs] - kwargs = {"images": [[img] for img in images]} - for prompt in prompts: + if images is not None: + kwargs = {"images": images} + for prompt, image_list in zip(prompts, images): if isinstance(prompt, list): # i.e., when using conversational data - prepare_multimodal_messages(prompt, num_images=1) + prepare_multimodal_messages(prompt, num_images=len(image_list)) prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] @@ -1133,7 +1140,7 @@ def _generate_and_score_completions( # Generate completions using vLLM: gather all prompts and use them in a single call in the main process if self.vllm_mode == "server": all_prompts_text = gather_object(prompts_text) - if has_images: + if images is not None: all_images = gather_object(images) if self.accelerator.is_main_process: @@ -1142,7 +1149,7 @@ def _generate_and_score_completions( # prompt individually. ordered_set_of_prompts = all_prompts_text[:: self.num_generations] - if has_images: + if images is not None: ordered_set_of_images = all_images[:: self.num_generations] else: ordered_set_of_images = None @@ -1205,7 +1212,7 @@ def _generate_and_score_completions( torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group) all_prompts_text = [p for sublist in gathered_prompts for p in sublist] - if has_images: + if images is not None: gathered_images = [None for _ in range(self.vllm_tensor_parallel_size)] torch.distributed.all_gather_object(gathered_images, images, group=self.tp_group) all_images = [img for sublist in gathered_images for img in sublist] @@ -1213,15 +1220,13 @@ def _generate_and_score_completions( all_images = None else: all_prompts_text = prompts_text - all_images = images if has_images else None + all_images = images - if has_images and all_images: + if images is not None and all_images: vllm_inputs = [] - for prompt, image in zip(all_prompts_text, all_images): - if image is not None: - vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image}}) - else: - vllm_inputs.append(prompt) + for prompt, image_list in zip(all_prompts_text, all_images): + vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image_list}}) + else: vllm_inputs = all_prompts_text @@ -1379,7 +1384,7 @@ def _generate_and_score_completions( # Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is # important because rewards will be normalized per group, and completions are distributed. We will later slice # rewards_per_func to extract each process's subset. - rewards_per_func = self._calculate_rewards(inputs, original_prompts, completions, completion_ids_list) + rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list) # Apply weights to each reward function's output and sum rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) @@ -1463,8 +1468,8 @@ def _generate_and_score_completions( self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist()) self._logs["advantages"].extend(all_process_advantages.tolist()) - if has_images: - self._logs["image"].extend(gather_object(images)) + if images is not None: + self._logs["images"].extend(gather_object(images)) output = { "prompt_ids": prompt_ids, @@ -1482,6 +1487,8 @@ def _generate_and_score_completions( output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"] if "image_sizes" in prompt_inputs: output["image_sizes"] = prompt_inputs["image_sizes"] + if images is not None: + output["num_images"] = [len(img_list) if img_list is not None else 0 for img_list in images] return output @profiling_decorator @@ -1588,14 +1595,11 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non "advantage": self._logs["advantages"], } - if self._logs["image"]: - table["image"] = [] - for img in self._logs["image"]: - if img is not None: - # Convert images to wandb Image objects for proper visualization - table["image"].append(wandb.Image(img)) - else: - table["image"].append(None) + if self._logs["images"]: + table["images"] = [] + for image_list in self._logs["images"]: + # Convert images to wandb Image objects for proper visualization + table["images"].append([wandb.Image(image) for image in image_list]) df = pd.DataFrame(table) if self.wandb_log_unique_prompts: From 86cc30bf3c307eebcd9223ec7db8bcc8784573b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 20 Sep 2025 00:43:43 +0000 Subject: [PATCH 08/29] gfpo --- trl/experimental/gfpo/gfpo_trainer.py | 50 +++++++++++++-------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/trl/experimental/gfpo/gfpo_trainer.py b/trl/experimental/gfpo/gfpo_trainer.py index f2a675fab16..af83e076dd9 100644 --- a/trl/experimental/gfpo/gfpo_trainer.py +++ b/trl/experimental/gfpo/gfpo_trainer.py @@ -83,22 +83,22 @@ def _generate_and_score_completions(self, inputs): prompts = [x["prompt"] for x in inputs] - # We don't yet support visual reward models/function, so we keep a copy of the original text-only prompts for - # later use in the reward computation. If images are present, we insert {"type": "image"} as required by the - # VLM chat template. - original_prompts = copy.deepcopy(prompts) + if "images" in inputs[0]: + images = [example.get("images") for example in inputs] + elif "image" in inputs[0]: + images = [[example.get("image")] if example.get("image") is not None else None for example in inputs] + else: + images = None # If the prompts are conversational and the inputs contain images, we need to convert the prompts from # [{"role": "user", "content": "What color is the sky?"}] to # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] kwargs = {} - has_images = "image" in inputs[0] - if has_images: - images = [example.get("image") for example in inputs] - kwargs = {"images": [[img] for img in images]} - for prompt in prompts: + if images is not None: + kwargs = {"images": images} + for prompt, image_list in zip(prompts, images): if isinstance(prompt, list): # i.e., when using conversational data - prepare_multimodal_messages(prompt, num_images=1) + prepare_multimodal_messages(prompt, num_images=len(image_list)) prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] @@ -170,7 +170,7 @@ def _generate_and_score_completions(self, inputs): # Generate completions using vLLM: gather all prompts and use them in a single call in the main process if self.vllm_mode == "server": all_prompts_text = gather_object(prompts_text) - if has_images: + if images is not None: all_images = gather_object(images) if self.accelerator.is_main_process: @@ -179,7 +179,7 @@ def _generate_and_score_completions(self, inputs): # prompt individually. ordered_set_of_prompts = all_prompts_text[:: self.num_generations] - if has_images: + if images is not None: ordered_set_of_images = all_images[:: self.num_generations] else: ordered_set_of_images = None @@ -244,7 +244,7 @@ def _generate_and_score_completions(self, inputs): torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group) all_prompts_text = [p for sublist in gathered_prompts for p in sublist] - if has_images: + if images is not None: gathered_images = [None for _ in range(self.vllm_tensor_parallel_size)] torch.distributed.all_gather_object(gathered_images, images, group=self.tp_group) all_images = [img for sublist in gathered_images for img in sublist] @@ -252,15 +252,13 @@ def _generate_and_score_completions(self, inputs): all_images = None else: all_prompts_text = prompts_text - all_images = images if has_images else None + all_images = images - if has_images and all_images: + if images is not None and all_images: vllm_inputs = [] - for prompt, image in zip(all_prompts_text, all_images): - if image is not None: - vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image}}) - else: - vllm_inputs.append(prompt) + for prompt, image_list in zip(all_prompts_text, all_images): + vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image_list}}) + else: vllm_inputs = all_prompts_text @@ -451,7 +449,7 @@ def _generate_and_score_completions(self, inputs): # Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is # important because rewards will be normalized per group, and completions are distributed. We will later slice # rewards_per_func to extract each process's subset. - rewards_per_func = self._calculate_rewards(inputs, original_prompts, completions, completion_ids_list) + rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list) # Apply weights to each reward function's output and sum rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) @@ -563,12 +561,12 @@ def _generate_and_score_completions(self, inputs): # Log prompt and completion texts all_prompts_text = gather_object(prompts_text) all_completions_text = gather_object(completions_text) - all_images = gather_object(images) if has_images else None + all_images = gather_object(images) if images is not None else None if self.num_remains_in_group is not None and mode == "train": group_global_indices_list = group_global_indices.tolist() all_prompts_text = [all_prompts_text[i] for i in group_global_indices_list] all_completions_text = [all_completions_text[i] for i in group_global_indices_list] - if has_images: + if images is not None: all_images = [all_images[i] for i in group_global_indices_list] self._logs["prompt"].extend(all_prompts_text) @@ -577,8 +575,8 @@ def _generate_and_score_completions(self, inputs): self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist()) self._logs["advantages"].extend(all_process_advantages.tolist()) - if has_images: - self._logs["image"].extend(all_images) + if images is not None: + self._logs["images"].extend(gather_object(images)) if self.use_vllm and self.vllm_importance_sampling_correction: delta = torch.abs(old_per_token_logps - sampling_per_token_logps) @@ -639,4 +637,6 @@ def _generate_and_score_completions(self, inputs): output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"] if "image_sizes" in prompt_inputs: output["image_sizes"] = prompt_inputs["image_sizes"] + if images is not None: + output["num_images"] = [len(img_list) if img_list is not None else 0 for img_list in images] return output From 088897b9cd37925268fb8fcb48b56aa4bc2b65d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 20 Sep 2025 02:25:10 +0000 Subject: [PATCH 09/29] fix --- trl/trainer/grpo_trainer.py | 23 +++++++++++++++++------ trl/trainer/rloo_trainer.py | 21 ++++++++++++++++----- trl/trainer/utils.py | 12 ++++++++---- 3 files changed, 41 insertions(+), 15 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 59135c232e6..87a08096a13 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -788,6 +788,7 @@ def _get_per_token_logps_and_entropies( compute_entropy=False, pixel_values=None, image_grid_thw=None, + num_images=None, pixel_attention_mask=None, image_sizes=None, ) -> dict[str, Optional[torch.Tensor]]: @@ -801,12 +802,16 @@ def _get_per_token_logps_and_entropies( # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} - if image_grid_thw is not None and pixel_values is not None: - model_inputs["image_grid_thw"] = torch.cat(image_grid_thw[start : start + batch_size]) - start_pixel_idx = 0 if start == 0 else torch.cat(image_grid_thw[:start]).prod(-1).sum().item() - end_pixel_idx = torch.cat(image_grid_thw[: start + batch_size]).prod(-1).sum().item() - model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx] + rows_per_image = image_grid_thw.prod(dim=-1) + rows_per_sample = torch.split(rows_per_image, num_images) + rows_per_sample = torch.stack([s.sum() for s in rows_per_sample]) + cum_rows = torch.cat([torch.tensor([0], device=rows_per_sample.device), rows_per_sample.cumsum(0)]) + row_start, row_end = cum_rows[start].item(), cum_rows[start + batch_size].item() + model_inputs["pixel_values"] = pixel_values[row_start:row_end] + cum_imgs = torch.tensor([0] + num_images).cumsum(0) + img_start, img_end = cum_imgs[start], cum_imgs[start + batch_size] + model_inputs["image_grid_thw"] = image_grid_thw[img_start:img_end] elif pixel_values is not None: model_inputs["pixel_values"] = pixel_values[start : start + batch_size] if pixel_attention_mask is not None: @@ -1362,6 +1367,8 @@ def _generate_and_score_completions( logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size + num_images = [len(img_list) for img_list in images] if images is not None else None + with torch.no_grad(): # If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of # a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the @@ -1382,6 +1389,7 @@ def _generate_and_score_completions( batch_size, pixel_values=prompt_inputs.get("pixel_values"), image_grid_thw=prompt_inputs.get("image_grid_thw"), + num_images=num_images, pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), ) @@ -1406,6 +1414,7 @@ def _generate_and_score_completions( batch_size=batch_size, pixel_values=prompt_inputs.get("pixel_values"), image_grid_thw=prompt_inputs.get("image_grid_thw"), + num_images=num_images, pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), ) @@ -1419,6 +1428,7 @@ def _generate_and_score_completions( batch_size=batch_size, pixel_values=prompt_inputs.get("pixel_values"), image_grid_thw=prompt_inputs.get("image_grid_thw"), + num_images=num_images, pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), ) @@ -1570,7 +1580,7 @@ def _generate_and_score_completions( if "image_sizes" in prompt_inputs: output["image_sizes"] = prompt_inputs["image_sizes"] if images is not None: - output["num_images"] = [len(img_list) if img_list is not None else 0 for img_list in images] + output["num_images"] = num_images return output def compute_liger_loss(self, unwrapped_model, inputs): @@ -1643,6 +1653,7 @@ def _compute_loss(self, model, inputs): compute_entropy=True, pixel_values=inputs.get("pixel_values"), image_grid_thw=inputs.get("image_grid_thw"), + num_images=inputs.get("num_images"), pixel_attention_mask=inputs.get("pixel_attention_mask"), image_sizes=inputs.get("image_sizes"), ) diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 48801496b01..4e50bbc4501 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -774,6 +774,7 @@ def _get_per_token_logps_and_entropies( compute_entropy=False, pixel_values=None, image_grid_thw=None, + num_images=None, pixel_attention_mask=None, image_sizes=None, ) -> dict[str, Optional[torch.Tensor]]: @@ -789,10 +790,15 @@ def _get_per_token_logps_and_entropies( model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} if image_grid_thw is not None and pixel_values is not None: - model_inputs["image_grid_thw"] = torch.cat(image_grid_thw[start : start + batch_size]) - start_pixel_idx = 0 if start == 0 else torch.cat(image_grid_thw[:start]).prod(-1).sum().item() - end_pixel_idx = torch.cat(image_grid_thw[: start + batch_size]).prod(-1).sum().item() - model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx] + rows_per_image = image_grid_thw.prod(dim=-1) + rows_per_sample = torch.split(rows_per_image, num_images) + rows_per_sample = torch.stack([s.sum() for s in rows_per_sample]) + cum_rows = torch.cat([torch.tensor([0], device=rows_per_sample.device), rows_per_sample.cumsum(0)]) + row_start, row_end = cum_rows[start].item(), cum_rows[start + batch_size].item() + model_inputs["pixel_values"] = pixel_values[row_start:row_end] + cum_imgs = torch.tensor([0] + num_images).cumsum(0) + img_start, img_end = cum_imgs[start], cum_imgs[start + batch_size] + model_inputs["image_grid_thw"] = image_grid_thw[img_start:img_end] elif pixel_values is not None: model_inputs["pixel_values"] = pixel_values[start : start + batch_size] if pixel_attention_mask is not None: @@ -1326,6 +1332,8 @@ def _generate_and_score_completions( logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size + num_images = [len(img_list) for img_list in images] if images is not None else None + with torch.no_grad(): # Compute the per-token log probabilities for the current model old_per_token_logps, _ = self._get_per_token_logps_and_entropies( @@ -1336,6 +1344,7 @@ def _generate_and_score_completions( batch_size, pixel_values=prompt_inputs.get("pixel_values"), image_grid_thw=prompt_inputs.get("image_grid_thw"), + num_images=num_images, pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), ) @@ -1352,6 +1361,7 @@ def _generate_and_score_completions( batch_size=batch_size, pixel_values=prompt_inputs.get("pixel_values"), image_grid_thw=prompt_inputs.get("image_grid_thw"), + num_images=num_images, pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), ) @@ -1365,6 +1375,7 @@ def _generate_and_score_completions( batch_size=batch_size, pixel_values=prompt_inputs.get("pixel_values"), image_grid_thw=prompt_inputs.get("image_grid_thw"), + num_images=num_images, pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), ) @@ -1488,7 +1499,7 @@ def _generate_and_score_completions( if "image_sizes" in prompt_inputs: output["image_sizes"] = prompt_inputs["image_sizes"] if images is not None: - output["num_images"] = [len(img_list) if img_list is not None else 0 for img_list in images] + output["num_images"] = num_images return output @profiling_decorator diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 7cd16472c16..337e9857c1a 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -1806,12 +1806,16 @@ def unsplit_pixel_values_by_grid(batch: dict[str, Union[torch.Tensor, list[torch tensor along the first dimension. """ pixel_values = batch.get("pixel_values") - if isinstance(pixel_values, list): merged = torch.cat(pixel_values, dim=0) - return {**batch, "pixel_values": merged} - else: - return batch + batch = {**batch, "pixel_values": merged} + + image_grid_thw = batch.get("image_grid_thw") + if isinstance(image_grid_thw, list): + merged = torch.cat(image_grid_thw, dim=0) + batch = {**batch, "image_grid_thw": merged} + + return batch def truncate_with_protected_tokens( From d2adc63eb66c70592813b718ad5791e5fdead371 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 20 Sep 2025 02:52:33 +0000 Subject: [PATCH 10/29] test peft --- tests/test_grpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index ced4de9d73a..5577e1dd25d 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -1348,7 +1348,7 @@ def test_training_vlm_peft(self): "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration" ) base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] - dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") def reward_func(completions, **kwargs): """Reward function that rewards longer completions.""" From f4c82bfc0470c5a2fb590880e7a639eb77b93f2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 20 Sep 2025 02:55:59 +0000 Subject: [PATCH 11/29] fix gfpo --- trl/experimental/gfpo/gfpo_trainer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/trl/experimental/gfpo/gfpo_trainer.py b/trl/experimental/gfpo/gfpo_trainer.py index af83e076dd9..6ac4b6acc7f 100644 --- a/trl/experimental/gfpo/gfpo_trainer.py +++ b/trl/experimental/gfpo/gfpo_trainer.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import logging import re from contextlib import nullcontext @@ -373,6 +372,8 @@ def _generate_and_score_completions(self, inputs): logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size + num_images = [len(img_list) for img_list in images] if images is not None else None + with torch.no_grad(): # If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of # a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the @@ -393,6 +394,7 @@ def _generate_and_score_completions(self, inputs): batch_size, pixel_values=prompt_inputs.get("pixel_values"), image_grid_thw=prompt_inputs.get("image_grid_thw"), + num_images=num_images, pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), ) @@ -417,6 +419,7 @@ def _generate_and_score_completions(self, inputs): batch_size=batch_size, pixel_values=prompt_inputs.get("pixel_values"), image_grid_thw=prompt_inputs.get("image_grid_thw"), + num_images=num_images, pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), ) @@ -430,6 +433,7 @@ def _generate_and_score_completions(self, inputs): batch_size=batch_size, pixel_values=prompt_inputs.get("pixel_values"), image_grid_thw=prompt_inputs.get("image_grid_thw"), + num_images=num_images, pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), ) @@ -638,5 +642,5 @@ def _generate_and_score_completions(self, inputs): if "image_sizes" in prompt_inputs: output["image_sizes"] = prompt_inputs["image_sizes"] if images is not None: - output["num_images"] = [len(img_list) if img_list is not None else 0 for img_list in images] + output["num_images"] = num_images return output From 1257796ba85a6566ff4a3749a36b4b90360fd5ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 20 Sep 2025 03:01:47 +0000 Subject: [PATCH 12/29] rloo test --- tests/test_rloo_trainer.py | 67 +++++++++++++++++++++++++++++++++++--- 1 file changed, 62 insertions(+), 5 deletions(-) diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py index 12042bd2b3b..2bab06f218c 100644 --- a/tests/test_rloo_trainer.py +++ b/tests/test_rloo_trainer.py @@ -1089,6 +1089,10 @@ def test_prepare_input_called_with_correct_data(self): def test_training_vlm(self, model_id): dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + training_args = RLOOConfig( output_dir=self.tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test @@ -1100,7 +1104,7 @@ def test_training_vlm(self, model_id): ) trainer = RLOOTrainer( model=model_id, - reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + reward_funcs=reward_func, args=training_args, train_dataset=dataset, ) @@ -1132,6 +1136,10 @@ def test_training_vlm(self, model_id): def test_training_vlm_beta_non_zero(self): dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + training_args = RLOOConfig( output_dir=self.tmp_dir, beta=0.1, # set beta to non-zero value to test the case where the reference model is used @@ -1143,7 +1151,7 @@ def test_training_vlm_beta_non_zero(self): ) trainer = RLOOTrainer( model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", - reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + reward_funcs=reward_func, args=training_args, train_dataset=dataset, ) @@ -1173,6 +1181,10 @@ def test_training_vlm_peft(self): base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + training_args = RLOOConfig( output_dir=self.tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test @@ -1183,7 +1195,7 @@ def test_training_vlm_peft(self): ) trainer = RLOOTrainer( model=model, - reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + reward_funcs=reward_func, args=training_args, train_dataset=dataset, peft_config=LoraConfig(target_modules=["q_proj", "v_proj"]), @@ -1208,6 +1220,10 @@ def test_training_vlm_and_prompt_truncation(self): # If not handled properly, prompt truncation may truncate image token dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + training_args = RLOOConfig( output_dir=self.tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test @@ -1219,7 +1235,7 @@ def test_training_vlm_and_prompt_truncation(self): ) trainer = RLOOTrainer( model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", - reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + reward_funcs=reward_func, args=training_args, train_dataset=dataset, ) @@ -1252,6 +1268,10 @@ def test_training_vlm_and_prompt_truncation(self): def test_training_vlm_and_vllm(self, model_id) -> None: dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + training_args = RLOOConfig( output_dir=self.tmp_dir, learning_rate=0.1, @@ -1265,7 +1285,44 @@ def test_training_vlm_and_vllm(self, model_id) -> None: ) trainer = RLOOTrainer( model=model_id, - reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + reward_funcs=reward_func, + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + @require_vision + def test_training_vlm_multi_image(self): + dataset = load_dataset("trl-internal-testing/zen-multi-image", "conversational_prompt_only", split="train") + + # For now, mixing image+text and text-only examples is not supported, so we filter out text-only examples + dataset = dataset.filter(lambda x: len(x["images"]) > 0) + + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + + training_args = RLOOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + max_prompt_length=None, # disable prompt truncation, because usually, models don't support it + report_to="none", + ) + trainer = RLOOTrainer( + model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", + reward_funcs=reward_func, args=training_args, train_dataset=dataset, ) From 099a39bd6a9f90b082f3554facc699ac5463ee86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 20 Sep 2025 03:04:07 +0000 Subject: [PATCH 13/29] peft rloo --- tests/test_rloo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py index 2bab06f218c..399419ec3c1 100644 --- a/tests/test_rloo_trainer.py +++ b/tests/test_rloo_trainer.py @@ -1179,7 +1179,7 @@ def test_training_vlm_peft(self): "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration" ) base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] - dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") def reward_func(completions, **kwargs): """Reward function that rewards longer completions.""" From 529add673c30175a32be9454973e9435e28b2251 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 20 Sep 2025 03:55:03 +0000 Subject: [PATCH 14/29] oops --- trl/trainer/rloo_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 4e50bbc4501..3671af229e2 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -1525,6 +1525,7 @@ def _compute_loss(self, model, inputs): compute_entropy=True, pixel_values=inputs.get("pixel_values"), image_grid_thw=inputs.get("image_grid_thw"), + num_images=inputs.get("num_images"), pixel_attention_mask=inputs.get("pixel_attention_mask"), image_sizes=inputs.get("image_sizes"), ) From fc6b11fcaeb182edbe0c3e5c336bddcaeea7bf3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 20 Sep 2025 04:22:54 +0000 Subject: [PATCH 15/29] update test --- tests/test_utils.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 0fc16682336..6f6ba1579ef 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1071,12 +1071,16 @@ def test_empty_protected_tokens_list(self): class UnsplitPixelValuesByGridTester(TrlTestCase): def test_unsplit_correctly(self): - split = [torch.randn(4, 5), torch.randn(2, 5)] - merged = torch.cat(split, dim=0) - batch = {"pixel_values": split, "other_key": torch.tensor([1])} + pixel_values = [torch.randn(4, 5), torch.randn(2, 5)] + pixel_values_merged = torch.cat(pixel_values, dim=0) + image_grid_thw = [torch.tensor([[1, 2, 2]]), torch.tensor([[1, 2, 1]])] + image_grid_thw_merged = torch.cat(image_grid_thw, dim=0) + batch = {"pixel_values": pixel_values, "image_grid_thw": image_grid_thw, "other_key": torch.tensor([1])} result = unsplit_pixel_values_by_grid(batch) self.assertIsInstance(result["pixel_values"], torch.Tensor) - self.assertTrue(torch.allclose(result["pixel_values"], merged)) + self.assertTrue(torch.allclose(result["pixel_values"], pixel_values_merged)) + self.assertIsInstance(result["image_grid_thw"], torch.Tensor) + self.assertTrue(torch.equal(result["image_grid_thw"], image_grid_thw_merged)) self.assertIn("other_key", result) def test_no_op_if_not_list(self): From ae1f497959032ae6ba0120cbb13b8f406b9e1799 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 20 Sep 2025 05:08:48 +0000 Subject: [PATCH 16/29] generate method --- trl/trainer/grpo_trainer.py | 155 +++++++++++++++++++++--------------- 1 file changed, 92 insertions(+), 63 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 87a08096a13..bfc24a651c3 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1069,21 +1069,10 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): rewards_per_func = gather(rewards_per_func) return rewards_per_func - def _generate_and_score_completions( - self, inputs: list[dict[str, Union[torch.Tensor, Any]]] - ) -> dict[str, Union[torch.Tensor, Any]]: + def _generate(self, prompts: list[str], images: Optional[list]): device = self.accelerator.device mode = "train" if self.model.training else "eval" - prompts = [x["prompt"] for x in inputs] - - if "images" in inputs[0]: - images = [example.get("images") for example in inputs] - elif "image" in inputs[0]: - images = [[example.get("image")] if example.get("image") is not None else None for example in inputs] - else: - images = None - # If the prompts are conversational and the inputs contain images, we need to convert the prompts from # [{"role": "user", "content": "What color is the sky?"}] to # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] @@ -1094,7 +1083,9 @@ def _generate_and_score_completions( if isinstance(prompt, list): # i.e., when using conversational data prepare_multimodal_messages(prompt, num_images=len(image_list)) - prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] + prompts_text = [ + maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts + ] prompt_inputs = self.processing_class( text=prompts_text, @@ -1106,6 +1097,7 @@ def _generate_and_score_completions( ) prompt_inputs = super()._prepare_inputs(prompt_inputs) prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] + forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} if self.max_prompt_length is not None: # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens. @@ -1279,8 +1271,9 @@ def _generate_and_score_completions( # Pad the completions, and concatenate them with the prompts completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] + completion_mask = [torch.ones(len(ids), device=device, dtype=torch.long) for ids in completion_ids] completion_ids = pad(completion_ids, padding_value=self.pad_token_id) - prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) + completion_mask = pad(completion_mask, padding_value=0) sampling_per_token_logps = [ torch.tensor(logprobs, device=device, dtype=torch.float32) for logprobs in all_logprobs ] @@ -1318,9 +1311,9 @@ def _generate_and_score_completions( completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") prompt_ids = [torch.tensor(ids, device=device) for ids in paged_prompt_inputs.input_ids] prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") - prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # Restore the original attention implementation, training mode self.model_wrapped.config._attn_implementation = previous_attn + else: # Regular generation path with ( @@ -1331,14 +1324,18 @@ def _generate_and_score_completions( torch.no_grad(), FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), ): - prompt_inputs["input_ids"], prompt_inputs["attention_mask"] = prompt_ids, prompt_mask prompt_completion_ids = unwrapped_model.generate( - **prompt_inputs, generation_config=self.generation_config, disable_compile=True + input_ids=prompt_ids, + attention_mask=prompt_mask, + **forward_kwargs, + generation_config=self.generation_config, + disable_compile=True, ) # Compute prompt length and extract completion ids prompt_length = prompt_ids.size(1) prompt_ids = prompt_completion_ids[:, :prompt_length] completion_ids = prompt_completion_ids[:, prompt_length:] + sampling_per_token_logps = None # not used in this case # Mask everything after the first EOS token is_eos = completion_ids == self.eos_token_id @@ -1347,10 +1344,6 @@ def _generate_and_score_completions( sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() - # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need - # to re-tokenize completions if the reward is computed from tokens. - completion_ids_list = [row[mask_row].tolist() for row, mask_row in zip(completion_ids, completion_mask.bool())] - # Sum along sequence dimension (dim=1) to get completion length per sequence, used for logging completion_lengths = completion_mask.sum(1) agg_completion_lengths = self.accelerator.gather(completion_lengths) @@ -1361,7 +1354,72 @@ def _generate_and_score_completions( truncated_completions = ~is_eos.any(dim=1) completion_mask = completion_mask * (~truncated_completions).unsqueeze(1).int() + # Log the metrics + if mode == "train": + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + self.state.num_input_tokens_seen += self.accelerator.gather(attention_mask.sum()).sum().item() + self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] + + # Log completion lengths, mean, min, max + self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item()) + self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) + self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) + + # Identify sequences that terminated with EOS and log their lengths + agg_terminated_with_eos = self.accelerator.gather(is_eos.any(dim=1)) + term_completion_lengths = agg_completion_lengths[agg_terminated_with_eos] + clipped_completions_ratio = 1 - len(term_completion_lengths) / len(agg_completion_lengths) + self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio) + if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found + term_completion_lengths = torch.zeros(1, device=device) + self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) + self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) + self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) + + if images is not None: + self._logs["image"].extend(gather_object(images)) + + return ( + prompt_ids, + completion_ids, + prompt_mask, + completion_mask, + num_items_in_batch, + sampling_per_token_logps, + forward_kwargs, + ) + + def _generate_and_score_completions( + self, inputs: list[dict[str, Union[torch.Tensor, Any]]] + ) -> dict[str, Union[torch.Tensor, Any]]: + device = self.accelerator.device + mode = "train" if self.model.training else "eval" + + prompts = [x["prompt"] for x in inputs] + + if "images" in inputs[0]: + images = [example.get("images") for example in inputs] + elif "image" in inputs[0]: + images = [[example.get("image")] if example.get("image") is not None else None for example in inputs] + else: + images = None + + ( + prompt_ids, + completion_ids, + prompt_mask, + completion_mask, + num_items_in_batch, + sampling_per_token_logps, + forward_kwargs, + ) = self._generate(prompts, images) + + # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need + # to re-tokenize completions if the reward is computed from tokens. + completion_ids_list = [row[mask_row].tolist() for row, mask_row in zip(completion_ids, completion_mask.bool())] + # Concatenate prompt_mask with completion_mask for logit computation + prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens @@ -1387,11 +1445,8 @@ def _generate_and_score_completions( attention_mask, logits_to_keep, batch_size, - pixel_values=prompt_inputs.get("pixel_values"), - image_grid_thw=prompt_inputs.get("image_grid_thw"), num_images=num_images, - pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), - image_sizes=prompt_inputs.get("image_sizes"), + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes ) else: old_per_token_logps = None @@ -1412,11 +1467,8 @@ def _generate_and_score_completions( attention_mask, logits_to_keep, batch_size=batch_size, - pixel_values=prompt_inputs.get("pixel_values"), - image_grid_thw=prompt_inputs.get("image_grid_thw"), num_images=num_images, - pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), - image_sizes=prompt_inputs.get("image_sizes"), + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes ) else: with self.accelerator.unwrap_model(self.model).disable_adapter(): @@ -1426,16 +1478,14 @@ def _generate_and_score_completions( attention_mask, logits_to_keep, batch_size=batch_size, - pixel_values=prompt_inputs.get("pixel_values"), - image_grid_thw=prompt_inputs.get("image_grid_thw"), num_images=num_images, - pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), - image_sizes=prompt_inputs.get("image_sizes"), + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes ) else: ref_per_token_logps = None - # Decode the generated completions + # Decode + prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=True) completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) if is_conversational(inputs[0]): completions = [] @@ -1484,27 +1534,6 @@ def _generate_and_score_completions( all_process_advantages = advantages.clone() # keep the aggregated advantages for logging advantages = advantages[process_slice] - # Log the metrics - if mode == "train": - self.state.num_input_tokens_seen += self.accelerator.gather(attention_mask.sum()).sum().item() - self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] - - # Log completion lengths, mean, min, max - self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item()) - self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) - self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) - - # Identify sequences that terminated with EOS and log their lengths - agg_terminated_with_eos = self.accelerator.gather(is_eos.any(dim=1)) - term_completion_lengths = agg_completion_lengths[agg_terminated_with_eos] - clipped_completions_ratio = 1 - len(term_completion_lengths) / len(agg_completion_lengths) - self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio) - if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found - term_completion_lengths = torch.zeros(1, device=device) - self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) - self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) - self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) - # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values) for i, reward_func_name in enumerate(self.reward_func_names): mean_rewards = torch.nanmean(rewards_per_func[:, i]).item() @@ -1571,14 +1600,14 @@ def _generate_and_score_completions( output["importance_sampling_ratio"] = importance_sampling_ratio if ref_per_token_logps is not None: output["ref_per_token_logps"] = ref_per_token_logps - if "pixel_values" in prompt_inputs: - output["pixel_values"] = prompt_inputs["pixel_values"] - if "image_grid_thw" in prompt_inputs: - output["image_grid_thw"] = prompt_inputs["image_grid_thw"] - if "pixel_attention_mask" in prompt_inputs: - output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"] - if "image_sizes" in prompt_inputs: - output["image_sizes"] = prompt_inputs["image_sizes"] + if "pixel_values" in forward_kwargs: + output["pixel_values"] = forward_kwargs["pixel_values"] + if "image_grid_thw" in forward_kwargs: + output["image_grid_thw"] = forward_kwargs["image_grid_thw"] + if "pixel_attention_mask" in forward_kwargs: + output["pixel_attention_mask"] = forward_kwargs["pixel_attention_mask"] + if "image_sizes" in forward_kwargs: + output["image_sizes"] = forward_kwargs["image_sizes"] if images is not None: output["num_images"] = num_images return output From f99843262210380e08a43874e778b3270381bffd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 20 Sep 2025 05:18:40 +0000 Subject: [PATCH 17/29] debug --- .github/workflows/tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4231ef227ec..48ee6cc9295 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -77,6 +77,7 @@ jobs: - name: Test with pytest run: | source .venv/bin/activate + export CUDA_LAUNCH_BLOCKING=1 make test - name: Post to Slack From fa738768c685196407db94a652d1e52cf88bcc22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 20 Sep 2025 15:21:36 +0000 Subject: [PATCH 18/29] skip failing test --- tests/test_online_dpo_trainer.py | 58 ++++++++++++++++---------------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/tests/test_online_dpo_trainer.py b/tests/test_online_dpo_trainer.py index 47fbd1f5a1f..30b39ce3464 100644 --- a/tests/test_online_dpo_trainer.py +++ b/tests/test_online_dpo_trainer.py @@ -419,35 +419,35 @@ def test_generation_config_setup(self): self.assertEqual(trainer.generation_config.max_new_tokens, 64) self.assertFalse(trainer.generation_config.do_sample) # From generation_kwargs - @require_torch_accelerator - @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) - def test_training_with_transformers_paged(self, config_name): - if Version(transformers.__version__) < Version("4.56.2"): - pytest.xfail("Upstream bug in transformers (GH#40692). Fix merged; awaiting release >= 4.56.2") - training_args = OnlineDPOConfig( - output_dir=self.tmp_dir, - per_device_train_batch_size=2, - max_steps=3, - learning_rate=5.0e-7, - eval_strategy="steps", - report_to="none", - use_transformers_paged=True, - ) - dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) - - trainer = OnlineDPOTrainer( - model=self.model, - reward_funcs=self.reward_model, - args=training_args, - train_dataset=dummy_dataset["train"], - eval_dataset=dummy_dataset["test"], - processing_class=self.tokenizer, - reward_processing_classes=self.reward_tokenizer, - ) - trainer.train() - - # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + # @require_torch_accelerator + # @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) + # def test_training_with_transformers_paged(self, config_name): + # if Version(transformers.__version__) < Version("4.56.2"): + # pytest.xfail("Upstream bug in transformers (GH#40692). Fix merged; awaiting release >= 4.56.2") + # training_args = OnlineDPOConfig( + # output_dir=self.tmp_dir, + # per_device_train_batch_size=2, + # max_steps=3, + # learning_rate=5.0e-7, + # eval_strategy="steps", + # report_to="none", + # use_transformers_paged=True, + # ) + # dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) + + # trainer = OnlineDPOTrainer( + # model=self.model, + # reward_funcs=self.reward_model, + # args=training_args, + # train_dataset=dummy_dataset["train"], + # eval_dataset=dummy_dataset["test"], + # processing_class=self.tokenizer, + # reward_processing_classes=self.reward_tokenizer, + # ) + # trainer.train() + + # # Check if training loss is available + # self.assertIn("train_loss", trainer.state.log_history[-1]) @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) def test_training_with_reward_funcs(self, config_name): From fc52e6832d5b0e9f3403a01fd9571f4dd537ee5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 20 Sep 2025 16:26:34 +0000 Subject: [PATCH 19/29] test fixed! --- .github/workflows/tests.yml | 1 - scripts/generate_tiny_models.py | 1 + tests/test_online_dpo_trainer.py | 58 ++++++++++++++++---------------- 3 files changed, 30 insertions(+), 30 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 48ee6cc9295..4231ef227ec 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -77,7 +77,6 @@ jobs: - name: Test with pytest run: | source .venv/bin/activate - export CUDA_LAUNCH_BLOCKING=1 make test - name: Post to Slack diff --git a/scripts/generate_tiny_models.py b/scripts/generate_tiny_models.py index f8e779896f6..0000f788d09 100644 --- a/scripts/generate_tiny_models.py +++ b/scripts/generate_tiny_models.py @@ -292,6 +292,7 @@ def init_weights_tiny_model(model): "hidden_size": 16, "num_attention_heads": 4, "num_key_value_heads": 2, + "embed_dim": 64, } config = AutoConfig.from_pretrained(model_id, text_config=text_config, vision_config=vision_config) diff --git a/tests/test_online_dpo_trainer.py b/tests/test_online_dpo_trainer.py index 30b39ce3464..47fbd1f5a1f 100644 --- a/tests/test_online_dpo_trainer.py +++ b/tests/test_online_dpo_trainer.py @@ -419,35 +419,35 @@ def test_generation_config_setup(self): self.assertEqual(trainer.generation_config.max_new_tokens, 64) self.assertFalse(trainer.generation_config.do_sample) # From generation_kwargs - # @require_torch_accelerator - # @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) - # def test_training_with_transformers_paged(self, config_name): - # if Version(transformers.__version__) < Version("4.56.2"): - # pytest.xfail("Upstream bug in transformers (GH#40692). Fix merged; awaiting release >= 4.56.2") - # training_args = OnlineDPOConfig( - # output_dir=self.tmp_dir, - # per_device_train_batch_size=2, - # max_steps=3, - # learning_rate=5.0e-7, - # eval_strategy="steps", - # report_to="none", - # use_transformers_paged=True, - # ) - # dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) - - # trainer = OnlineDPOTrainer( - # model=self.model, - # reward_funcs=self.reward_model, - # args=training_args, - # train_dataset=dummy_dataset["train"], - # eval_dataset=dummy_dataset["test"], - # processing_class=self.tokenizer, - # reward_processing_classes=self.reward_tokenizer, - # ) - # trainer.train() - - # # Check if training loss is available - # self.assertIn("train_loss", trainer.state.log_history[-1]) + @require_torch_accelerator + @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) + def test_training_with_transformers_paged(self, config_name): + if Version(transformers.__version__) < Version("4.56.2"): + pytest.xfail("Upstream bug in transformers (GH#40692). Fix merged; awaiting release >= 4.56.2") + training_args = OnlineDPOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + learning_rate=5.0e-7, + eval_strategy="steps", + report_to="none", + use_transformers_paged=True, + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) + + trainer = OnlineDPOTrainer( + model=self.model, + reward_funcs=self.reward_model, + args=training_args, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + processing_class=self.tokenizer, + reward_processing_classes=self.reward_tokenizer, + ) + trainer.train() + + # Check if training loss is available + self.assertIn("train_loss", trainer.state.log_history[-1]) @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) def test_training_with_reward_funcs(self, config_name): From 4fc2b5b71d7b11c4e5488c8ad90b8999061798ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 20 Sep 2025 17:13:23 +0000 Subject: [PATCH 20/29] gfpo --- trl/experimental/gfpo/gfpo_trainer.py | 354 ++------------------------ 1 file changed, 27 insertions(+), 327 deletions(-) diff --git a/trl/experimental/gfpo/gfpo_trainer.py b/trl/experimental/gfpo/gfpo_trainer.py index 6ac4b6acc7f..d3b59a72c81 100644 --- a/trl/experimental/gfpo/gfpo_trainer.py +++ b/trl/experimental/gfpo/gfpo_trainer.py @@ -13,22 +13,15 @@ # limitations under the License. import logging -import re -from contextlib import nullcontext from typing import Any, Callable import torch -import torch.utils.data -from accelerate.utils import broadcast_object_list, gather_object -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from transformers.utils import is_flash_attn_2_available +from accelerate.utils import gather_object -from ...data_utils import is_conversational, maybe_apply_chat_template, prepare_multimodal_messages -from ...extras.profiling import profiling_context +from ...data_utils import is_conversational from ...import_utils import is_vllm_available -from ...models import unwrap_model_for_generation from ...trainer.grpo_trainer import GRPOTrainer as _GRPOTrainer -from ...trainer.utils import nanmax, nanmin, nanstd, pad, truncate_with_protected_tokens +from ...trainer.utils import nanmax, nanmin, nanstd logger = logging.getLogger(__name__) @@ -36,8 +29,7 @@ GroupFilterFunc = Callable[[list[list[Any]], list[list[Any]]], list[list[float]]] if is_vllm_available(): - from vllm import SamplingParams - from vllm.sampling_params import GuidedDecodingParams + pass class GFPOTrainer(_GRPOTrainer): @@ -89,284 +81,22 @@ def _generate_and_score_completions(self, inputs): else: images = None - # If the prompts are conversational and the inputs contain images, we need to convert the prompts from - # [{"role": "user", "content": "What color is the sky?"}] to - # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] - kwargs = {} - if images is not None: - kwargs = {"images": images} - for prompt, image_list in zip(prompts, images): - if isinstance(prompt, list): # i.e., when using conversational data - prepare_multimodal_messages(prompt, num_images=len(image_list)) - - prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] - - prompt_inputs = self.processing_class( - text=prompts_text, - return_tensors="pt", - padding=True, - padding_side="left", - add_special_tokens=False, - **kwargs, - ) - prompt_inputs = super()._prepare_inputs(prompt_inputs) - prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] - - if self.max_prompt_length is not None: - # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens. - # Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text, - # because we can't use `skip_special_tokens=True` (some special tokens are still needed for generation). - protected = [self.image_token_id, self.vision_start_token_id, self.vision_end_token_id] - protected = [token for token in protected if token is not None] - prompt_ids, prompt_mask = truncate_with_protected_tokens( - prompt_ids, prompt_mask, self.max_prompt_length, protected - ) - - prompts_text = self.processing_class.batch_decode( - prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False - ) - prompts_text = [re.sub(rf"^({re.escape(self.pad_token)})+", "", text) for text in prompts_text] - - # The chat template sometimes inserts a single image token into the prompt text. However, when this text is - # later tokenized, the single image token string is expanded into multiple image token IDs, depending on the - # image size. Since we're detokenizing here, we may see repeated image tokens in the decoded text. We - # collapse them back into a single token string to match the original chat template in case it originally - # applies it. Otherwise, it assumes that the chat template uses only vision_start_token_id to indicate images - # (e.g. Gemma 3) and removes all image_token instances and vision_end_token_id as well, leaving only - # the vision_start_token_id (e.g. ). - if self.image_token is not None: - escaped_img_token = re.escape(self.image_token) - # Search for the image token in the chat template - if re.search(escaped_img_token, self.processing_class.chat_template): - prompts_text = [ - re.sub(rf"({escaped_img_token})+", self.image_token, text) for text in prompts_text - ] - else: - # If the chat template doesn't use the image token, we remove all instances of it + vision_end_token_id - if self.vision_end_token_id is not None: - escaped_eoi_token = re.escape( - self.processing_class.tokenizer.decode([self.vision_end_token_id]) - ) - prompts_text = [ - re.sub(rf"({escaped_img_token})+{escaped_eoi_token}", "", text) for text in prompts_text - ] - else: - # If vision_end_token_id is None, just remove the image tokens - prompts_text = [re.sub(rf"({escaped_img_token})+", "", text) for text in prompts_text] - - # Generate completions using either vLLM or regular generation - if self.use_vllm: - if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode: - # wake up colocated vLLM instances if needed - torch.cuda.empty_cache() # required to avoid OOM in some cases - self.llm.wake_up() - - # First, update the vLLM weights if needed - if self.state.global_step != self._last_loaded_step: - self._move_model_to_vllm() - self._last_loaded_step = self.state.global_step - - # Generate completions using vLLM: gather all prompts and use them in a single call in the main process - if self.vllm_mode == "server": - all_prompts_text = gather_object(prompts_text) - if images is not None: - all_images = gather_object(images) - - if self.accelerator.is_main_process: - # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate - # num_generations outputs for each one. This is faster than generating outputs for each duplicate - # prompt individually. - ordered_set_of_prompts = all_prompts_text[:: self.num_generations] - - if images is not None: - ordered_set_of_images = all_images[:: self.num_generations] - else: - ordered_set_of_images = None - - with profiling_context(self, "vLLM.generate"): - output = self.vllm_client.generate( - prompts=ordered_set_of_prompts, - images=ordered_set_of_images, - n=self.num_generations, - repetition_penalty=self.repetition_penalty, - temperature=self.temperature, - top_p=self.top_p, - top_k=-1 if self.top_k is None else self.top_k, - min_p=0.0 if self.min_p is None else self.min_p, - max_tokens=self.max_completion_length, - guided_decoding_regex=self.guided_decoding_regex, - generation_kwargs=self.args.generation_kwargs, - ) - payload = (output["completion_ids"], output["logprobs"]) - else: - payload = None - - # Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice. - obj_list = [payload] - broadcast_object_list(obj_list, from_process=0) - completion_ids, all_logprobs = obj_list[0] - - process_slice = slice( - self.accelerator.process_index * len(prompts), - (self.accelerator.process_index + 1) * len(prompts), - ) - completion_ids = completion_ids[process_slice] - all_logprobs = all_logprobs[process_slice] - - # Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts - elif self.vllm_mode == "colocate": - if self.guided_decoding_regex: - guided_decoding = GuidedDecodingParams(regex=self.guided_decoding_regex) - else: - guided_decoding = None - - generation_kwargs = { - "n": 1, # vLLM on each GPU generates only 1 in colocate mode - "repetition_penalty": self.repetition_penalty, - "temperature": self.temperature, - "top_p": self.top_p, - "top_k": -1 if self.top_k is None else self.top_k, - "min_p": 0.0 if self.min_p is None else self.min_p, - "max_tokens": self.max_completion_length, - "guided_decoding": guided_decoding, - "logprobs": 0, # only return the logprob of the generated token - } - if self.args.generation_kwargs is not None: - generation_kwargs.update(self.args.generation_kwargs) - sampling_params = SamplingParams(**generation_kwargs) - - if self.vllm_tensor_parallel_size > 1: - # Gather prompts from all ranks in the TP group and flatten. - # Each rank starts with its own prompts; after gathering, all ranks see the full group set. - orig_size = len(prompts_text) - gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)] - torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group) - all_prompts_text = [p for sublist in gathered_prompts for p in sublist] - - if images is not None: - gathered_images = [None for _ in range(self.vllm_tensor_parallel_size)] - torch.distributed.all_gather_object(gathered_images, images, group=self.tp_group) - all_images = [img for sublist in gathered_images for img in sublist] - else: - all_images = None - else: - all_prompts_text = prompts_text - all_images = images - - if images is not None and all_images: - vllm_inputs = [] - for prompt, image_list in zip(all_prompts_text, all_images): - vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image_list}}) - - else: - vllm_inputs = all_prompts_text - - with profiling_context(self, "vLLM.generate"): - all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False) - - completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] - all_logprobs = [ - [next(iter(lp.values())).logprob for lp in output.logprobs] - for outputs in all_outputs - for output in outputs.outputs - ] - - if self.vllm_tensor_parallel_size > 1: - # Slice completions for this rank within its TP group. - # Each rank generates all outputs — we keep only our share. - local_rank_in_group = torch.distributed.get_rank(group=self.tp_group) - tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) - completion_ids = completion_ids[tp_slice] - all_logprobs = all_logprobs[tp_slice] - - if self.args.vllm_enable_sleep_mode: - self.llm.sleep(level=1) - - # Pad the completions, and concatenate them with the prompts - completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] - completion_ids = pad(completion_ids, padding_value=self.pad_token_id) - prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) - sampling_per_token_logps = [ - torch.tensor(logprobs, device=device, dtype=torch.float32) for logprobs in all_logprobs - ] - sampling_per_token_logps = pad(sampling_per_token_logps, padding_value=0.0) - - elif self.use_transformers_paged: - # Re-process inputs for paged generation if needed - # Note: images are already validated and preprocessed above - paged_prompt_inputs = self.processing_class(text=prompts_text, **kwargs) - previous_attn = self.model_wrapped.config._attn_implementation - - if is_flash_attn_2_available(): - self.model_wrapped.config._attn_implementation = "paged_attention" - else: - self.model_wrapped.config._attn_implementation = "sdpa_paged" - with ( - profiling_context(self, "transformers.generate_batch"), - unwrap_model_for_generation( - self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation - ) as unwrapped_model, - torch.no_grad(), - FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), - ): - # Cast to the appropriate dtype based on training configuration - if self.args.bf16: - unwrapped_model.to(torch.bfloat16) - elif self.args.fp16: - unwrapped_model.to(torch.float16) - with torch.inference_mode(): - all_outputs = unwrapped_model.generate_batch( - paged_prompt_inputs.input_ids, generation_config=self.generation_config, progress_bar=False - ) - completion_ids = [output.generated_tokens for output in all_outputs.values()] - completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] - completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") - prompt_ids = [torch.tensor(ids, device=device) for ids in paged_prompt_inputs.input_ids] - prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") - prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) - # Restore the original attention implementation, training mode - self.model_wrapped.config._attn_implementation = previous_attn - else: - # Regular generation path - with ( - profiling_context(self, "transformers.generate"), - unwrap_model_for_generation( - self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation - ) as unwrapped_model, - torch.no_grad(), - FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), - ): - prompt_inputs["input_ids"], prompt_inputs["attention_mask"] = prompt_ids, prompt_mask - prompt_completion_ids = unwrapped_model.generate( - **prompt_inputs, generation_config=self.generation_config, disable_compile=True - ) - # Compute prompt length and extract completion ids - prompt_length = prompt_ids.size(1) - prompt_ids = prompt_completion_ids[:, :prompt_length] - completion_ids = prompt_completion_ids[:, prompt_length:] - - # Mask everything after the first EOS token - is_eos = completion_ids == self.eos_token_id - eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) - eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] - sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) - completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() + ( + prompt_ids, + completion_ids, + prompt_mask, + completion_mask, + num_items_in_batch, + sampling_per_token_logps, + forward_kwargs, + ) = self._generate(prompts, images) # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need # to re-tokenize completions if the reward is computed from tokens. completion_ids_list = [row[mask_row].tolist() for row, mask_row in zip(completion_ids, completion_mask.bool())] - # Sum along sequence dimension (dim=1) to get completion length per sequence, used for logging - completion_lengths = completion_mask.sum(1) - agg_completion_lengths = self.accelerator.gather(completion_lengths) - num_items_in_batch = agg_completion_lengths.sum() # this is required for the DAPO loss - - # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask - if self.mask_truncated_completions: - truncated_completions = ~is_eos.any(dim=1) - completion_mask = completion_mask * (~truncated_completions).unsqueeze(1).int() - # Concatenate prompt_mask with completion_mask for logit computation + prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens @@ -392,11 +122,8 @@ def _generate_and_score_completions(self, inputs): attention_mask, logits_to_keep, batch_size, - pixel_values=prompt_inputs.get("pixel_values"), - image_grid_thw=prompt_inputs.get("image_grid_thw"), num_images=num_images, - pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), - image_sizes=prompt_inputs.get("image_sizes"), + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes ) else: old_per_token_logps = None @@ -417,11 +144,8 @@ def _generate_and_score_completions(self, inputs): attention_mask, logits_to_keep, batch_size=batch_size, - pixel_values=prompt_inputs.get("pixel_values"), - image_grid_thw=prompt_inputs.get("image_grid_thw"), num_images=num_images, - pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), - image_sizes=prompt_inputs.get("image_sizes"), + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes ) else: with self.accelerator.unwrap_model(self.model).disable_adapter(): @@ -431,16 +155,14 @@ def _generate_and_score_completions(self, inputs): attention_mask, logits_to_keep, batch_size=batch_size, - pixel_values=prompt_inputs.get("pixel_values"), - image_grid_thw=prompt_inputs.get("image_grid_thw"), num_images=num_images, - pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), - image_sizes=prompt_inputs.get("image_sizes"), + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes ) else: ref_per_token_logps = None - # Decode the generated completions + # Decode + prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=True) completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) if is_conversational(inputs[0]): completions = [] @@ -529,28 +251,6 @@ def _generate_and_score_completions(self, inputs): completion_lengths = completion_mask.sum(1) agg_completion_lengths = self.accelerator.gather(completion_lengths) num_items_in_batch = agg_completion_lengths.sum() - is_eos = completion_ids == self.eos_token_id - - # Log the metrics - if mode == "train": - self.state.num_input_tokens_seen += self.accelerator.gather(attention_mask.sum()).sum().item() - self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] - - # Log completion lengths, mean, min, max - self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item()) - self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) - self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) - - # Identify sequences that terminated with EOS and log their lengths - agg_terminated_with_eos = self.accelerator.gather(is_eos.any(dim=1)) - term_completion_lengths = agg_completion_lengths[agg_terminated_with_eos] - clipped_completions_ratio = 1 - len(term_completion_lengths) / len(agg_completion_lengths) - self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio) - if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found - term_completion_lengths = torch.zeros(1, device=device) - self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) - self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) - self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values) for i, reward_func_name in enumerate(self.reward_func_names): @@ -633,14 +333,14 @@ def _generate_and_score_completions(self, inputs): output["importance_sampling_ratio"] = importance_sampling_ratio if ref_per_token_logps is not None: output["ref_per_token_logps"] = ref_per_token_logps - if "pixel_values" in prompt_inputs: - output["pixel_values"] = prompt_inputs["pixel_values"] - if "image_grid_thw" in prompt_inputs: - output["image_grid_thw"] = prompt_inputs["image_grid_thw"] - if "pixel_attention_mask" in prompt_inputs: - output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"] - if "image_sizes" in prompt_inputs: - output["image_sizes"] = prompt_inputs["image_sizes"] + if "pixel_values" in forward_kwargs: + output["pixel_values"] = forward_kwargs["pixel_values"] + if "image_grid_thw" in forward_kwargs: + output["image_grid_thw"] = forward_kwargs["image_grid_thw"] + if "pixel_attention_mask" in forward_kwargs: + output["pixel_attention_mask"] = forward_kwargs["pixel_attention_mask"] + if "image_sizes" in forward_kwargs: + output["image_sizes"] = forward_kwargs["image_sizes"] if images is not None: output["num_images"] = num_images return output From b628744752d54c3f2028b4a2d420bd88fc06a188 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 20 Sep 2025 17:15:02 +0000 Subject: [PATCH 21/29] rm vllm --- trl/experimental/gfpo/gfpo_trainer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/trl/experimental/gfpo/gfpo_trainer.py b/trl/experimental/gfpo/gfpo_trainer.py index d3b59a72c81..5e228c1e883 100644 --- a/trl/experimental/gfpo/gfpo_trainer.py +++ b/trl/experimental/gfpo/gfpo_trainer.py @@ -19,7 +19,6 @@ from accelerate.utils import gather_object from ...data_utils import is_conversational -from ...import_utils import is_vllm_available from ...trainer.grpo_trainer import GRPOTrainer as _GRPOTrainer from ...trainer.utils import nanmax, nanmin, nanstd @@ -28,9 +27,6 @@ GroupFilterFunc = Callable[[list[list[Any]], list[list[Any]]], list[list[float]]] -if is_vllm_available(): - pass - class GFPOTrainer(_GRPOTrainer): def __init__( From d3a769fe8fb5a8a4b9e21b6d96c8b1d8f84c7960 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 20 Sep 2025 17:15:13 +0000 Subject: [PATCH 22/29] fix doc --- docs/source/experimental.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/experimental.md b/docs/source/experimental.md index 65471f4e421..1413e56a2df 100644 --- a/docs/source/experimental.md +++ b/docs/source/experimental.md @@ -66,7 +66,7 @@ class GroupFilter: return group_scores training_args = GFPOConfig( - output_dir="Qwen3-0.6B-GFPO" + output_dir="Qwen3-0.6B-GFPO", per_device_train_batch_size=4, num_remains_in_group=2, bf16=True, From c9693b255655fffc5eb8690d872e810a0353df44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sun, 21 Sep 2025 00:43:21 +0000 Subject: [PATCH 23/29] a bit messy! --- trl/data_utils.py | 2 +- trl/trainer/grpo_trainer.py | 70 ++++++++++++++++++++++++++++++++++++- 2 files changed, 70 insertions(+), 2 deletions(-) diff --git a/trl/data_utils.py b/trl/data_utils.py index 50931fb9a15..3093e0efa99 100644 --- a/trl/data_utils.py +++ b/trl/data_utils.py @@ -148,7 +148,7 @@ def apply_chat_template( # Apply the chat template to the prompt, adding the generation prompt if "prompt" in example: last_role = example["prompt"][-1]["role"] - if last_role == "user": + if last_role in ["user", "tool"]: add_generation_prompt = True continue_final_message = False elif last_role == "assistant": diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index bfc24a651c3..d99c27119e1 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +import json import os import re import textwrap @@ -60,6 +61,8 @@ RepeatSampler, disable_dropout_in_model, entropy_from_logits, + flush_left, + flush_right, generate_model_card, get_comet_experiment_url, identity, @@ -98,6 +101,20 @@ RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]] +def extract_tool_calls(text: str) -> dict[str, Any]: + """ + Given a list of strings, extract all JSON blocks and return them as a list of dictionaries. + """ + pattern = re.compile(r"\s*(\{.*?\})\s*", re.DOTALL) + + for match in pattern.findall(text): + try: + return json.loads(match) + except json.JSONDecodeError: + pass + return None + + class GRPOTrainer(Trainer): """ Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the @@ -214,7 +231,10 @@ def __init__( callbacks: Optional[list[TrainerCallback]] = None, optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), peft_config: Optional["PeftConfig"] = None, + tools=None, ): + self.tools = tools or [] + self._tool_dict = {name: tool for name, tool in zip([tool.__name__ for tool in self.tools], self.tools)} # Args if args is None: model_name = model if isinstance(model, str) else model.config._name_or_path @@ -1084,7 +1104,8 @@ def _generate(self, prompts: list[str], images: Optional[list]): prepare_multimodal_messages(prompt, num_images=len(image_list)) prompts_text = [ - maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts + maybe_apply_chat_template({"prompt": prompt}, self.processing_class, tools=self.tools)["prompt"] + for prompt in prompts ] prompt_inputs = self.processing_class( @@ -1413,6 +1434,53 @@ def _generate_and_score_completions( sampling_per_token_logps, forward_kwargs, ) = self._generate(prompts, images) + completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + tool_calls = [extract_tool_calls(completion) for completion in completions] + tool_results = [self._tool_dict[tc["name"]](**tc["arguments"]) if tc else None for tc in tool_calls] + tool_messages = [ + [{"role": "tool", "name": tc["name"], "content": str(tr)}] if tc else None + for tc, tr in zip(tool_calls, tool_results) + ] + new_prompts = [ + p + [{"role": "user", "content": c}] + t for p, c, t in zip(prompts, completions, tool_messages) if t + ] + needs_tool = torch.tensor([tc is not None for tc in tool_calls], device=device) + if new_prompts: + ( + new_prompt_ids, + new_completion_ids, + new_prompt_mask, + new_completion_mask, + new_num_items_in_batch, + new_sampling_per_token_logps, + new_forward_kwargs, + ) = self._generate(new_prompts, images) + num_tool_ids = new_prompt_mask.sum(-1) - torch.cat( + [prompt_mask[needs_tool], completion_mask[needs_tool]], dim=1 + ).sum(-1) + tool_ids = [ids[-num:] for ids, num in zip(new_prompt_ids, num_tool_ids)] + tool_mask = [torch.ones_like(ids) for ids in tool_ids] + r_completion_mask, r_completion_ids = flush_right(completion_mask[needs_tool], completion_ids[needs_tool]) + ci = [torch.cat(x) for x in zip(r_completion_ids, tool_ids, new_completion_ids)] + cm = [torch.cat(x) for x in zip(r_completion_mask, tool_mask, new_completion_mask)] + + new_ci = [] + new_cm = [] + true_idx = 0 + for i, m in enumerate(needs_tool): + if m: + # take the next tensor from list_true + new_ci.append(ci[true_idx]) + new_cm.append(cm[true_idx]) + true_idx += 1 + else: + new_ci.append(completion_ids[i]) + new_cm.append(completion_mask[i]) + + completion_ids = pad(new_ci, self.pad_token_id) + completion_mask = pad(new_cm, 0) + completion_mask, completion_ids = flush_left(completion_mask, completion_ids) + num_items_in_batch += new_num_items_in_batch # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need # to re-tokenize completions if the reward is computed from tokens. From 05270f820f69bad6b3edc1edbeee2d53189a9143 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 22 Sep 2025 23:51:57 +0000 Subject: [PATCH 24/29] update layers to ignore --- tests/test_grpo_trainer.py | 11 ----------- tests/test_rloo_trainer.py | 2 -- 2 files changed, 13 deletions(-) diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 5577e1dd25d..cc484c56d0f 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -1291,7 +1291,6 @@ def reward_func(completions, **kwargs): "model.vision_tower.", "model.multi_modal_projector.", "model.vision_model.", - "model.connector.modality_projection.", "model.visual.", "model.image_newline", ) @@ -1587,17 +1586,7 @@ def reward_func(completions, **kwargs): # Check that the params have changed # Because of the way the tiny models are initialized, the gradient does not flow properly through the # vision parts of the model, so we skip them. Ideally, we should fix the init of these models. - params_to_skip = ( - # "model.vision_tower.", - # "model.multi_modal_projector.", - # "model.vision_model.", - # "model.connector.modality_projection.", - # "model.visual.", - # "model.image_newline", - ) for n, param in previous_trainable_params.items(): - if n.startswith(params_to_skip): - continue new_param = trainer.model.get_parameter(n) self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py index 399419ec3c1..cde52de6047 100644 --- a/tests/test_rloo_trainer.py +++ b/tests/test_rloo_trainer.py @@ -1121,8 +1121,6 @@ def reward_func(completions, **kwargs): params_to_skip = ( "model.vision_tower.", "model.multi_modal_projector.", - "model.vision_model.", - "model.connector.modality_projection.", "model.visual.", "model.image_newline", ) From 1c530948681255ce59bc267939d79a80ec1c5d93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 22 Sep 2025 23:57:13 +0000 Subject: [PATCH 25/29] clarify image column desc --- docs/source/dataset_formats.md | 2 +- docs/source/grpo_trainer.md | 4 +++- docs/source/rloo_trainer.md | 4 +++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/docs/source/dataset_formats.md b/docs/source/dataset_formats.md index da606a7d97e..8a105ff5e34 100644 --- a/docs/source/dataset_formats.md +++ b/docs/source/dataset_formats.md @@ -1037,7 +1037,7 @@ Some trainers also support fine-tuning vision-language models (VLMs) using image A conversational vision dataset differs from a standard conversational dataset in two key ways: -1. The dataset must contain the key `images` with the image data. +1. The dataset must contain the key `images` with the image data (as lists of PIL images) or `image` with a single PIL image. 2. The `"content"` field in messages must be a list of dictionaries, where each dictionary specifies the type of data: `"image"` or `"text"`. Example: diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md index 172e5f93111..e998ef63f69 100644 --- a/docs/source/grpo_trainer.md +++ b/docs/source/grpo_trainer.md @@ -562,7 +562,9 @@ Tested with: - **SmolVLM2** — e.g., `HuggingFaceTB/SmolVLM2-2.2B-Instruct` + Compatibility with all VLMs is not guaranteed. If you believe a model should be supported, feel free to open an issue on GitHub — or better yet, submit a pull request with the required changes. + ### Quick Start @@ -605,7 +607,7 @@ VLM training may fail if image tokens are truncated. We highly recommend disabli Each training sample should include: - `prompt`: Text formatted via the processor's chat template -- `image`: A single image (PIL or NumPy array) +- `image`/`images`: PIL Image or list of PIL Images The trainer automatically handles image-to-tensor conversion via the model’s image processor. diff --git a/docs/source/rloo_trainer.md b/docs/source/rloo_trainer.md index 66a8f3e16e4..bce71b1f0bf 100644 --- a/docs/source/rloo_trainer.md +++ b/docs/source/rloo_trainer.md @@ -533,7 +533,9 @@ Tested with: - **SmolVLM2** — e.g., `HuggingFaceTB/SmolVLM2-2.2B-Instruct` + Compatibility with all VLMs is not guaranteed. If you believe a model should be supported, feel free to open an issue on GitHub — or better yet, submit a pull request with the required changes. + ### Quick Start @@ -576,7 +578,7 @@ VLM training may fail if image tokens are truncated. We highly recommend disabli Each training sample should include: - `prompt`: Text formatted via the processor's chat template -- `image`: A single image (PIL or NumPy array) +- `image`/`images`: PIL Image or list of PIL Images The trainer automatically handles image-to-tensor conversion via the model’s image processor. From 9b6652eed4fdc6c7e73c40d81917a3e5c9ad024c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 23 Sep 2025 00:05:23 +0000 Subject: [PATCH 26/29] rm VLM x RM warning --- trl/trainer/grpo_trainer.py | 8 -------- trl/trainer/rloo_trainer.py | 8 -------- 2 files changed, 16 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 87a08096a13..69825102b27 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1024,14 +1024,6 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): with profiling_context(self, reward_func_name): if isinstance(reward_func, nn.Module): # Module (no PretrainedModel) for compat with compiled models if is_conversational(inputs[0]): - # VLM reward models aren't supported yet, so we drop the image and raise a warning if needed - for prompt in prompts: - for turn in prompt: - if isinstance(turn["content"], list): - logger.warning_once("Visual reward models aren't supported yet; dropping image.") - turn["content"] = " ".join( - e["text"] for e in turn["content"] if e["type"] == "text" - ) messages = [{"messages": p + c} for p, c in zip(prompts, completions)] texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages] else: diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 3671af229e2..359cb68e43b 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -1011,14 +1011,6 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): with profiling_context(self, reward_func_name): if isinstance(reward_func, nn.Module): # Module (no PretrainedModel) for compat with compiled models if is_conversational(inputs[0]): - # VLM reward models aren't supported yet, so we drop the image and raise a warning if needed - for prompt in prompts: - for turn in prompt: - if isinstance(turn["content"], list): - logger.warning_once("Visual reward models aren't supported yet; dropping image.") - turn["content"] = " ".join( - e["text"] for e in turn["content"] if e["type"] == "text" - ) messages = [{"messages": p + c} for p, c in zip(prompts, completions)] texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages] else: From c83e7108319d19ffe55866a8f7401f9741b93df1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 24 Sep 2025 17:17:14 +0000 Subject: [PATCH 27/29] same for rloo --- trl/trainer/grpo_trainer.py | 7 +- trl/trainer/rloo_trainer.py | 150 ++++++++++++++++++++---------------- 2 files changed, 87 insertions(+), 70 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 78c3ccff638..a947d52cec3 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -480,7 +480,7 @@ def __init__( if not is_vllm_available(): raise ImportError( "vLLM is not available and `use_vllm` is set to True. Please install vLLM with " - "`pip install [vllm]` to use it." + "`pip install trl[vllm]` to use it." ) if self.vllm_mode == "server": @@ -533,7 +533,7 @@ def __init__( distributed_executor_backend="external_launcher", # Feed identical seed for tp groups to ensure sampling results are the same across workers seed=self.accelerator.process_index // self.vllm_tensor_parallel_size, - # Latest vLLM v1 memory profiler is misled by the high default value (i.e., 32768) - thinking there's not enough memory + # Latest vLLM v1 memory profiler is misled by the high default value (i.e., 32768) - thinking there's not enough memory max_num_batched_tokens=4096, model_impl=self.args.vllm_model_impl, enable_sleep_mode=self.args.vllm_enable_sleep_mode, @@ -1366,9 +1366,6 @@ def _generate(self, prompts: list[str], images: Optional[list]): self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) - if images is not None: - self._logs["image"].extend(gather_object(images)) - return ( prompt_ids, completion_ids, diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index e87ecf95b37..5ff29112e9c 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -81,7 +81,6 @@ if is_peft_available(): from peft import PeftConfig, PeftModel - if is_vllm_available(): from vllm import LLM, SamplingParams from vllm.sampling_params import GuidedDecodingParams @@ -788,7 +787,6 @@ def _get_per_token_logps_and_entropies( # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} - if image_grid_thw is not None and pixel_values is not None: rows_per_image = image_grid_thw.prod(dim=-1) rows_per_sample = torch.split(rows_per_image, num_images) @@ -1048,21 +1046,10 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): rewards_per_func = gather(rewards_per_func) return rewards_per_func - def _generate_and_score_completions( - self, inputs: list[dict[str, Union[torch.Tensor, Any]]] - ) -> dict[str, Union[torch.Tensor, Any]]: + def _generate(self, prompts: list[str], images: Optional[list]): device = self.accelerator.device mode = "train" if self.model.training else "eval" - prompts = [x["prompt"] for x in inputs] - - if "images" in inputs[0]: - images = [example.get("images") for example in inputs] - elif "image" in inputs[0]: - images = [[example.get("image")] if example.get("image") is not None else None for example in inputs] - else: - images = None - # If the prompts are conversational and the inputs contain images, we need to convert the prompts from # [{"role": "user", "content": "What color is the sky?"}] to # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] @@ -1073,7 +1060,9 @@ def _generate_and_score_completions( if isinstance(prompt, list): # i.e., when using conversational data prepare_multimodal_messages(prompt, num_images=len(image_list)) - prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] + prompts_text = [ + maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts + ] prompt_inputs = self.processing_class( text=prompts_text, @@ -1085,6 +1074,7 @@ def _generate_and_score_completions( ) prompt_inputs = super()._prepare_inputs(prompt_inputs) prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] + forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} if self.max_prompt_length is not None: # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens. @@ -1250,8 +1240,9 @@ def _generate_and_score_completions( # Pad the completions, and concatenate them with the prompts completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] + completion_mask = [torch.ones(len(ids), device=device, dtype=torch.long) for ids in completion_ids] completion_ids = pad(completion_ids, padding_value=self.pad_token_id) - prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) + completion_mask = pad(completion_mask, padding_value=0) elif self.use_transformers_paged: # Re-process inputs for paged generation if needed @@ -1286,9 +1277,9 @@ def _generate_and_score_completions( completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") prompt_ids = [torch.tensor(ids, device=device) for ids in paged_prompt_inputs.input_ids] prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") - prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # Restore the original attention implementation, training mode self.model_wrapped.config._attn_implementation = previous_attn + else: # Regular generation path with ( @@ -1299,9 +1290,12 @@ def _generate_and_score_completions( torch.no_grad(), FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), ): - prompt_inputs["input_ids"], prompt_inputs["attention_mask"] = prompt_ids, prompt_mask prompt_completion_ids = unwrapped_model.generate( - **prompt_inputs, generation_config=self.generation_config, disable_compile=True + input_ids=prompt_ids, + attention_mask=prompt_mask, + **forward_kwargs, + generation_config=self.generation_config, + disable_compile=True, ) # Compute prompt length and extract completion ids prompt_length = prompt_ids.size(1) @@ -1315,10 +1309,6 @@ def _generate_and_score_completions( sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() - # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need - # to re-tokenize completions if the reward is computed from tokens. - completion_ids_list = [row[mask_row].tolist() for row, mask_row in zip(completion_ids, completion_mask.bool())] - # Sum along sequence dimension (dim=1) to get completion length per sequence, used for logging completion_lengths = completion_mask.sum(1) @@ -1327,7 +1317,66 @@ def _generate_and_score_completions( truncated_completions = ~is_eos.any(dim=1) completion_mask = completion_mask * (~truncated_completions).unsqueeze(1).int() + # Log the metrics + if mode == "train": + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + self.state.num_input_tokens_seen += self.accelerator.gather(attention_mask.sum()).sum().item() + self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] + + # Log completion lengths, mean, min, max + agg_completion_lengths = self.accelerator.gather(completion_lengths) + self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item()) + self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) + self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) + + # Identify sequences that terminated with EOS and log their lengths + agg_terminated_with_eos = self.accelerator.gather(is_eos.any(dim=1)) + term_completion_lengths = agg_completion_lengths[agg_terminated_with_eos] + clipped_completions_ratio = 1 - len(term_completion_lengths) / len(agg_completion_lengths) + self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio) + if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found + term_completion_lengths = torch.zeros(1, device=device) + self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) + self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) + self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) + + return ( + prompt_ids, + completion_ids, + prompt_mask, + completion_mask, + forward_kwargs + ) + + def _generate_and_score_completions( + self, inputs: list[dict[str, Union[torch.Tensor, Any]]] + ) -> dict[str, Union[torch.Tensor, Any]]: + device = self.accelerator.device + mode = "train" if self.model.training else "eval" + + prompts = [x["prompt"] for x in inputs] + + if "images" in inputs[0]: + images = [example.get("images") for example in inputs] + elif "image" in inputs[0]: + images = [[example.get("image")] if example.get("image") is not None else None for example in inputs] + else: + images = None + + ( + prompt_ids, + completion_ids, + prompt_mask, + completion_mask, + forward_kwargs, + ) = self._generate(prompts, images) + + # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need + # to re-tokenize completions if the reward is computed from tokens. + completion_ids_list = [row[mask_row].tolist() for row, mask_row in zip(completion_ids, completion_mask.bool())] + # Concatenate prompt_mask with completion_mask for logit computation + prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens @@ -1343,11 +1392,8 @@ def _generate_and_score_completions( attention_mask, logits_to_keep, batch_size, - pixel_values=prompt_inputs.get("pixel_values"), - image_grid_thw=prompt_inputs.get("image_grid_thw"), num_images=num_images, - pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), - image_sizes=prompt_inputs.get("image_sizes"), + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes ) old_logps = (old_per_token_logps * completion_mask).sum(1) # mask out padding and tokens after EOS @@ -1360,11 +1406,8 @@ def _generate_and_score_completions( attention_mask, logits_to_keep, batch_size=batch_size, - pixel_values=prompt_inputs.get("pixel_values"), - image_grid_thw=prompt_inputs.get("image_grid_thw"), num_images=num_images, - pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), - image_sizes=prompt_inputs.get("image_sizes"), + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes ) else: with self.accelerator.unwrap_model(self.model).disable_adapter(): @@ -1374,16 +1417,14 @@ def _generate_and_score_completions( attention_mask, logits_to_keep, batch_size=batch_size, - pixel_values=prompt_inputs.get("pixel_values"), - image_grid_thw=prompt_inputs.get("image_grid_thw"), num_images=num_images, - pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), - image_sizes=prompt_inputs.get("image_sizes"), + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes ) else: ref_per_token_logps = None - # Decode the generated completions + # Decode + prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=True) completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) if is_conversational(inputs[0]): completions = [] @@ -1436,33 +1477,12 @@ def _generate_and_score_completions( all_process_advantages = advantages.clone() # keep the aggregated advantages for logging advantages = advantages[process_slice] - # Log the metrics - if mode == "train": - self.state.num_input_tokens_seen += self.accelerator.gather(attention_mask.sum()).sum().item() - self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] # Calculate and log the mean KL divergence between current and reference model if self.beta != 0.0: mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).nanmean().item()) - # Log completion lengths, mean, min, max - agg_completion_lengths = self.accelerator.gather(completion_lengths) - self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item()) - self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) - self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) - - # Identify sequences that terminated with EOS and log their lengths - agg_terminated_with_eos = self.accelerator.gather(is_eos.any(dim=1)) - term_completion_lengths = agg_completion_lengths[agg_terminated_with_eos] - clipped_completions_ratio = 1 - len(term_completion_lengths) / len(agg_completion_lengths) - self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio) - if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found - term_completion_lengths = torch.zeros(1, device=device) - self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) - self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) - self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) - # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values) for i, reward_func_name in enumerate(self.reward_func_names): mean_rewards = torch.nanmean(rewards_per_func[:, i]).item() @@ -1491,14 +1511,14 @@ def _generate_and_score_completions( "old_logps": old_logps, "advantages": advantages, } - if "pixel_values" in prompt_inputs: - output["pixel_values"] = prompt_inputs["pixel_values"] - if "image_grid_thw" in prompt_inputs: - output["image_grid_thw"] = prompt_inputs["image_grid_thw"] - if "pixel_attention_mask" in prompt_inputs: - output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"] - if "image_sizes" in prompt_inputs: - output["image_sizes"] = prompt_inputs["image_sizes"] + if "pixel_values" in forward_kwargs: + output["pixel_values"] = forward_kwargs["pixel_values"] + if "image_grid_thw" in forward_kwargs: + output["image_grid_thw"] = forward_kwargs["image_grid_thw"] + if "pixel_attention_mask" in forward_kwargs: + output["pixel_attention_mask"] = forward_kwargs["pixel_attention_mask"] + if "image_sizes" in forward_kwargs: + output["image_sizes"] = forward_kwargs["image_sizes"] if images is not None: output["num_images"] = num_images return output From ec6ad259d22cbb817eb3b5a6cf948799721893bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 24 Sep 2025 17:26:25 +0000 Subject: [PATCH 28/29] nits style and align --- trl/trainer/grpo_trainer.py | 1 + trl/trainer/rloo_trainer.py | 18 +++--------------- 2 files changed, 4 insertions(+), 15 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index a947d52cec3..7d9138319e6 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1303,6 +1303,7 @@ def _generate(self, prompts: list[str], images: Optional[list]): prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") # Restore the original attention implementation, training mode self.model_wrapped.config._attn_implementation = previous_attn + sampling_per_token_logps = None # not used in this case else: # Regular generation path diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 5ff29112e9c..a4e31c7a9cf 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -787,6 +787,7 @@ def _get_per_token_logps_and_entropies( # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} + if image_grid_thw is not None and pixel_values is not None: rows_per_image = image_grid_thw.prod(dim=-1) rows_per_sample = torch.split(rows_per_image, num_images) @@ -1340,13 +1341,7 @@ def _generate(self, prompts: list[str], images: Optional[list]): self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) - return ( - prompt_ids, - completion_ids, - prompt_mask, - completion_mask, - forward_kwargs - ) + return prompt_ids, completion_ids, prompt_mask, completion_mask, forward_kwargs def _generate_and_score_completions( self, inputs: list[dict[str, Union[torch.Tensor, Any]]] @@ -1363,13 +1358,7 @@ def _generate_and_score_completions( else: images = None - ( - prompt_ids, - completion_ids, - prompt_mask, - completion_mask, - forward_kwargs, - ) = self._generate(prompts, images) + prompt_ids, completion_ids, prompt_mask, completion_mask, forward_kwargs = self._generate(prompts, images) # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need # to re-tokenize completions if the reward is computed from tokens. @@ -1477,7 +1466,6 @@ def _generate_and_score_completions( all_process_advantages = advantages.clone() # keep the aggregated advantages for logging advantages = advantages[process_slice] - # Calculate and log the mean KL divergence between current and reference model if self.beta != 0.0: mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) From 04e4bd7228256106b3e91d6525db08b8ecc2f0f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Thu, 25 Sep 2025 21:16:05 -0600 Subject: [PATCH 29/29] Update trl/trainer/grpo_trainer.py Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> --- trl/trainer/grpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 67590c0775d..35da0f157ae 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -235,7 +235,7 @@ def __init__( tools=None, ): self.tools = tools or [] - self._tool_dict = {name: tool for name, tool in zip([tool.__name__ for tool in self.tools], self.tools)} + self._tool_dict = {tool.__name__: tool for tool in self.tools} # Args if args is None: model_name = model if isinstance(model, str) else model.config._name_or_path