Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
552e899
Refactor image handling: replace `image_split_sizes` with `image_grid…
qgallouedec Sep 19, 2025
449ef07
simpler
qgallouedec Sep 19, 2025
c8933aa
gfpo
qgallouedec Sep 19, 2025
229c554
multi-image grpo
qgallouedec Sep 19, 2025
3ca6ad5
log with wandb
qgallouedec Sep 19, 2025
dcf4b92
no vlm reward models
qgallouedec Sep 20, 2025
30ad7ca
rloo
qgallouedec Sep 20, 2025
86cc30b
gfpo
qgallouedec Sep 20, 2025
088897b
fix
qgallouedec Sep 20, 2025
d2adc63
test peft
qgallouedec Sep 20, 2025
f4c82bf
fix gfpo
qgallouedec Sep 20, 2025
1257796
rloo test
qgallouedec Sep 20, 2025
099a39b
peft rloo
qgallouedec Sep 20, 2025
529add6
oops
qgallouedec Sep 20, 2025
fc6b11f
update test
qgallouedec Sep 20, 2025
ae1f497
generate method
qgallouedec Sep 20, 2025
f998432
debug
qgallouedec Sep 20, 2025
fa73876
skip failing test
qgallouedec Sep 20, 2025
52d8bd9
Merge branch 'main' into drop-image_split_sizes
qgallouedec Sep 20, 2025
dfc0d38
Merge branch 'drop-image_split_sizes' into multi-image-support
qgallouedec Sep 20, 2025
fc52e68
test fixed!
qgallouedec Sep 20, 2025
4d12aeb
Merge branch 'multi-image-support' into generate-method
qgallouedec Sep 20, 2025
4fc2b5b
gfpo
qgallouedec Sep 20, 2025
b628744
rm vllm
qgallouedec Sep 20, 2025
d3a769f
fix doc
qgallouedec Sep 20, 2025
e17ec42
Merge branch 'main' into drop-image_split_sizes
qgallouedec Sep 22, 2025
efbb03a
Merge branch 'drop-image_split_sizes' into multi-image-support
qgallouedec Sep 22, 2025
562c662
Merge branch 'main' into multi-image-support
qgallouedec Sep 22, 2025
485781c
Merge branch 'main' into multi-image-support
qgallouedec Sep 22, 2025
05270f8
update layers to ignore
qgallouedec Sep 22, 2025
1c53094
clarify image column desc
qgallouedec Sep 22, 2025
9b6652e
rm VLM x RM warning
qgallouedec Sep 23, 2025
c500440
Merge branch 'multi-image-support' into generate-method
qgallouedec Sep 23, 2025
a6a8c44
Merge branch 'main' into generate-method
qgallouedec Sep 23, 2025
d8665e1
Merge branch 'main' into generate-method
qgallouedec Sep 23, 2025
365d501
Merge branch 'main' into generate-method
qgallouedec Sep 23, 2025
cdb4c76
Merge branch 'main' into generate-method
qgallouedec Sep 24, 2025
c83e710
same for rloo
qgallouedec Sep 24, 2025
ec6ad25
nits style and align
qgallouedec Sep 24, 2025
b4cadde
Merge branch 'main' into generate-method
qgallouedec Sep 24, 2025
b0dceb9
restart
qgallouedec Sep 25, 2025
ebe32c2
progress
qgallouedec Sep 25, 2025
0213662
progress continues
qgallouedec Sep 25, 2025
8b3a724
progress again again
qgallouedec Sep 25, 2025
c1ae6aa
back to working point
qgallouedec Sep 25, 2025
1a66b43
revert chage data utils
qgallouedec Sep 25, 2025
2dc69a6
Merge branch 'main' into generate-method
qgallouedec Sep 26, 2025
9435a94
refactor in grpo
qgallouedec Sep 26, 2025
d3f1d3c
Merge branch 'main' into refactor_generate
qgallouedec Sep 26, 2025
3d8ea27
wrong merge commit
qgallouedec Sep 26, 2025
27dc958
fix num_input_tokens_seen
qgallouedec Sep 26, 2025
53772ef
getting closer
qgallouedec Sep 26, 2025
8766fa5
consistent naming
qgallouedec Sep 26, 2025
236b78b
better
qgallouedec Sep 26, 2025
9da4830
simplify a bit + comment
qgallouedec Sep 26, 2025
b3bd0b0
another one
qgallouedec Sep 26, 2025
d79b9e1
get prompt ids from generation
qgallouedec Sep 26, 2025
8d34d54
remove pad token removal
qgallouedec Sep 26, 2025
e770efe
Merge branch 'refactor_generate' into refactor_generate_2
qgallouedec Sep 26, 2025
55a2480
rloo + doc
qgallouedec Sep 26, 2025
c8041e1
Merge branch 'refactor_generate' into refactor_generate_2
qgallouedec Sep 26, 2025
7b7a11d
test and doc
qgallouedec Sep 27, 2025
c5064d6
gfpo
qgallouedec Sep 27, 2025
effb41b
Merge branch 'main' into refactor_generate
qgallouedec Sep 27, 2025
e82bfb4
Merge branch 'main' into refactor_generate
qgallouedec Sep 27, 2025
4b9c126
Merge branch 'refactor_generate' into refactor_generate_2
qgallouedec Sep 27, 2025
f11759e
Merge branch 'main' into refactor_generate_2
qgallouedec Sep 30, 2025
e7aa945
fix vllm client server
qgallouedec Sep 30, 2025
e164ec5
repicate all_prompt_ids
qgallouedec Oct 1, 2025
49577ad
Same for RLOO
qgallouedec Oct 1, 2025
5fca5b8
fix normal generation path
qgallouedec Oct 1, 2025
d599c20
Merge branch 'main' into refactor_generate_2
qgallouedec Oct 1, 2025
e82db74
🔣 Fix test: replace `trainer.tokenizer` by `trainer.processing_class`…
qgallouedec Oct 1, 2025
192deb3
Fix CI ImportError: FlashAttention2 and decorator order for all param…
albertvillanova Oct 1, 2025
cf9d8e7
Hotfix wrong formatting of docstrings with blockquote tips (#4187)
albertvillanova Oct 1, 2025
f9c3c3c
🌡️ Have vLLM return processed (temperature scaled) log probs (#4163)
YonatanGideoni Oct 1, 2025
6489479
Replace remaining trainer.tokenizer with trainer.processing_class in …
albertvillanova Oct 3, 2025
21a67fc
[DOCS] Lora without regret (#4181)
burtenshaw Oct 3, 2025
c1e7ad2
[DOCS/FIX] lora without regrets - fix lr (#4207)
burtenshaw Oct 6, 2025
5d34144
Remove custome_container for building the docs (#4198)
albertvillanova Oct 6, 2025
ae2a0e7
Remove tokenizer creation from `sft` example script (#4197)
sergiopaniego Oct 6, 2025
6543f51
Hotfix: Exclude transformers 4.57.0 for Python 3.9 (#4209)
albertvillanova Oct 6, 2025
8319ce0
Replace unittest with pytest (#4188)
albertvillanova Oct 6, 2025
4fdaa4c
Updated vLLM integration guide (#4162)
sergiopaniego Oct 6, 2025
d258e36
Remove `Optional` from `processing_class` in `PPOTrainer` (#4212)
sergiopaniego Oct 6, 2025
7f5b499
Replace setup with pyproject and fix packaging unintended modules (#4…
albertvillanova Oct 6, 2025
df386f9
Merge branch 'main' into refactor_generate_2
qgallouedec Oct 6, 2025
5b9a6ab
Merge branch 'main' into refactor_generate_2
qgallouedec Oct 6, 2025
4a274d5
Merge branch 'main' into refactor_generate_2
qgallouedec Oct 6, 2025
6324eda
Merge branch 'main' into refactor_generate_2
qgallouedec Oct 7, 2025
3955643
fix prompt mask
qgallouedec Oct 7, 2025
ee6638c
remove no-op
qgallouedec Oct 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 69 additions & 41 deletions tests/test_vllm_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,36 +74,42 @@ def setup_class(cls):

def test_generate(self):
prompts = ["Hello, AI!", "Tell me a joke"]
outputs = self.client.generate(prompts)["completion_ids"]
outputs = self.client.generate(prompts)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]

# Check that the output is a list
assert isinstance(outputs, list)
# Check that the outputs are lists
assert isinstance(prompt_ids, list)
assert isinstance(completion_ids, list)

# Check that the number of generated sequences is equal to the number of prompts
assert len(outputs) == len(prompts)
# Check that the number of sequences are equal to the number of prompts
assert len(prompt_ids) == len(prompts)
assert len(completion_ids) == len(prompts)

# Check that the generated sequences are lists of integers
for seq in outputs:
# Check that the sequences are lists of integers
for seq in prompt_ids:
assert all(isinstance(tok, int) for tok in seq)
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)

def test_generate_with_params(self):
prompts = ["Hello, AI!", "Tell me a joke"]
outputs = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
"completion_ids"
]

# Check that the output is a list
assert isinstance(outputs, list)
assert isinstance(completion_ids, list)

# Check that the number of generated sequences is 2 times the number of prompts
assert len(outputs) == 2 * len(prompts)
assert len(completion_ids) == 2 * len(prompts)

# Check that the generated sequences are lists of integers
for seq in outputs:
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)

# Check that the length of the generated sequences is less than or equal to 32
for seq in outputs:
for seq in completion_ids:
assert len(seq) <= 32

def test_update_model_params(self):
Expand Down Expand Up @@ -148,36 +154,42 @@ def setup_class(cls):

def test_generate(self):
prompts = ["Hello, AI!", "Tell me a joke"]
outputs = self.client.generate(prompts)["completion_ids"]
outputs = self.client.generate(prompts)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]

# Check that the output is a list
assert isinstance(outputs, list)
# Check that the outputs are lists
assert isinstance(prompt_ids, list)
assert isinstance(completion_ids, list)

# Check that the number of generated sequences is equal to the number of prompts
assert len(outputs) == len(prompts)
# Check that the number of sequences are equal to the number of prompts
assert len(prompt_ids) == len(prompts)
assert len(completion_ids) == len(prompts)

# Check that the generated sequences are lists of integers
for seq in outputs:
# Check that the sequences are lists of integers
for seq in prompt_ids:
assert all(isinstance(tok, int) for tok in seq)
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)

def test_generate_with_params(self):
prompts = ["Hello, AI!", "Tell me a joke"]
outputs = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
"completion_ids"
]

# Check that the output is a list
assert isinstance(outputs, list)
assert isinstance(completion_ids, list)

# Check that the number of generated sequences is 2 times the number of prompts
assert len(outputs) == 2 * len(prompts)
assert len(completion_ids) == 2 * len(prompts)

# Check that the generated sequences are lists of integers
for seq in outputs:
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)

# Check that the length of the generated sequences is less than or equal to 32
for seq in outputs:
for seq in completion_ids:
assert len(seq) <= 32

def test_update_model_params(self):
Expand Down Expand Up @@ -224,16 +236,22 @@ def setup_class(cls):

def test_generate(self):
prompts = ["Hello, AI!", "Tell me a joke"]
outputs = self.client.generate(prompts)["completion_ids"]
outputs = self.client.generate(prompts)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]

# Check that the output is a list
assert isinstance(outputs, list)
# Check that the outputs are lists
assert isinstance(prompt_ids, list)
assert isinstance(completion_ids, list)

# Check that the number of generated sequences is equal to the number of prompts
assert len(outputs) == len(prompts)
# Check that the number of sequences are equal to the number of prompts
assert len(prompt_ids) == len(prompts)
assert len(completion_ids) == len(prompts)

# Check that the generated sequences are lists of integers
for seq in outputs:
# Check that the sequences are lists of integers
for seq in prompt_ids:
assert all(isinstance(tok, int) for tok in seq)
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)

def test_update_model_params(self):
Expand Down Expand Up @@ -280,16 +298,22 @@ def setup_class(cls):

def test_generate(self):
prompts = ["Hello, AI!", "Tell me a joke"]
outputs = self.client.generate(prompts)["completion_ids"]
outputs = self.client.generate(prompts)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]

# Check that the output is a list
assert isinstance(outputs, list)
# Check that the outputs are lists
assert isinstance(prompt_ids, list)
assert isinstance(completion_ids, list)

# Check that the number of generated sequences is equal to the number of prompts
assert len(outputs) == len(prompts)
# Check that the number of sequences are equal to the number of prompts
assert len(prompt_ids) == len(prompts)
assert len(completion_ids) == len(prompts)

# Check that the generated sequences are lists of integers
for seq in outputs:
# Check that the sequences are lists of integers
for seq in prompt_ids:
assert all(isinstance(tok, int) for tok in seq)
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)

def test_update_model_params(self):
Expand Down Expand Up @@ -336,9 +360,13 @@ def test_init_communicator_with_device_int(self):

# Test basic functionality
prompts = ["Hello, AI!"]
outputs = client.generate(prompts)["completion_ids"]
assert isinstance(outputs, list)
assert len(outputs) == len(prompts)
outputs = client.generate(prompts)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]
assert isinstance(prompt_ids, list)
assert len(prompt_ids) == len(prompts)
assert isinstance(completion_ids, list)
assert len(completion_ids) == len(prompts)

client.close_communicator()

Expand Down
16 changes: 13 additions & 3 deletions trl/extras/vllm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,12 @@ class VLLMClient:

>>> client = VLLMClient()
>>> client.generate(["Hello, AI!", "Tell me a joke"])
[[2980, 498, 1492, 752, 448, 264, 13027, 8645, 30, 358, 2776, 4460, 311, 3270, 264, 2025],
[911, 7988, 1251, 382, 3838, 653, 498, 1618, 4325, 879, 2581, 20027, 264, 21428, 30, 362]]
{'prompt_ids': [[9707, 11, 15235, 0],
[40451, 752, 264, 21646]],
'completion_ids': [[11479, 752, 5046, 279, 1465, 304, 419, 23670, 2038, 358, 2776, 4378, 369, 847, 15549, 6733],
[911, 19654, 382, 3838, 1558, 279, 16158, 1977, 979, 498, 2299, 4460, 311, 10542, 432, 518]],
'logprobs': [[-5.193126201629639, -0.05592319369316101, -4.861808776855469, -1.673396110534668, -2.6316866874694824, -0.2861405313014984, -0.35006725788116455, -5.23351526260376, -0.1447441577911377, -5.21489953994751, -1.6022650003433228, -1.9649192094802856, -2.1338791847229004, -1.2775304317474365, -10.004860877990723, -4.171003818511963],
[-0.012896230444312096, -5.747106552124023, -1.5248860120773315, -1.9286258220672607, -2.8512537479400635, -2.8055880069732666, -3.019822835922241, -0.37132859230041504, -0.6311739087104797, -2.562908411026001, -3.1664533615112305, -2.685293436050415, -0.007259538397192955, -7.339841842651367, -1.188662052154541, -3.54781436920166]]}

>>> from transformers import AutoModelForCausalLM

Expand Down Expand Up @@ -212,6 +216,8 @@ def generate(

Returns:
`dict` with keys:
- `prompt_ids` (`list[list[int]]`):
List of lists of token IDs representing the tokenized input prompts.
- `completion_ids` (`list[list[int]]`):
List of lists of token IDs representing the model-generated completions for each prompt.
- `logprobs` (`list[list[float]]`):
Expand Down Expand Up @@ -246,7 +252,11 @@ def pil_to_base64(image):
)
if response.status_code == 200:
json_response = response.json()
return {"completion_ids": json_response["completion_ids"], "logprobs": json_response["logprobs"]}
return {
"prompt_ids": json_response["prompt_ids"],
"completion_ids": json_response["completion_ids"],
"logprobs": json_response["logprobs"],
}
else:
raise Exception(f"Request failed: {response.status_code}, {response.text}")

Expand Down
11 changes: 9 additions & 2 deletions trl/scripts/vllm_serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,7 @@ class GenerateRequest(BaseModel):
generation_kwargs: dict = field(default_factory=dict)

class GenerateResponse(BaseModel):
prompt_ids: list[list[int]]
completion_ids: list[list[int]]
logprobs: list[list[float]]

Expand Down Expand Up @@ -532,6 +533,7 @@ async def generate(request: GenerateRequest):

Returns:
`GenerateResponse`:
- `prompt_ids` (list of list of `int`): A list of lists of token IDs for each input prompt.
- `completion_ids` (list of list of `int`): A list of lists of token IDs for each generated completion.
- `logprobs` (list of list of `float`): A list of lists of log probabilities for each token in the
generated completions.
Expand All @@ -543,7 +545,11 @@ async def generate(request: GenerateRequest):

Example response:
```json
{"completion_ids": [[101, 102, 103], [201, 202, 203]], "logprobs": [[-0.1, -0.2, -0.3], [-0.4, -0.5, -0.6]]}
{
"prompt_ids": [[101, 102], [201, 202]],
"completion_ids": [[103, 104, 105], [203, 204, 205]],
"logprobs": [[-0.1, -0.2, -0.3], [-0.4, -0.5, -0.6]]
}
```
"""
request.images = request.images or [None] * len(request.prompts)
Expand Down Expand Up @@ -596,13 +602,14 @@ async def generate(request: GenerateRequest):

# Flatten and combine all results
all_outputs = list(chain.from_iterable(all_outputs)) # from list of list to single list
prompt_ids = [output.prompt_token_ids for output in all_outputs]
completion_ids = [list(output.token_ids) for outputs in all_outputs for output in outputs.outputs]
logprobs: list[list[float]] = [
[sanitize_logprob(next(iter(logprob.values()))) for logprob in output.logprobs]
for outputs in all_outputs
for output in outputs.outputs
]
return {"completion_ids": completion_ids, "logprobs": logprobs}
return {"prompt_ids": prompt_ids, "completion_ids": completion_ids, "logprobs": logprobs}

class InitCommunicatorRequest(BaseModel):
host: str
Expand Down
21 changes: 13 additions & 8 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,11 +1101,12 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
**kwargs,
)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]

if self.max_prompt_length is not None:
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]

# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
# Then we decode those tokens back into text. We set `skip_special_tokens=False` because some special
# tokens are needed for generation.
Expand Down Expand Up @@ -1187,19 +1188,23 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
guided_decoding_regex=self.guided_decoding_regex,
generation_kwargs=self.args.generation_kwargs,
)
payload = (output["completion_ids"], output["logprobs"])
payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"])
else:
payload = None

# Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice.
obj_list = [payload]
broadcast_object_list(obj_list, from_process=0)
all_completion_ids, all_logprobs = obj_list[0]
all_prompt_ids, all_completion_ids, all_logprobs = obj_list[0]

# At this point, we only get 1 copy of each prompt, so we need to repeat them num_generations times
all_prompt_ids = [ids for ids in all_prompt_ids for _ in range(self.num_generations)]

process_slice = slice(
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),
)
prompt_ids = all_prompt_ids[process_slice]
completion_ids = all_completion_ids[process_slice]
logprobs = all_logprobs[process_slice]

Expand Down Expand Up @@ -1254,6 +1259,7 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
with profiling_context(self, "vLLM.generate"):
all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False)

all_prompt_ids = [output.prompt_token_ids for output in all_outputs]
all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs]
all_logprobs = [
[next(iter(lp.values())).logprob for lp in output.logprobs]
Expand All @@ -1266,9 +1272,11 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
# Each rank generates all outputs — we keep only our share.
local_rank_in_group = torch.distributed.get_rank(group=self.tp_group)
tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size)
prompt_ids = all_prompt_ids[tp_slice]
completion_ids = all_completion_ids[tp_slice]
logprobs = all_logprobs[tp_slice]
else:
prompt_ids = all_prompt_ids
completion_ids = all_completion_ids
logprobs = all_logprobs

Expand Down Expand Up @@ -1311,10 +1319,7 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]):

else:
# Regular generation path
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids]
prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids]
prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left")
prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left")
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]

with (
profiling_context(self, "transformers.generate"),
Expand Down
Loading
Loading