1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- import copy
1615import heapq
1716import re
1817from contextlib import nullcontext
@@ -89,22 +88,22 @@ def _generate_and_score_completions(
8988
9089 prompts = [x ["prompt" ] for x in inputs ]
9190
92- # We don't yet support visual reward models/function, so we keep a copy of the original text-only prompts for
93- # later use in the reward computation. If images are present, we insert {"type": "image"} as required by the
94- # VLM chat template.
95- original_prompts = copy .deepcopy (prompts )
91+ if "images" in inputs [0 ]:
92+ images = [example .get ("images" ) for example in inputs ]
93+ elif "image" in inputs [0 ]:
94+ images = [[example .get ("image" )] if example .get ("image" ) is not None else None for example in inputs ]
95+ else :
96+ images = None
9697
9798 # If the prompts are conversational and the inputs contain images, we need to convert the prompts from
9899 # [{"role": "user", "content": "What color is the sky?"}] to
99100 # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}]
100101 kwargs = {}
101- has_images = "image" in inputs [0 ]
102- if has_images :
103- images = [example .get ("image" ) for example in inputs ]
104- kwargs = {"images" : [[img ] for img in images ]}
105- for prompt in prompts :
102+ if images is not None :
103+ kwargs = {"images" : images }
104+ for prompt , image_list in zip (prompts , images ):
106105 if isinstance (prompt , list ): # i.e., when using conversational data
107- prepare_multimodal_messages (prompt , num_images = 1 )
106+ prepare_multimodal_messages (prompt , num_images = len ( image_list ) )
108107
109108 prompts_text = [maybe_apply_chat_template (example , self .processing_class )["prompt" ] for example in inputs ]
110109
@@ -176,7 +175,7 @@ def _generate_and_score_completions(
176175 # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
177176 if self .vllm_mode == "server" :
178177 all_prompts_text = gather_object (prompts_text )
179- if has_images :
178+ if images is not None :
180179 all_images = gather_object (images )
181180
182181 if self .accelerator .is_main_process :
@@ -185,7 +184,7 @@ def _generate_and_score_completions(
185184 # prompt individually.
186185 ordered_set_of_prompts = all_prompts_text [:: self .num_generations ]
187186
188- if has_images :
187+ if images is not None :
189188 ordered_set_of_images = all_images [:: self .num_generations ]
190189 else :
191190 ordered_set_of_images = None
@@ -250,23 +249,21 @@ def _generate_and_score_completions(
250249 torch .distributed .all_gather_object (gathered_prompts , prompts_text , group = self .tp_group )
251250 all_prompts_text = [p for sublist in gathered_prompts for p in sublist ]
252251
253- if has_images :
252+ if images is not None :
254253 gathered_images = [None for _ in range (self .vllm_tensor_parallel_size )]
255254 torch .distributed .all_gather_object (gathered_images , images , group = self .tp_group )
256255 all_images = [img for sublist in gathered_images for img in sublist ]
257256 else :
258257 all_images = None
259258 else :
260259 all_prompts_text = prompts_text
261- all_images = images if has_images else None
260+ all_images = images
262261
263- if has_images and all_images :
262+ if images is not None and all_images :
264263 vllm_inputs = []
265- for prompt , image in zip (all_prompts_text , all_images ):
266- if image is not None :
267- vllm_inputs .append ({"prompt" : prompt , "multi_modal_data" : {"image" : image }})
268- else :
269- vllm_inputs .append (prompt )
264+ for prompt , image_list in zip (all_prompts_text , all_images ):
265+ vllm_inputs .append ({"prompt" : prompt , "multi_modal_data" : {"image" : image_list }})
266+
270267 else :
271268 vllm_inputs = all_prompts_text
272269
@@ -381,6 +378,8 @@ def _generate_and_score_completions(
381378 logits_to_keep = completion_ids .size (1 ) # we only need to compute the logits for the completion tokens
382379 batch_size = self .args .per_device_train_batch_size if mode == "train" else self .args .per_device_eval_batch_size
383380
381+ num_images = [len (img_list ) for img_list in images ] if images is not None else None
382+
384383 with torch .no_grad ():
385384 # If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of
386385 # a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the
@@ -401,6 +400,7 @@ def _generate_and_score_completions(
401400 batch_size ,
402401 pixel_values = prompt_inputs .get ("pixel_values" ),
403402 image_grid_thw = prompt_inputs .get ("image_grid_thw" ),
403+ num_images = num_images ,
404404 pixel_attention_mask = prompt_inputs .get ("pixel_attention_mask" ),
405405 image_sizes = prompt_inputs .get ("image_sizes" ),
406406 )
@@ -425,6 +425,7 @@ def _generate_and_score_completions(
425425 batch_size = batch_size ,
426426 pixel_values = prompt_inputs .get ("pixel_values" ),
427427 image_grid_thw = prompt_inputs .get ("image_grid_thw" ),
428+ num_images = num_images ,
428429 pixel_attention_mask = prompt_inputs .get ("pixel_attention_mask" ),
429430 image_sizes = prompt_inputs .get ("image_sizes" ),
430431 )
@@ -438,6 +439,7 @@ def _generate_and_score_completions(
438439 batch_size = batch_size ,
439440 pixel_values = prompt_inputs .get ("pixel_values" ),
440441 image_grid_thw = prompt_inputs .get ("image_grid_thw" ),
442+ num_images = num_images ,
441443 pixel_attention_mask = prompt_inputs .get ("pixel_attention_mask" ),
442444 image_sizes = prompt_inputs .get ("image_sizes" ),
443445 )
@@ -457,7 +459,7 @@ def _generate_and_score_completions(
457459 # Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is
458460 # important because rewards will be normalized per group, and completions are distributed. We will later slice
459461 # rewards_per_func to extract each process's subset.
460- rewards_per_func = self ._calculate_rewards (inputs , original_prompts , completions , completion_ids_list )
462+ rewards_per_func = self ._calculate_rewards (inputs , prompts , completions , completion_ids_list )
461463
462464 # Apply weights to each reward function's output and sum
463465 rewards = (rewards_per_func * self .reward_weights .to (device ).unsqueeze (0 )).nansum (dim = 1 )
@@ -535,8 +537,8 @@ def _generate_and_score_completions(
535537 self ._logs ["rewards" ][name ].extend (rewards_per_func [:, i ].tolist ())
536538 self ._logs ["advantages" ].extend (all_process_advantages .tolist ())
537539
538- if has_images :
539- self ._logs ["image " ].extend (gather_object (images ))
540+ if images is not None :
541+ self ._logs ["images " ].extend (gather_object (images ))
540542
541543 if self .use_vllm and self .vllm_importance_sampling_correction :
542544 delta = torch .abs (old_per_token_logps - sampling_per_token_logps )
@@ -607,6 +609,8 @@ def _generate_and_score_completions(
607609 output ["pixel_attention_mask" ] = prompt_inputs ["pixel_attention_mask" ]
608610 if "image_sizes" in prompt_inputs :
609611 output ["image_sizes" ] = prompt_inputs ["image_sizes" ]
612+ if images is not None :
613+ output ["images" ] = images
610614 return output
611615
612616 def slice_group_data (
0 commit comments