Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import heapq
import re
from contextlib import nullcontext
Expand Down Expand Up @@ -89,22 +88,22 @@ def _generate_and_score_completions(

prompts = [x["prompt"] for x in inputs]

# We don't yet support visual reward models/function, so we keep a copy of the original text-only prompts for
# later use in the reward computation. If images are present, we insert {"type": "image"} as required by the
# VLM chat template.
original_prompts = copy.deepcopy(prompts)
if "images" in inputs[0]:
images = [example.get("images") for example in inputs]
elif "image" in inputs[0]:
images = [[example.get("image")] if example.get("image") is not None else None for example in inputs]
else:
images = None

# If the prompts are conversational and the inputs contain images, we need to convert the prompts from
# [{"role": "user", "content": "What color is the sky?"}] to
# [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}]
kwargs = {}
has_images = "image" in inputs[0]
if has_images:
images = [example.get("image") for example in inputs]
kwargs = {"images": [[img] for img in images]}
for prompt in prompts:
if images is not None:
kwargs = {"images": images}
for prompt, image_list in zip(prompts, images):
if isinstance(prompt, list): # i.e., when using conversational data
prepare_multimodal_messages(prompt, num_images=1)
prepare_multimodal_messages(prompt, num_images=len(image_list))

prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]

Expand Down Expand Up @@ -176,7 +175,7 @@ def _generate_and_score_completions(
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
if self.vllm_mode == "server":
all_prompts_text = gather_object(prompts_text)
if has_images:
if images is not None:
all_images = gather_object(images)

if self.accelerator.is_main_process:
Expand All @@ -185,7 +184,7 @@ def _generate_and_score_completions(
# prompt individually.
ordered_set_of_prompts = all_prompts_text[:: self.num_generations]

if has_images:
if images is not None:
ordered_set_of_images = all_images[:: self.num_generations]
else:
ordered_set_of_images = None
Expand Down Expand Up @@ -250,23 +249,21 @@ def _generate_and_score_completions(
torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group)
all_prompts_text = [p for sublist in gathered_prompts for p in sublist]

if has_images:
if images is not None:
gathered_images = [None for _ in range(self.vllm_tensor_parallel_size)]
torch.distributed.all_gather_object(gathered_images, images, group=self.tp_group)
all_images = [img for sublist in gathered_images for img in sublist]
else:
all_images = None
else:
all_prompts_text = prompts_text
all_images = images if has_images else None
all_images = images

if has_images and all_images:
if images is not None and all_images:
vllm_inputs = []
for prompt, image in zip(all_prompts_text, all_images):
if image is not None:
vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image}})
else:
vllm_inputs.append(prompt)
for prompt, image_list in zip(all_prompts_text, all_images):
vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image_list}})

else:
vllm_inputs = all_prompts_text

Expand Down Expand Up @@ -381,6 +378,8 @@ def _generate_and_score_completions(
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size

num_images = [len(img_list) for img_list in images] if images is not None else None

with torch.no_grad():
# If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of
# a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the
Expand All @@ -401,6 +400,7 @@ def _generate_and_score_completions(
batch_size,
pixel_values=prompt_inputs.get("pixel_values"),
image_grid_thw=prompt_inputs.get("image_grid_thw"),
num_images=num_images,
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
image_sizes=prompt_inputs.get("image_sizes"),
)
Expand All @@ -425,6 +425,7 @@ def _generate_and_score_completions(
batch_size=batch_size,
pixel_values=prompt_inputs.get("pixel_values"),
image_grid_thw=prompt_inputs.get("image_grid_thw"),
num_images=num_images,
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
image_sizes=prompt_inputs.get("image_sizes"),
)
Expand All @@ -438,6 +439,7 @@ def _generate_and_score_completions(
batch_size=batch_size,
pixel_values=prompt_inputs.get("pixel_values"),
image_grid_thw=prompt_inputs.get("image_grid_thw"),
num_images=num_images,
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
image_sizes=prompt_inputs.get("image_sizes"),
)
Expand All @@ -457,7 +459,7 @@ def _generate_and_score_completions(
# Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is
# important because rewards will be normalized per group, and completions are distributed. We will later slice
# rewards_per_func to extract each process's subset.
rewards_per_func = self._calculate_rewards(inputs, original_prompts, completions, completion_ids_list)
rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list)

# Apply weights to each reward function's output and sum
rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)
Expand Down Expand Up @@ -535,8 +537,8 @@ def _generate_and_score_completions(
self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist())
self._logs["advantages"].extend(all_process_advantages.tolist())

if has_images:
self._logs["image"].extend(gather_object(images))
if images is not None:
self._logs["images"].extend(gather_object(images))

if self.use_vllm and self.vllm_importance_sampling_correction:
delta = torch.abs(old_per_token_logps - sampling_per_token_logps)
Expand Down Expand Up @@ -607,6 +609,8 @@ def _generate_and_score_completions(
output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"]
if "image_sizes" in prompt_inputs:
output["image_sizes"] = prompt_inputs["image_sizes"]
if images is not None:
output["images"] = images
return output

def slice_group_data(
Expand Down
Loading