Skip to content

Commit 4fc4314

Browse files
qgallouedeckashif
authored andcommitted
📽 Multi image support for GRPO replay buffer (#4157)
1 parent 3f62147 commit 4fc4314

File tree

1 file changed

+28
-24
lines changed

1 file changed

+28
-24
lines changed

trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import copy
1615
import heapq
1716
import re
1817
from contextlib import nullcontext
@@ -89,22 +88,22 @@ def _generate_and_score_completions(
8988

9089
prompts = [x["prompt"] for x in inputs]
9190

92-
# We don't yet support visual reward models/function, so we keep a copy of the original text-only prompts for
93-
# later use in the reward computation. If images are present, we insert {"type": "image"} as required by the
94-
# VLM chat template.
95-
original_prompts = copy.deepcopy(prompts)
91+
if "images" in inputs[0]:
92+
images = [example.get("images") for example in inputs]
93+
elif "image" in inputs[0]:
94+
images = [[example.get("image")] if example.get("image") is not None else None for example in inputs]
95+
else:
96+
images = None
9697

9798
# If the prompts are conversational and the inputs contain images, we need to convert the prompts from
9899
# [{"role": "user", "content": "What color is the sky?"}] to
99100
# [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}]
100101
kwargs = {}
101-
has_images = "image" in inputs[0]
102-
if has_images:
103-
images = [example.get("image") for example in inputs]
104-
kwargs = {"images": [[img] for img in images]}
105-
for prompt in prompts:
102+
if images is not None:
103+
kwargs = {"images": images}
104+
for prompt, image_list in zip(prompts, images):
106105
if isinstance(prompt, list): # i.e., when using conversational data
107-
prepare_multimodal_messages(prompt, num_images=1)
106+
prepare_multimodal_messages(prompt, num_images=len(image_list))
108107

109108
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
110109

@@ -176,7 +175,7 @@ def _generate_and_score_completions(
176175
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
177176
if self.vllm_mode == "server":
178177
all_prompts_text = gather_object(prompts_text)
179-
if has_images:
178+
if images is not None:
180179
all_images = gather_object(images)
181180

182181
if self.accelerator.is_main_process:
@@ -185,7 +184,7 @@ def _generate_and_score_completions(
185184
# prompt individually.
186185
ordered_set_of_prompts = all_prompts_text[:: self.num_generations]
187186

188-
if has_images:
187+
if images is not None:
189188
ordered_set_of_images = all_images[:: self.num_generations]
190189
else:
191190
ordered_set_of_images = None
@@ -250,23 +249,21 @@ def _generate_and_score_completions(
250249
torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group)
251250
all_prompts_text = [p for sublist in gathered_prompts for p in sublist]
252251

253-
if has_images:
252+
if images is not None:
254253
gathered_images = [None for _ in range(self.vllm_tensor_parallel_size)]
255254
torch.distributed.all_gather_object(gathered_images, images, group=self.tp_group)
256255
all_images = [img for sublist in gathered_images for img in sublist]
257256
else:
258257
all_images = None
259258
else:
260259
all_prompts_text = prompts_text
261-
all_images = images if has_images else None
260+
all_images = images
262261

263-
if has_images and all_images:
262+
if images is not None and all_images:
264263
vllm_inputs = []
265-
for prompt, image in zip(all_prompts_text, all_images):
266-
if image is not None:
267-
vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image}})
268-
else:
269-
vllm_inputs.append(prompt)
264+
for prompt, image_list in zip(all_prompts_text, all_images):
265+
vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image_list}})
266+
270267
else:
271268
vllm_inputs = all_prompts_text
272269

@@ -381,6 +378,8 @@ def _generate_and_score_completions(
381378
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
382379
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
383380

381+
num_images = [len(img_list) for img_list in images] if images is not None else None
382+
384383
with torch.no_grad():
385384
# If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of
386385
# a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the
@@ -401,6 +400,7 @@ def _generate_and_score_completions(
401400
batch_size,
402401
pixel_values=prompt_inputs.get("pixel_values"),
403402
image_grid_thw=prompt_inputs.get("image_grid_thw"),
403+
num_images=num_images,
404404
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
405405
image_sizes=prompt_inputs.get("image_sizes"),
406406
)
@@ -425,6 +425,7 @@ def _generate_and_score_completions(
425425
batch_size=batch_size,
426426
pixel_values=prompt_inputs.get("pixel_values"),
427427
image_grid_thw=prompt_inputs.get("image_grid_thw"),
428+
num_images=num_images,
428429
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
429430
image_sizes=prompt_inputs.get("image_sizes"),
430431
)
@@ -438,6 +439,7 @@ def _generate_and_score_completions(
438439
batch_size=batch_size,
439440
pixel_values=prompt_inputs.get("pixel_values"),
440441
image_grid_thw=prompt_inputs.get("image_grid_thw"),
442+
num_images=num_images,
441443
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
442444
image_sizes=prompt_inputs.get("image_sizes"),
443445
)
@@ -457,7 +459,7 @@ def _generate_and_score_completions(
457459
# Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is
458460
# important because rewards will be normalized per group, and completions are distributed. We will later slice
459461
# rewards_per_func to extract each process's subset.
460-
rewards_per_func = self._calculate_rewards(inputs, original_prompts, completions, completion_ids_list)
462+
rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list)
461463

462464
# Apply weights to each reward function's output and sum
463465
rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)
@@ -535,8 +537,8 @@ def _generate_and_score_completions(
535537
self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist())
536538
self._logs["advantages"].extend(all_process_advantages.tolist())
537539

538-
if has_images:
539-
self._logs["image"].extend(gather_object(images))
540+
if images is not None:
541+
self._logs["images"].extend(gather_object(images))
540542

541543
if self.use_vllm and self.vllm_importance_sampling_correction:
542544
delta = torch.abs(old_per_token_logps - sampling_per_token_logps)
@@ -607,6 +609,8 @@ def _generate_and_score_completions(
607609
output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"]
608610
if "image_sizes" in prompt_inputs:
609611
output["image_sizes"] = prompt_inputs["image_sizes"]
612+
if images is not None:
613+
output["images"] = images
610614
return output
611615

612616
def slice_group_data(

0 commit comments

Comments
 (0)