diff --git a/tests/test_vllm_client_server.py b/tests/test_vllm_client_server.py index 23c2080289c..ad93022aff4 100644 --- a/tests/test_vllm_client_server.py +++ b/tests/test_vllm_client_server.py @@ -74,36 +74,42 @@ def setup_class(cls): def test_generate(self): prompts = ["Hello, AI!", "Tell me a joke"] - outputs = self.client.generate(prompts)["completion_ids"] + outputs = self.client.generate(prompts) + prompt_ids = outputs["prompt_ids"] + completion_ids = outputs["completion_ids"] - # Check that the output is a list - assert isinstance(outputs, list) + # Check that the outputs are lists + assert isinstance(prompt_ids, list) + assert isinstance(completion_ids, list) - # Check that the number of generated sequences is equal to the number of prompts - assert len(outputs) == len(prompts) + # Check that the number of sequences are equal to the number of prompts + assert len(prompt_ids) == len(prompts) + assert len(completion_ids) == len(prompts) - # Check that the generated sequences are lists of integers - for seq in outputs: + # Check that the sequences are lists of integers + for seq in prompt_ids: + assert all(isinstance(tok, int) for tok in seq) + for seq in completion_ids: assert all(isinstance(tok, int) for tok in seq) def test_generate_with_params(self): prompts = ["Hello, AI!", "Tell me a joke"] - outputs = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[ + completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[ "completion_ids" ] # Check that the output is a list - assert isinstance(outputs, list) + assert isinstance(completion_ids, list) # Check that the number of generated sequences is 2 times the number of prompts - assert len(outputs) == 2 * len(prompts) + assert len(completion_ids) == 2 * len(prompts) # Check that the generated sequences are lists of integers - for seq in outputs: + for seq in completion_ids: assert all(isinstance(tok, int) for tok in seq) # Check that the length of the generated sequences is less than or equal to 32 - for seq in outputs: + for seq in completion_ids: assert len(seq) <= 32 def test_update_model_params(self): @@ -148,36 +154,42 @@ def setup_class(cls): def test_generate(self): prompts = ["Hello, AI!", "Tell me a joke"] - outputs = self.client.generate(prompts)["completion_ids"] + outputs = self.client.generate(prompts) + prompt_ids = outputs["prompt_ids"] + completion_ids = outputs["completion_ids"] - # Check that the output is a list - assert isinstance(outputs, list) + # Check that the outputs are lists + assert isinstance(prompt_ids, list) + assert isinstance(completion_ids, list) - # Check that the number of generated sequences is equal to the number of prompts - assert len(outputs) == len(prompts) + # Check that the number of sequences are equal to the number of prompts + assert len(prompt_ids) == len(prompts) + assert len(completion_ids) == len(prompts) - # Check that the generated sequences are lists of integers - for seq in outputs: + # Check that the sequences are lists of integers + for seq in prompt_ids: + assert all(isinstance(tok, int) for tok in seq) + for seq in completion_ids: assert all(isinstance(tok, int) for tok in seq) def test_generate_with_params(self): prompts = ["Hello, AI!", "Tell me a joke"] - outputs = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[ + completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[ "completion_ids" ] # Check that the output is a list - assert isinstance(outputs, list) + assert isinstance(completion_ids, list) # Check that the number of generated sequences is 2 times the number of prompts - assert len(outputs) == 2 * len(prompts) + assert len(completion_ids) == 2 * len(prompts) # Check that the generated sequences are lists of integers - for seq in outputs: + for seq in completion_ids: assert all(isinstance(tok, int) for tok in seq) # Check that the length of the generated sequences is less than or equal to 32 - for seq in outputs: + for seq in completion_ids: assert len(seq) <= 32 def test_update_model_params(self): @@ -224,16 +236,22 @@ def setup_class(cls): def test_generate(self): prompts = ["Hello, AI!", "Tell me a joke"] - outputs = self.client.generate(prompts)["completion_ids"] + outputs = self.client.generate(prompts) + prompt_ids = outputs["prompt_ids"] + completion_ids = outputs["completion_ids"] - # Check that the output is a list - assert isinstance(outputs, list) + # Check that the outputs are lists + assert isinstance(prompt_ids, list) + assert isinstance(completion_ids, list) - # Check that the number of generated sequences is equal to the number of prompts - assert len(outputs) == len(prompts) + # Check that the number of sequences are equal to the number of prompts + assert len(prompt_ids) == len(prompts) + assert len(completion_ids) == len(prompts) - # Check that the generated sequences are lists of integers - for seq in outputs: + # Check that the sequences are lists of integers + for seq in prompt_ids: + assert all(isinstance(tok, int) for tok in seq) + for seq in completion_ids: assert all(isinstance(tok, int) for tok in seq) def test_update_model_params(self): @@ -280,16 +298,22 @@ def setup_class(cls): def test_generate(self): prompts = ["Hello, AI!", "Tell me a joke"] - outputs = self.client.generate(prompts)["completion_ids"] + outputs = self.client.generate(prompts) + prompt_ids = outputs["prompt_ids"] + completion_ids = outputs["completion_ids"] - # Check that the output is a list - assert isinstance(outputs, list) + # Check that the outputs are lists + assert isinstance(prompt_ids, list) + assert isinstance(completion_ids, list) - # Check that the number of generated sequences is equal to the number of prompts - assert len(outputs) == len(prompts) + # Check that the number of sequences are equal to the number of prompts + assert len(prompt_ids) == len(prompts) + assert len(completion_ids) == len(prompts) - # Check that the generated sequences are lists of integers - for seq in outputs: + # Check that the sequences are lists of integers + for seq in prompt_ids: + assert all(isinstance(tok, int) for tok in seq) + for seq in completion_ids: assert all(isinstance(tok, int) for tok in seq) def test_update_model_params(self): @@ -336,9 +360,13 @@ def test_init_communicator_with_device_int(self): # Test basic functionality prompts = ["Hello, AI!"] - outputs = client.generate(prompts)["completion_ids"] - assert isinstance(outputs, list) - assert len(outputs) == len(prompts) + outputs = client.generate(prompts) + prompt_ids = outputs["prompt_ids"] + completion_ids = outputs["completion_ids"] + assert isinstance(prompt_ids, list) + assert len(prompt_ids) == len(prompts) + assert isinstance(completion_ids, list) + assert len(completion_ids) == len(prompts) client.close_communicator() diff --git a/trl/extras/vllm_client.py b/trl/extras/vllm_client.py index 0932697d6ee..00e5f2c817b 100644 --- a/trl/extras/vllm_client.py +++ b/trl/extras/vllm_client.py @@ -83,8 +83,12 @@ class VLLMClient: >>> client = VLLMClient() >>> client.generate(["Hello, AI!", "Tell me a joke"]) - [[2980, 498, 1492, 752, 448, 264, 13027, 8645, 30, 358, 2776, 4460, 311, 3270, 264, 2025], - [911, 7988, 1251, 382, 3838, 653, 498, 1618, 4325, 879, 2581, 20027, 264, 21428, 30, 362]] + {'prompt_ids': [[9707, 11, 15235, 0], + [40451, 752, 264, 21646]], + 'completion_ids': [[11479, 752, 5046, 279, 1465, 304, 419, 23670, 2038, 358, 2776, 4378, 369, 847, 15549, 6733], + [911, 19654, 382, 3838, 1558, 279, 16158, 1977, 979, 498, 2299, 4460, 311, 10542, 432, 518]], + 'logprobs': [[-5.193126201629639, -0.05592319369316101, -4.861808776855469, -1.673396110534668, -2.6316866874694824, -0.2861405313014984, -0.35006725788116455, -5.23351526260376, -0.1447441577911377, -5.21489953994751, -1.6022650003433228, -1.9649192094802856, -2.1338791847229004, -1.2775304317474365, -10.004860877990723, -4.171003818511963], + [-0.012896230444312096, -5.747106552124023, -1.5248860120773315, -1.9286258220672607, -2.8512537479400635, -2.8055880069732666, -3.019822835922241, -0.37132859230041504, -0.6311739087104797, -2.562908411026001, -3.1664533615112305, -2.685293436050415, -0.007259538397192955, -7.339841842651367, -1.188662052154541, -3.54781436920166]]} >>> from transformers import AutoModelForCausalLM @@ -212,6 +216,8 @@ def generate( Returns: `dict` with keys: + - `prompt_ids` (`list[list[int]]`): + List of lists of token IDs representing the tokenized input prompts. - `completion_ids` (`list[list[int]]`): List of lists of token IDs representing the model-generated completions for each prompt. - `logprobs` (`list[list[float]]`): @@ -246,7 +252,11 @@ def pil_to_base64(image): ) if response.status_code == 200: json_response = response.json() - return {"completion_ids": json_response["completion_ids"], "logprobs": json_response["logprobs"]} + return { + "prompt_ids": json_response["prompt_ids"], + "completion_ids": json_response["completion_ids"], + "logprobs": json_response["logprobs"], + } else: raise Exception(f"Request failed: {response.status_code}, {response.text}") diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index 3e448aedf13..901f8177ce0 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -499,6 +499,7 @@ class GenerateRequest(BaseModel): generation_kwargs: dict = field(default_factory=dict) class GenerateResponse(BaseModel): + prompt_ids: list[list[int]] completion_ids: list[list[int]] logprobs: list[list[float]] @@ -532,6 +533,7 @@ async def generate(request: GenerateRequest): Returns: `GenerateResponse`: + - `prompt_ids` (list of list of `int`): A list of lists of token IDs for each input prompt. - `completion_ids` (list of list of `int`): A list of lists of token IDs for each generated completion. - `logprobs` (list of list of `float`): A list of lists of log probabilities for each token in the generated completions. @@ -543,7 +545,11 @@ async def generate(request: GenerateRequest): Example response: ```json - {"completion_ids": [[101, 102, 103], [201, 202, 203]], "logprobs": [[-0.1, -0.2, -0.3], [-0.4, -0.5, -0.6]]} + { + "prompt_ids": [[101, 102], [201, 202]], + "completion_ids": [[103, 104, 105], [203, 204, 205]], + "logprobs": [[-0.1, -0.2, -0.3], [-0.4, -0.5, -0.6]] + } ``` """ request.images = request.images or [None] * len(request.prompts) @@ -596,13 +602,14 @@ async def generate(request: GenerateRequest): # Flatten and combine all results all_outputs = list(chain.from_iterable(all_outputs)) # from list of list to single list + prompt_ids = [output.prompt_token_ids for output in all_outputs] completion_ids = [list(output.token_ids) for outputs in all_outputs for output in outputs.outputs] logprobs: list[list[float]] = [ [sanitize_logprob(next(iter(logprob.values()))) for logprob in output.logprobs] for outputs in all_outputs for output in outputs.outputs ] - return {"completion_ids": completion_ids, "logprobs": logprobs} + return {"prompt_ids": prompt_ids, "completion_ids": completion_ids, "logprobs": logprobs} class InitCommunicatorRequest(BaseModel): host: str diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index a4c85783461..eca92f4bbf8 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1101,11 +1101,12 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): **kwargs, ) prompt_inputs = super()._prepare_inputs(prompt_inputs) - prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} - prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())] if self.max_prompt_length is not None: + prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] + prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())] + # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens. # Then we decode those tokens back into text. We set `skip_special_tokens=False` because some special # tokens are needed for generation. @@ -1187,19 +1188,23 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): guided_decoding_regex=self.guided_decoding_regex, generation_kwargs=self.args.generation_kwargs, ) - payload = (output["completion_ids"], output["logprobs"]) + payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"]) else: payload = None # Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice. obj_list = [payload] broadcast_object_list(obj_list, from_process=0) - all_completion_ids, all_logprobs = obj_list[0] + all_prompt_ids, all_completion_ids, all_logprobs = obj_list[0] + + # At this point, we only get 1 copy of each prompt, so we need to repeat them num_generations times + all_prompt_ids = [ids for ids in all_prompt_ids for _ in range(self.num_generations)] process_slice = slice( self.accelerator.process_index * len(prompts), (self.accelerator.process_index + 1) * len(prompts), ) + prompt_ids = all_prompt_ids[process_slice] completion_ids = all_completion_ids[process_slice] logprobs = all_logprobs[process_slice] @@ -1254,6 +1259,7 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): with profiling_context(self, "vLLM.generate"): all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False) + all_prompt_ids = [output.prompt_token_ids for output in all_outputs] all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] all_logprobs = [ [next(iter(lp.values())).logprob for lp in output.logprobs] @@ -1266,9 +1272,11 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): # Each rank generates all outputs — we keep only our share. local_rank_in_group = torch.distributed.get_rank(group=self.tp_group) tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) + prompt_ids = all_prompt_ids[tp_slice] completion_ids = all_completion_ids[tp_slice] logprobs = all_logprobs[tp_slice] else: + prompt_ids = all_prompt_ids completion_ids = all_completion_ids logprobs = all_logprobs @@ -1311,10 +1319,7 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): else: # Regular generation path - prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids] - prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] - prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") - prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") + prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] with ( profiling_context(self, "transformers.generate"), diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index a97d969f0f6..dbc1b7a1912 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -1090,11 +1090,12 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): **kwargs, ) prompt_inputs = super()._prepare_inputs(prompt_inputs) - prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} - prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())] if self.max_prompt_length is not None: + prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] + prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())] + # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens. # Then we decode those tokens back into text. We set `skip_special_tokens=False` because some special # tokens are needed for generation. @@ -1176,19 +1177,23 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): guided_decoding_regex=self.guided_decoding_regex, generation_kwargs=self.args.generation_kwargs, ) - payload = (output["completion_ids"], output["logprobs"]) + payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"]) else: payload = None # Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice. obj_list = [payload] broadcast_object_list(obj_list, from_process=0) - all_completion_ids, _ = obj_list[0] + all_prompt_ids, all_completion_ids, _ = obj_list[0] + + # At this point, we only get 1 copy of each prompt, so we need to repeat them num_generations times + all_prompt_ids = [ids for ids in all_prompt_ids for _ in range(self.num_generations)] process_slice = slice( self.accelerator.process_index * len(prompts), (self.accelerator.process_index + 1) * len(prompts), ) + prompt_ids = all_prompt_ids[process_slice] completion_ids = all_completion_ids[process_slice] # Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts @@ -1241,6 +1246,7 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): with profiling_context(self, "vLLM.generate"): all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False) + all_prompt_ids = [output.prompt_token_ids for output in all_outputs] all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] if self.vllm_tensor_parallel_size > 1: @@ -1248,8 +1254,10 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): # Each rank generates all outputs — we keep only our share. local_rank_in_group = torch.distributed.get_rank(group=self.tp_group) tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) + prompt_ids = all_prompt_ids[tp_slice] completion_ids = all_completion_ids[tp_slice] else: + prompt_ids = all_prompt_ids completion_ids = all_completion_ids if self.args.vllm_enable_sleep_mode: @@ -1290,10 +1298,7 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): else: # Regular generation path - prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids] - prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] - prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") - prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") + prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] with ( profiling_context(self, "transformers.generate"),