diff --git a/docs/source/experimental.md b/docs/source/experimental.md index 3cdddc5fefb..b00b89efc82 100644 --- a/docs/source/experimental.md +++ b/docs/source/experimental.md @@ -66,7 +66,7 @@ class GroupFilter: return group_scores training_args = GFPOConfig( - output_dir="Qwen3-0.6B-GFPO" + output_dir="Qwen3-0.6B-GFPO", per_device_train_batch_size=4, num_remains_in_group=2, bf16=True, diff --git a/trl/experimental/gfpo/gfpo_trainer.py b/trl/experimental/gfpo/gfpo_trainer.py index 6ac4b6acc7f..5e228c1e883 100644 --- a/trl/experimental/gfpo/gfpo_trainer.py +++ b/trl/experimental/gfpo/gfpo_trainer.py @@ -13,32 +13,20 @@ # limitations under the License. import logging -import re -from contextlib import nullcontext from typing import Any, Callable import torch -import torch.utils.data -from accelerate.utils import broadcast_object_list, gather_object -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from transformers.utils import is_flash_attn_2_available - -from ...data_utils import is_conversational, maybe_apply_chat_template, prepare_multimodal_messages -from ...extras.profiling import profiling_context -from ...import_utils import is_vllm_available -from ...models import unwrap_model_for_generation +from accelerate.utils import gather_object + +from ...data_utils import is_conversational from ...trainer.grpo_trainer import GRPOTrainer as _GRPOTrainer -from ...trainer.utils import nanmax, nanmin, nanstd, pad, truncate_with_protected_tokens +from ...trainer.utils import nanmax, nanmin, nanstd logger = logging.getLogger(__name__) GroupFilterFunc = Callable[[list[list[Any]], list[list[Any]]], list[list[float]]] -if is_vllm_available(): - from vllm import SamplingParams - from vllm.sampling_params import GuidedDecodingParams - class GFPOTrainer(_GRPOTrainer): def __init__( @@ -89,284 +77,22 @@ def _generate_and_score_completions(self, inputs): else: images = None - # If the prompts are conversational and the inputs contain images, we need to convert the prompts from - # [{"role": "user", "content": "What color is the sky?"}] to - # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] - kwargs = {} - if images is not None: - kwargs = {"images": images} - for prompt, image_list in zip(prompts, images): - if isinstance(prompt, list): # i.e., when using conversational data - prepare_multimodal_messages(prompt, num_images=len(image_list)) - - prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] - - prompt_inputs = self.processing_class( - text=prompts_text, - return_tensors="pt", - padding=True, - padding_side="left", - add_special_tokens=False, - **kwargs, - ) - prompt_inputs = super()._prepare_inputs(prompt_inputs) - prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] - - if self.max_prompt_length is not None: - # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens. - # Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text, - # because we can't use `skip_special_tokens=True` (some special tokens are still needed for generation). - protected = [self.image_token_id, self.vision_start_token_id, self.vision_end_token_id] - protected = [token for token in protected if token is not None] - prompt_ids, prompt_mask = truncate_with_protected_tokens( - prompt_ids, prompt_mask, self.max_prompt_length, protected - ) - - prompts_text = self.processing_class.batch_decode( - prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False - ) - prompts_text = [re.sub(rf"^({re.escape(self.pad_token)})+", "", text) for text in prompts_text] - - # The chat template sometimes inserts a single image token into the prompt text. However, when this text is - # later tokenized, the single image token string is expanded into multiple image token IDs, depending on the - # image size. Since we're detokenizing here, we may see repeated image tokens in the decoded text. We - # collapse them back into a single token string to match the original chat template in case it originally - # applies it. Otherwise, it assumes that the chat template uses only vision_start_token_id to indicate images - # (e.g. Gemma 3) and removes all image_token instances and vision_end_token_id as well, leaving only - # the vision_start_token_id (e.g. ). - if self.image_token is not None: - escaped_img_token = re.escape(self.image_token) - # Search for the image token in the chat template - if re.search(escaped_img_token, self.processing_class.chat_template): - prompts_text = [ - re.sub(rf"({escaped_img_token})+", self.image_token, text) for text in prompts_text - ] - else: - # If the chat template doesn't use the image token, we remove all instances of it + vision_end_token_id - if self.vision_end_token_id is not None: - escaped_eoi_token = re.escape( - self.processing_class.tokenizer.decode([self.vision_end_token_id]) - ) - prompts_text = [ - re.sub(rf"({escaped_img_token})+{escaped_eoi_token}", "", text) for text in prompts_text - ] - else: - # If vision_end_token_id is None, just remove the image tokens - prompts_text = [re.sub(rf"({escaped_img_token})+", "", text) for text in prompts_text] - - # Generate completions using either vLLM or regular generation - if self.use_vllm: - if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode: - # wake up colocated vLLM instances if needed - torch.cuda.empty_cache() # required to avoid OOM in some cases - self.llm.wake_up() - - # First, update the vLLM weights if needed - if self.state.global_step != self._last_loaded_step: - self._move_model_to_vllm() - self._last_loaded_step = self.state.global_step - - # Generate completions using vLLM: gather all prompts and use them in a single call in the main process - if self.vllm_mode == "server": - all_prompts_text = gather_object(prompts_text) - if images is not None: - all_images = gather_object(images) - - if self.accelerator.is_main_process: - # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate - # num_generations outputs for each one. This is faster than generating outputs for each duplicate - # prompt individually. - ordered_set_of_prompts = all_prompts_text[:: self.num_generations] - - if images is not None: - ordered_set_of_images = all_images[:: self.num_generations] - else: - ordered_set_of_images = None - - with profiling_context(self, "vLLM.generate"): - output = self.vllm_client.generate( - prompts=ordered_set_of_prompts, - images=ordered_set_of_images, - n=self.num_generations, - repetition_penalty=self.repetition_penalty, - temperature=self.temperature, - top_p=self.top_p, - top_k=-1 if self.top_k is None else self.top_k, - min_p=0.0 if self.min_p is None else self.min_p, - max_tokens=self.max_completion_length, - guided_decoding_regex=self.guided_decoding_regex, - generation_kwargs=self.args.generation_kwargs, - ) - payload = (output["completion_ids"], output["logprobs"]) - else: - payload = None - - # Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice. - obj_list = [payload] - broadcast_object_list(obj_list, from_process=0) - completion_ids, all_logprobs = obj_list[0] - - process_slice = slice( - self.accelerator.process_index * len(prompts), - (self.accelerator.process_index + 1) * len(prompts), - ) - completion_ids = completion_ids[process_slice] - all_logprobs = all_logprobs[process_slice] - - # Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts - elif self.vllm_mode == "colocate": - if self.guided_decoding_regex: - guided_decoding = GuidedDecodingParams(regex=self.guided_decoding_regex) - else: - guided_decoding = None - - generation_kwargs = { - "n": 1, # vLLM on each GPU generates only 1 in colocate mode - "repetition_penalty": self.repetition_penalty, - "temperature": self.temperature, - "top_p": self.top_p, - "top_k": -1 if self.top_k is None else self.top_k, - "min_p": 0.0 if self.min_p is None else self.min_p, - "max_tokens": self.max_completion_length, - "guided_decoding": guided_decoding, - "logprobs": 0, # only return the logprob of the generated token - } - if self.args.generation_kwargs is not None: - generation_kwargs.update(self.args.generation_kwargs) - sampling_params = SamplingParams(**generation_kwargs) - - if self.vllm_tensor_parallel_size > 1: - # Gather prompts from all ranks in the TP group and flatten. - # Each rank starts with its own prompts; after gathering, all ranks see the full group set. - orig_size = len(prompts_text) - gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)] - torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group) - all_prompts_text = [p for sublist in gathered_prompts for p in sublist] - - if images is not None: - gathered_images = [None for _ in range(self.vllm_tensor_parallel_size)] - torch.distributed.all_gather_object(gathered_images, images, group=self.tp_group) - all_images = [img for sublist in gathered_images for img in sublist] - else: - all_images = None - else: - all_prompts_text = prompts_text - all_images = images - - if images is not None and all_images: - vllm_inputs = [] - for prompt, image_list in zip(all_prompts_text, all_images): - vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image_list}}) - - else: - vllm_inputs = all_prompts_text - - with profiling_context(self, "vLLM.generate"): - all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False) - - completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] - all_logprobs = [ - [next(iter(lp.values())).logprob for lp in output.logprobs] - for outputs in all_outputs - for output in outputs.outputs - ] - - if self.vllm_tensor_parallel_size > 1: - # Slice completions for this rank within its TP group. - # Each rank generates all outputs — we keep only our share. - local_rank_in_group = torch.distributed.get_rank(group=self.tp_group) - tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) - completion_ids = completion_ids[tp_slice] - all_logprobs = all_logprobs[tp_slice] - - if self.args.vllm_enable_sleep_mode: - self.llm.sleep(level=1) - - # Pad the completions, and concatenate them with the prompts - completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] - completion_ids = pad(completion_ids, padding_value=self.pad_token_id) - prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) - sampling_per_token_logps = [ - torch.tensor(logprobs, device=device, dtype=torch.float32) for logprobs in all_logprobs - ] - sampling_per_token_logps = pad(sampling_per_token_logps, padding_value=0.0) - - elif self.use_transformers_paged: - # Re-process inputs for paged generation if needed - # Note: images are already validated and preprocessed above - paged_prompt_inputs = self.processing_class(text=prompts_text, **kwargs) - previous_attn = self.model_wrapped.config._attn_implementation - - if is_flash_attn_2_available(): - self.model_wrapped.config._attn_implementation = "paged_attention" - else: - self.model_wrapped.config._attn_implementation = "sdpa_paged" - with ( - profiling_context(self, "transformers.generate_batch"), - unwrap_model_for_generation( - self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation - ) as unwrapped_model, - torch.no_grad(), - FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), - ): - # Cast to the appropriate dtype based on training configuration - if self.args.bf16: - unwrapped_model.to(torch.bfloat16) - elif self.args.fp16: - unwrapped_model.to(torch.float16) - with torch.inference_mode(): - all_outputs = unwrapped_model.generate_batch( - paged_prompt_inputs.input_ids, generation_config=self.generation_config, progress_bar=False - ) - completion_ids = [output.generated_tokens for output in all_outputs.values()] - completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] - completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") - prompt_ids = [torch.tensor(ids, device=device) for ids in paged_prompt_inputs.input_ids] - prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") - prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) - # Restore the original attention implementation, training mode - self.model_wrapped.config._attn_implementation = previous_attn - else: - # Regular generation path - with ( - profiling_context(self, "transformers.generate"), - unwrap_model_for_generation( - self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation - ) as unwrapped_model, - torch.no_grad(), - FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), - ): - prompt_inputs["input_ids"], prompt_inputs["attention_mask"] = prompt_ids, prompt_mask - prompt_completion_ids = unwrapped_model.generate( - **prompt_inputs, generation_config=self.generation_config, disable_compile=True - ) - # Compute prompt length and extract completion ids - prompt_length = prompt_ids.size(1) - prompt_ids = prompt_completion_ids[:, :prompt_length] - completion_ids = prompt_completion_ids[:, prompt_length:] - - # Mask everything after the first EOS token - is_eos = completion_ids == self.eos_token_id - eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) - eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] - sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) - completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() + ( + prompt_ids, + completion_ids, + prompt_mask, + completion_mask, + num_items_in_batch, + sampling_per_token_logps, + forward_kwargs, + ) = self._generate(prompts, images) # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need # to re-tokenize completions if the reward is computed from tokens. completion_ids_list = [row[mask_row].tolist() for row, mask_row in zip(completion_ids, completion_mask.bool())] - # Sum along sequence dimension (dim=1) to get completion length per sequence, used for logging - completion_lengths = completion_mask.sum(1) - agg_completion_lengths = self.accelerator.gather(completion_lengths) - num_items_in_batch = agg_completion_lengths.sum() # this is required for the DAPO loss - - # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask - if self.mask_truncated_completions: - truncated_completions = ~is_eos.any(dim=1) - completion_mask = completion_mask * (~truncated_completions).unsqueeze(1).int() - # Concatenate prompt_mask with completion_mask for logit computation + prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens @@ -392,11 +118,8 @@ def _generate_and_score_completions(self, inputs): attention_mask, logits_to_keep, batch_size, - pixel_values=prompt_inputs.get("pixel_values"), - image_grid_thw=prompt_inputs.get("image_grid_thw"), num_images=num_images, - pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), - image_sizes=prompt_inputs.get("image_sizes"), + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes ) else: old_per_token_logps = None @@ -417,11 +140,8 @@ def _generate_and_score_completions(self, inputs): attention_mask, logits_to_keep, batch_size=batch_size, - pixel_values=prompt_inputs.get("pixel_values"), - image_grid_thw=prompt_inputs.get("image_grid_thw"), num_images=num_images, - pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), - image_sizes=prompt_inputs.get("image_sizes"), + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes ) else: with self.accelerator.unwrap_model(self.model).disable_adapter(): @@ -431,16 +151,14 @@ def _generate_and_score_completions(self, inputs): attention_mask, logits_to_keep, batch_size=batch_size, - pixel_values=prompt_inputs.get("pixel_values"), - image_grid_thw=prompt_inputs.get("image_grid_thw"), num_images=num_images, - pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), - image_sizes=prompt_inputs.get("image_sizes"), + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes ) else: ref_per_token_logps = None - # Decode the generated completions + # Decode + prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=True) completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) if is_conversational(inputs[0]): completions = [] @@ -529,28 +247,6 @@ def _generate_and_score_completions(self, inputs): completion_lengths = completion_mask.sum(1) agg_completion_lengths = self.accelerator.gather(completion_lengths) num_items_in_batch = agg_completion_lengths.sum() - is_eos = completion_ids == self.eos_token_id - - # Log the metrics - if mode == "train": - self.state.num_input_tokens_seen += self.accelerator.gather(attention_mask.sum()).sum().item() - self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] - - # Log completion lengths, mean, min, max - self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item()) - self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) - self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) - - # Identify sequences that terminated with EOS and log their lengths - agg_terminated_with_eos = self.accelerator.gather(is_eos.any(dim=1)) - term_completion_lengths = agg_completion_lengths[agg_terminated_with_eos] - clipped_completions_ratio = 1 - len(term_completion_lengths) / len(agg_completion_lengths) - self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio) - if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found - term_completion_lengths = torch.zeros(1, device=device) - self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) - self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) - self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values) for i, reward_func_name in enumerate(self.reward_func_names): @@ -633,14 +329,14 @@ def _generate_and_score_completions(self, inputs): output["importance_sampling_ratio"] = importance_sampling_ratio if ref_per_token_logps is not None: output["ref_per_token_logps"] = ref_per_token_logps - if "pixel_values" in prompt_inputs: - output["pixel_values"] = prompt_inputs["pixel_values"] - if "image_grid_thw" in prompt_inputs: - output["image_grid_thw"] = prompt_inputs["image_grid_thw"] - if "pixel_attention_mask" in prompt_inputs: - output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"] - if "image_sizes" in prompt_inputs: - output["image_sizes"] = prompt_inputs["image_sizes"] + if "pixel_values" in forward_kwargs: + output["pixel_values"] = forward_kwargs["pixel_values"] + if "image_grid_thw" in forward_kwargs: + output["image_grid_thw"] = forward_kwargs["image_grid_thw"] + if "pixel_attention_mask" in forward_kwargs: + output["pixel_attention_mask"] = forward_kwargs["pixel_attention_mask"] + if "image_sizes" in forward_kwargs: + output["image_sizes"] = forward_kwargs["image_sizes"] if images is not None: output["num_images"] = num_images return output diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 70228fd111e..216591acf6a 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -492,7 +492,7 @@ def __init__( if not is_vllm_available(): raise ImportError( "vLLM is not available and `use_vllm` is set to True. Please install vLLM with " - "`pip install [vllm]` to use it." + "`pip install trl[vllm]` to use it." ) if self.vllm_mode == "server": @@ -545,7 +545,7 @@ def __init__( distributed_executor_backend="external_launcher", # Feed identical seed for tp groups to ensure sampling results are the same across workers seed=self.accelerator.process_index // self.vllm_tensor_parallel_size, - # Latest vLLM v1 memory profiler is misled by the high default value (i.e., 32768) - thinking there's not enough memory + # Latest vLLM v1 memory profiler is misled by the high default value (i.e., 32768) - thinking there's not enough memory max_num_batched_tokens=4096, model_impl=self.args.vllm_model_impl, enable_sleep_mode=self.args.vllm_enable_sleep_mode, @@ -1070,21 +1070,10 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): rewards_per_func = gather(rewards_per_func) return rewards_per_func - def _generate_and_score_completions( - self, inputs: list[dict[str, Union[torch.Tensor, Any]]] - ) -> dict[str, Union[torch.Tensor, Any]]: + def _generate(self, prompts: list[str], images: Optional[list]): device = self.accelerator.device mode = "train" if self.model.training else "eval" - prompts = [x["prompt"] for x in inputs] - - if "images" in inputs[0]: - images = [example.get("images") for example in inputs] - elif "image" in inputs[0]: - images = [[example.get("image")] if example.get("image") is not None else None for example in inputs] - else: - images = None - # If the prompts are conversational and the inputs contain images, we need to convert the prompts from # [{"role": "user", "content": "What color is the sky?"}] to # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] @@ -1095,7 +1084,9 @@ def _generate_and_score_completions( if isinstance(prompt, list): # i.e., when using conversational data prepare_multimodal_messages(prompt, num_images=len(image_list)) - prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] + prompts_text = [ + maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts + ] prompt_inputs = self.processing_class( text=prompts_text, @@ -1107,6 +1098,7 @@ def _generate_and_score_completions( ) prompt_inputs = super()._prepare_inputs(prompt_inputs) prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] + forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} if self.max_prompt_length is not None: # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens. @@ -1280,8 +1272,9 @@ def _generate_and_score_completions( # Pad the completions, and concatenate them with the prompts completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] + completion_mask = [torch.ones(len(ids), device=device, dtype=torch.long) for ids in completion_ids] completion_ids = pad(completion_ids, padding_value=self.pad_token_id) - prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) + completion_mask = pad(completion_mask, padding_value=0) sampling_per_token_logps = [ torch.tensor(logprobs, device=device, dtype=torch.float32) for logprobs in all_logprobs ] @@ -1320,9 +1313,10 @@ def _generate_and_score_completions( completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") prompt_ids = [torch.tensor(ids, device=device) for ids in paged_prompt_inputs.input_ids] prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") - prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # Restore the original attention implementation, training mode self.model_wrapped.config._attn_implementation = previous_attn + sampling_per_token_logps = None # not used in this case + else: # Regular generation path with ( @@ -1333,14 +1327,18 @@ def _generate_and_score_completions( torch.no_grad(), FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), ): - prompt_inputs["input_ids"], prompt_inputs["attention_mask"] = prompt_ids, prompt_mask prompt_completion_ids = unwrapped_model.generate( - **prompt_inputs, generation_config=self.generation_config, disable_compile=True + input_ids=prompt_ids, + attention_mask=prompt_mask, + **forward_kwargs, + generation_config=self.generation_config, + disable_compile=True, ) # Compute prompt length and extract completion ids prompt_length = prompt_ids.size(1) prompt_ids = prompt_completion_ids[:, :prompt_length] completion_ids = prompt_completion_ids[:, prompt_length:] + sampling_per_token_logps = None # not used in this case # Mask everything after the first EOS token is_eos = completion_ids == self.eos_token_id @@ -1349,10 +1347,6 @@ def _generate_and_score_completions( sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() - # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need - # to re-tokenize completions if the reward is computed from tokens. - completion_ids_list = [row[mask_row].tolist() for row, mask_row in zip(completion_ids, completion_mask.bool())] - # Sum along sequence dimension (dim=1) to get completion length per sequence, used for logging completion_lengths = completion_mask.sum(1) agg_completion_lengths = self.accelerator.gather(completion_lengths) @@ -1363,7 +1357,69 @@ def _generate_and_score_completions( truncated_completions = ~is_eos.any(dim=1) completion_mask = completion_mask * (~truncated_completions).unsqueeze(1).int() + # Log the metrics + if mode == "train": + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + self.state.num_input_tokens_seen += self.accelerator.gather(attention_mask.sum()).sum().item() + self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] + + # Log completion lengths, mean, min, max + self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item()) + self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) + self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) + + # Identify sequences that terminated with EOS and log their lengths + agg_terminated_with_eos = self.accelerator.gather(is_eos.any(dim=1)) + term_completion_lengths = agg_completion_lengths[agg_terminated_with_eos] + clipped_completions_ratio = 1 - len(term_completion_lengths) / len(agg_completion_lengths) + self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio) + if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found + term_completion_lengths = torch.zeros(1, device=device) + self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) + self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) + self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) + + return ( + prompt_ids, + completion_ids, + prompt_mask, + completion_mask, + num_items_in_batch, + sampling_per_token_logps, + forward_kwargs, + ) + + def _generate_and_score_completions( + self, inputs: list[dict[str, Union[torch.Tensor, Any]]] + ) -> dict[str, Union[torch.Tensor, Any]]: + device = self.accelerator.device + mode = "train" if self.model.training else "eval" + + prompts = [x["prompt"] for x in inputs] + + if "images" in inputs[0]: + images = [example.get("images") for example in inputs] + elif "image" in inputs[0]: + images = [[example.get("image")] if example.get("image") is not None else None for example in inputs] + else: + images = None + + ( + prompt_ids, + completion_ids, + prompt_mask, + completion_mask, + num_items_in_batch, + sampling_per_token_logps, + forward_kwargs, + ) = self._generate(prompts, images) + + # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need + # to re-tokenize completions if the reward is computed from tokens. + completion_ids_list = [row[mask_row].tolist() for row, mask_row in zip(completion_ids, completion_mask.bool())] + # Concatenate prompt_mask with completion_mask for logit computation + prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens @@ -1389,11 +1445,8 @@ def _generate_and_score_completions( attention_mask, logits_to_keep, batch_size, - pixel_values=prompt_inputs.get("pixel_values"), - image_grid_thw=prompt_inputs.get("image_grid_thw"), num_images=num_images, - pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), - image_sizes=prompt_inputs.get("image_sizes"), + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes ) else: old_per_token_logps = None @@ -1414,11 +1467,8 @@ def _generate_and_score_completions( attention_mask, logits_to_keep, batch_size=batch_size, - pixel_values=prompt_inputs.get("pixel_values"), - image_grid_thw=prompt_inputs.get("image_grid_thw"), num_images=num_images, - pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), - image_sizes=prompt_inputs.get("image_sizes"), + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes ) else: with self.accelerator.unwrap_model(self.model).disable_adapter(): @@ -1428,16 +1478,14 @@ def _generate_and_score_completions( attention_mask, logits_to_keep, batch_size=batch_size, - pixel_values=prompt_inputs.get("pixel_values"), - image_grid_thw=prompt_inputs.get("image_grid_thw"), num_images=num_images, - pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), - image_sizes=prompt_inputs.get("image_sizes"), + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes ) else: ref_per_token_logps = None - # Decode the generated completions + # Decode + prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=True) completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) if is_conversational(inputs[0]): completions = [] @@ -1486,27 +1534,6 @@ def _generate_and_score_completions( all_process_advantages = advantages.clone() # keep the aggregated advantages for logging advantages = advantages[process_slice] - # Log the metrics - if mode == "train": - self.state.num_input_tokens_seen += self.accelerator.gather(attention_mask.sum()).sum().item() - self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] - - # Log completion lengths, mean, min, max - self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item()) - self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) - self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) - - # Identify sequences that terminated with EOS and log their lengths - agg_terminated_with_eos = self.accelerator.gather(is_eos.any(dim=1)) - term_completion_lengths = agg_completion_lengths[agg_terminated_with_eos] - clipped_completions_ratio = 1 - len(term_completion_lengths) / len(agg_completion_lengths) - self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio) - if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found - term_completion_lengths = torch.zeros(1, device=device) - self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) - self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) - self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) - # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values) for i, reward_func_name in enumerate(self.reward_func_names): mean_rewards = torch.nanmean(rewards_per_func[:, i]).item() @@ -1573,14 +1600,14 @@ def _generate_and_score_completions( output["importance_sampling_ratio"] = importance_sampling_ratio if ref_per_token_logps is not None: output["ref_per_token_logps"] = ref_per_token_logps - if "pixel_values" in prompt_inputs: - output["pixel_values"] = prompt_inputs["pixel_values"] - if "image_grid_thw" in prompt_inputs: - output["image_grid_thw"] = prompt_inputs["image_grid_thw"] - if "pixel_attention_mask" in prompt_inputs: - output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"] - if "image_sizes" in prompt_inputs: - output["image_sizes"] = prompt_inputs["image_sizes"] + if "pixel_values" in forward_kwargs: + output["pixel_values"] = forward_kwargs["pixel_values"] + if "image_grid_thw" in forward_kwargs: + output["image_grid_thw"] = forward_kwargs["image_grid_thw"] + if "pixel_attention_mask" in forward_kwargs: + output["pixel_attention_mask"] = forward_kwargs["pixel_attention_mask"] + if "image_sizes" in forward_kwargs: + output["image_sizes"] = forward_kwargs["image_sizes"] if images is not None: output["num_images"] = num_images return output diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 8cf844154eb..3c7490eaac3 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -79,7 +79,6 @@ if is_peft_available(): from peft import PeftConfig, PeftModel - if is_vllm_available(): from vllm import LLM, SamplingParams from vllm.sampling_params import GuidedDecodingParams @@ -1062,21 +1061,10 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): rewards_per_func = gather(rewards_per_func) return rewards_per_func - def _generate_and_score_completions( - self, inputs: list[dict[str, Union[torch.Tensor, Any]]] - ) -> dict[str, Union[torch.Tensor, Any]]: + def _generate(self, prompts: list[str], images: Optional[list]): device = self.accelerator.device mode = "train" if self.model.training else "eval" - prompts = [x["prompt"] for x in inputs] - - if "images" in inputs[0]: - images = [example.get("images") for example in inputs] - elif "image" in inputs[0]: - images = [[example.get("image")] if example.get("image") is not None else None for example in inputs] - else: - images = None - # If the prompts are conversational and the inputs contain images, we need to convert the prompts from # [{"role": "user", "content": "What color is the sky?"}] to # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] @@ -1087,7 +1075,9 @@ def _generate_and_score_completions( if isinstance(prompt, list): # i.e., when using conversational data prepare_multimodal_messages(prompt, num_images=len(image_list)) - prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] + prompts_text = [ + maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts + ] prompt_inputs = self.processing_class( text=prompts_text, @@ -1099,6 +1089,7 @@ def _generate_and_score_completions( ) prompt_inputs = super()._prepare_inputs(prompt_inputs) prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] + forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} if self.max_prompt_length is not None: # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens. @@ -1264,8 +1255,9 @@ def _generate_and_score_completions( # Pad the completions, and concatenate them with the prompts completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] + completion_mask = [torch.ones(len(ids), device=device, dtype=torch.long) for ids in completion_ids] completion_ids = pad(completion_ids, padding_value=self.pad_token_id) - prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) + completion_mask = pad(completion_mask, padding_value=0) elif self.use_transformers_paged: # Re-process inputs for paged generation if needed @@ -1300,9 +1292,9 @@ def _generate_and_score_completions( completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") prompt_ids = [torch.tensor(ids, device=device) for ids in paged_prompt_inputs.input_ids] prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") - prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # Restore the original attention implementation, training mode self.model_wrapped.config._attn_implementation = previous_attn + else: # Regular generation path with ( @@ -1313,9 +1305,12 @@ def _generate_and_score_completions( torch.no_grad(), FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), ): - prompt_inputs["input_ids"], prompt_inputs["attention_mask"] = prompt_ids, prompt_mask prompt_completion_ids = unwrapped_model.generate( - **prompt_inputs, generation_config=self.generation_config, disable_compile=True + input_ids=prompt_ids, + attention_mask=prompt_mask, + **forward_kwargs, + generation_config=self.generation_config, + disable_compile=True, ) # Compute prompt length and extract completion ids prompt_length = prompt_ids.size(1) @@ -1329,10 +1324,6 @@ def _generate_and_score_completions( sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() - # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need - # to re-tokenize completions if the reward is computed from tokens. - completion_ids_list = [row[mask_row].tolist() for row, mask_row in zip(completion_ids, completion_mask.bool())] - # Sum along sequence dimension (dim=1) to get completion length per sequence, used for logging completion_lengths = completion_mask.sum(1) @@ -1341,7 +1332,54 @@ def _generate_and_score_completions( truncated_completions = ~is_eos.any(dim=1) completion_mask = completion_mask * (~truncated_completions).unsqueeze(1).int() + # Log the metrics + if mode == "train": + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + self.state.num_input_tokens_seen += self.accelerator.gather(attention_mask.sum()).sum().item() + self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] + + # Log completion lengths, mean, min, max + agg_completion_lengths = self.accelerator.gather(completion_lengths) + self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item()) + self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) + self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) + + # Identify sequences that terminated with EOS and log their lengths + agg_terminated_with_eos = self.accelerator.gather(is_eos.any(dim=1)) + term_completion_lengths = agg_completion_lengths[agg_terminated_with_eos] + clipped_completions_ratio = 1 - len(term_completion_lengths) / len(agg_completion_lengths) + self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio) + if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found + term_completion_lengths = torch.zeros(1, device=device) + self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) + self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) + self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) + + return prompt_ids, completion_ids, prompt_mask, completion_mask, forward_kwargs + + def _generate_and_score_completions( + self, inputs: list[dict[str, Union[torch.Tensor, Any]]] + ) -> dict[str, Union[torch.Tensor, Any]]: + device = self.accelerator.device + mode = "train" if self.model.training else "eval" + + prompts = [x["prompt"] for x in inputs] + + if "images" in inputs[0]: + images = [example.get("images") for example in inputs] + elif "image" in inputs[0]: + images = [[example.get("image")] if example.get("image") is not None else None for example in inputs] + else: + images = None + + prompt_ids, completion_ids, prompt_mask, completion_mask, forward_kwargs = self._generate(prompts, images) + + # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need + # to re-tokenize completions if the reward is computed from tokens. + completion_ids_list = [row[mask_row].tolist() for row, mask_row in zip(completion_ids, completion_mask.bool())] + # Concatenate prompt_mask with completion_mask for logit computation + prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens @@ -1357,11 +1395,8 @@ def _generate_and_score_completions( attention_mask, logits_to_keep, batch_size, - pixel_values=prompt_inputs.get("pixel_values"), - image_grid_thw=prompt_inputs.get("image_grid_thw"), num_images=num_images, - pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), - image_sizes=prompt_inputs.get("image_sizes"), + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes ) old_logps = (old_per_token_logps * completion_mask).sum(1) # mask out padding and tokens after EOS @@ -1374,11 +1409,8 @@ def _generate_and_score_completions( attention_mask, logits_to_keep, batch_size=batch_size, - pixel_values=prompt_inputs.get("pixel_values"), - image_grid_thw=prompt_inputs.get("image_grid_thw"), num_images=num_images, - pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), - image_sizes=prompt_inputs.get("image_sizes"), + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes ) else: with self.accelerator.unwrap_model(self.model).disable_adapter(): @@ -1388,16 +1420,14 @@ def _generate_and_score_completions( attention_mask, logits_to_keep, batch_size=batch_size, - pixel_values=prompt_inputs.get("pixel_values"), - image_grid_thw=prompt_inputs.get("image_grid_thw"), num_images=num_images, - pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), - image_sizes=prompt_inputs.get("image_sizes"), + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes ) else: ref_per_token_logps = None - # Decode the generated completions + # Decode + prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=True) completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) if is_conversational(inputs[0]): completions = [] @@ -1450,33 +1480,11 @@ def _generate_and_score_completions( all_process_advantages = advantages.clone() # keep the aggregated advantages for logging advantages = advantages[process_slice] - # Log the metrics - if mode == "train": - self.state.num_input_tokens_seen += self.accelerator.gather(attention_mask.sum()).sum().item() - self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] - # Calculate and log the mean KL divergence between current and reference model if self.beta != 0.0: mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).nanmean().item()) - # Log completion lengths, mean, min, max - agg_completion_lengths = self.accelerator.gather(completion_lengths) - self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item()) - self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) - self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) - - # Identify sequences that terminated with EOS and log their lengths - agg_terminated_with_eos = self.accelerator.gather(is_eos.any(dim=1)) - term_completion_lengths = agg_completion_lengths[agg_terminated_with_eos] - clipped_completions_ratio = 1 - len(term_completion_lengths) / len(agg_completion_lengths) - self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio) - if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found - term_completion_lengths = torch.zeros(1, device=device) - self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) - self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) - self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) - # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values) for i, reward_func_name in enumerate(self.reward_func_names): mean_rewards = torch.nanmean(rewards_per_func[:, i]).item() @@ -1505,14 +1513,14 @@ def _generate_and_score_completions( "old_logps": old_logps, "advantages": advantages, } - if "pixel_values" in prompt_inputs: - output["pixel_values"] = prompt_inputs["pixel_values"] - if "image_grid_thw" in prompt_inputs: - output["image_grid_thw"] = prompt_inputs["image_grid_thw"] - if "pixel_attention_mask" in prompt_inputs: - output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"] - if "image_sizes" in prompt_inputs: - output["image_sizes"] = prompt_inputs["image_sizes"] + if "pixel_values" in forward_kwargs: + output["pixel_values"] = forward_kwargs["pixel_values"] + if "image_grid_thw" in forward_kwargs: + output["image_grid_thw"] = forward_kwargs["image_grid_thw"] + if "pixel_attention_mask" in forward_kwargs: + output["pixel_attention_mask"] = forward_kwargs["pixel_attention_mask"] + if "image_sizes" in forward_kwargs: + output["image_sizes"] = forward_kwargs["image_sizes"] if images is not None: output["num_images"] = num_images return output