diff --git a/trl/data_utils.py b/trl/data_utils.py index 75e7a76f979..9f1fec62010 100644 --- a/trl/data_utils.py +++ b/trl/data_utils.py @@ -148,7 +148,7 @@ def apply_chat_template( # Apply the chat template to the prompt, adding the generation prompt if "prompt" in example: last_role = example["prompt"][-1]["role"] - if last_role == "user": + if last_role in ["user", "tool"]: add_generation_prompt = True continue_final_message = False elif last_role == "assistant": diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 216591acf6a..9c02ad6ae97 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +import json import os import re import textwrap @@ -61,6 +62,10 @@ disable_dropout_in_model, ensure_master_addr_port, entropy_from_logits, + flush_left, + flush_right, + generate_model_card, + get_comet_experiment_url, identity, nanmax, nanmin, @@ -97,7 +102,21 @@ RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]] -class GRPOTrainer(BaseTrainer): +def extract_tool_calls(text: str) -> dict[str, Any]: + """ + Given a list of strings, extract all JSON blocks and return them as a list of dictionaries. + """ + pattern = re.compile(r"\s*(\{.*?\})\s*", re.DOTALL) + + for match in pattern.findall(text): + try: + return json.loads(match) + except json.JSONDecodeError: + pass + return None + + +class GRPOTrainer(Trainer): """ Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language @@ -227,7 +246,10 @@ def __init__( callbacks: Optional[list[TrainerCallback]] = None, optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), peft_config: Optional["PeftConfig"] = None, + tools=None, ): + self.tools = tools or [] + self._tool_dict = {tool.__name__: tool for tool in self.tools} # Args if args is None: model_name = model if isinstance(model, str) else model.config._name_or_path @@ -1085,7 +1107,8 @@ def _generate(self, prompts: list[str], images: Optional[list]): prepare_multimodal_messages(prompt, num_images=len(image_list)) prompts_text = [ - maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts + maybe_apply_chat_template({"prompt": prompt}, self.processing_class, tools=self.tools)["prompt"] + for prompt in prompts ] prompt_inputs = self.processing_class( @@ -1413,6 +1436,53 @@ def _generate_and_score_completions( sampling_per_token_logps, forward_kwargs, ) = self._generate(prompts, images) + completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + tool_calls = [extract_tool_calls(completion) for completion in completions] + tool_results = [self._tool_dict[tc["name"]](**tc["arguments"]) if tc else None for tc in tool_calls] + tool_messages = [ + [{"role": "tool", "name": tc["name"], "content": str(tr)}] if tc else None + for tc, tr in zip(tool_calls, tool_results) + ] + new_prompts = [ + p + [{"role": "user", "content": c}] + t for p, c, t in zip(prompts, completions, tool_messages) if t + ] + needs_tool = torch.tensor([tc is not None for tc in tool_calls], device=device) + if new_prompts: + ( + new_prompt_ids, + new_completion_ids, + new_prompt_mask, + new_completion_mask, + new_num_items_in_batch, + new_sampling_per_token_logps, + new_forward_kwargs, + ) = self._generate(new_prompts, images) + num_tool_ids = new_prompt_mask.sum(-1) - torch.cat( + [prompt_mask[needs_tool], completion_mask[needs_tool]], dim=1 + ).sum(-1) + tool_ids = [ids[-num:] for ids, num in zip(new_prompt_ids, num_tool_ids)] + tool_mask = [torch.ones_like(ids) for ids in tool_ids] + r_completion_mask, r_completion_ids = flush_right(completion_mask[needs_tool], completion_ids[needs_tool]) + ci = [torch.cat(x) for x in zip(r_completion_ids, tool_ids, new_completion_ids)] + cm = [torch.cat(x) for x in zip(r_completion_mask, tool_mask, new_completion_mask)] + + new_ci = [] + new_cm = [] + true_idx = 0 + for i, m in enumerate(needs_tool): + if m: + # take the next tensor from list_true + new_ci.append(ci[true_idx]) + new_cm.append(cm[true_idx]) + true_idx += 1 + else: + new_ci.append(completion_ids[i]) + new_cm.append(completion_mask[i]) + + completion_ids = pad(new_ci, self.pad_token_id) + completion_mask = pad(new_cm, 0) + completion_mask, completion_ids = flush_left(completion_mask, completion_ids) + num_items_in_batch += new_num_items_in_batch # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need # to re-tokenize completions if the reward is computed from tokens.