diff --git a/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py b/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py index 9d0cd58546b..9fc7b26cbba 100644 --- a/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py +++ b/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import heapq import re from contextlib import nullcontext @@ -89,22 +88,22 @@ def _generate_and_score_completions( prompts = [x["prompt"] for x in inputs] - # We don't yet support visual reward models/function, so we keep a copy of the original text-only prompts for - # later use in the reward computation. If images are present, we insert {"type": "image"} as required by the - # VLM chat template. - original_prompts = copy.deepcopy(prompts) + if "images" in inputs[0]: + images = [example.get("images") for example in inputs] + elif "image" in inputs[0]: + images = [[example.get("image")] if example.get("image") is not None else None for example in inputs] + else: + images = None # If the prompts are conversational and the inputs contain images, we need to convert the prompts from # [{"role": "user", "content": "What color is the sky?"}] to # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] kwargs = {} - has_images = "image" in inputs[0] - if has_images: - images = [example.get("image") for example in inputs] - kwargs = {"images": [[img] for img in images]} - for prompt in prompts: + if images is not None: + kwargs = {"images": images} + for prompt, image_list in zip(prompts, images): if isinstance(prompt, list): # i.e., when using conversational data - prepare_multimodal_messages(prompt, num_images=1) + prepare_multimodal_messages(prompt, num_images=len(image_list)) prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] @@ -176,7 +175,7 @@ def _generate_and_score_completions( # Generate completions using vLLM: gather all prompts and use them in a single call in the main process if self.vllm_mode == "server": all_prompts_text = gather_object(prompts_text) - if has_images: + if images is not None: all_images = gather_object(images) if self.accelerator.is_main_process: @@ -185,7 +184,7 @@ def _generate_and_score_completions( # prompt individually. ordered_set_of_prompts = all_prompts_text[:: self.num_generations] - if has_images: + if images is not None: ordered_set_of_images = all_images[:: self.num_generations] else: ordered_set_of_images = None @@ -250,7 +249,7 @@ def _generate_and_score_completions( torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group) all_prompts_text = [p for sublist in gathered_prompts for p in sublist] - if has_images: + if images is not None: gathered_images = [None for _ in range(self.vllm_tensor_parallel_size)] torch.distributed.all_gather_object(gathered_images, images, group=self.tp_group) all_images = [img for sublist in gathered_images for img in sublist] @@ -258,15 +257,13 @@ def _generate_and_score_completions( all_images = None else: all_prompts_text = prompts_text - all_images = images if has_images else None + all_images = images - if has_images and all_images: + if images is not None and all_images: vllm_inputs = [] - for prompt, image in zip(all_prompts_text, all_images): - if image is not None: - vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image}}) - else: - vllm_inputs.append(prompt) + for prompt, image_list in zip(all_prompts_text, all_images): + vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image_list}}) + else: vllm_inputs = all_prompts_text @@ -381,6 +378,8 @@ def _generate_and_score_completions( logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size + num_images = [len(img_list) for img_list in images] if images is not None else None + with torch.no_grad(): # If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of # a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the @@ -401,6 +400,7 @@ def _generate_and_score_completions( batch_size, pixel_values=prompt_inputs.get("pixel_values"), image_grid_thw=prompt_inputs.get("image_grid_thw"), + num_images=num_images, pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), ) @@ -425,6 +425,7 @@ def _generate_and_score_completions( batch_size=batch_size, pixel_values=prompt_inputs.get("pixel_values"), image_grid_thw=prompt_inputs.get("image_grid_thw"), + num_images=num_images, pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), ) @@ -438,6 +439,7 @@ def _generate_and_score_completions( batch_size=batch_size, pixel_values=prompt_inputs.get("pixel_values"), image_grid_thw=prompt_inputs.get("image_grid_thw"), + num_images=num_images, pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), ) @@ -457,7 +459,7 @@ def _generate_and_score_completions( # Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is # important because rewards will be normalized per group, and completions are distributed. We will later slice # rewards_per_func to extract each process's subset. - rewards_per_func = self._calculate_rewards(inputs, original_prompts, completions, completion_ids_list) + rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list) # Apply weights to each reward function's output and sum rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) @@ -535,8 +537,8 @@ def _generate_and_score_completions( self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist()) self._logs["advantages"].extend(all_process_advantages.tolist()) - if has_images: - self._logs["image"].extend(gather_object(images)) + if images is not None: + self._logs["images"].extend(gather_object(images)) if self.use_vllm and self.vllm_importance_sampling_correction: delta = torch.abs(old_per_token_logps - sampling_per_token_logps) @@ -607,6 +609,8 @@ def _generate_and_score_completions( output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"] if "image_sizes" in prompt_inputs: output["image_sizes"] = prompt_inputs["image_sizes"] + if images is not None: + output["images"] = images return output def slice_group_data(