Skip to content
Closed
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
552e899
Refactor image handling: replace `image_split_sizes` with `image_grid…
qgallouedec Sep 19, 2025
449ef07
simpler
qgallouedec Sep 19, 2025
c8933aa
gfpo
qgallouedec Sep 19, 2025
229c554
multi-image grpo
qgallouedec Sep 19, 2025
3ca6ad5
log with wandb
qgallouedec Sep 19, 2025
dcf4b92
no vlm reward models
qgallouedec Sep 20, 2025
30ad7ca
rloo
qgallouedec Sep 20, 2025
86cc30b
gfpo
qgallouedec Sep 20, 2025
088897b
fix
qgallouedec Sep 20, 2025
d2adc63
test peft
qgallouedec Sep 20, 2025
f4c82bf
fix gfpo
qgallouedec Sep 20, 2025
1257796
rloo test
qgallouedec Sep 20, 2025
099a39b
peft rloo
qgallouedec Sep 20, 2025
529add6
oops
qgallouedec Sep 20, 2025
fc6b11f
update test
qgallouedec Sep 20, 2025
ae1f497
generate method
qgallouedec Sep 20, 2025
f998432
debug
qgallouedec Sep 20, 2025
fa73876
skip failing test
qgallouedec Sep 20, 2025
52d8bd9
Merge branch 'main' into drop-image_split_sizes
qgallouedec Sep 20, 2025
dfc0d38
Merge branch 'drop-image_split_sizes' into multi-image-support
qgallouedec Sep 20, 2025
fc52e68
test fixed!
qgallouedec Sep 20, 2025
4d12aeb
Merge branch 'multi-image-support' into generate-method
qgallouedec Sep 20, 2025
4fc2b5b
gfpo
qgallouedec Sep 20, 2025
b628744
rm vllm
qgallouedec Sep 20, 2025
d3a769f
fix doc
qgallouedec Sep 20, 2025
c9693b2
a bit messy!
qgallouedec Sep 21, 2025
e17ec42
Merge branch 'main' into drop-image_split_sizes
qgallouedec Sep 22, 2025
efbb03a
Merge branch 'drop-image_split_sizes' into multi-image-support
qgallouedec Sep 22, 2025
562c662
Merge branch 'main' into multi-image-support
qgallouedec Sep 22, 2025
485781c
Merge branch 'main' into multi-image-support
qgallouedec Sep 22, 2025
05270f8
update layers to ignore
qgallouedec Sep 22, 2025
1c53094
clarify image column desc
qgallouedec Sep 22, 2025
9b6652e
rm VLM x RM warning
qgallouedec Sep 23, 2025
c500440
Merge branch 'multi-image-support' into generate-method
qgallouedec Sep 23, 2025
a6a8c44
Merge branch 'main' into generate-method
qgallouedec Sep 23, 2025
b8656e0
Merge branch 'generate-method' into multi-turn
qgallouedec Sep 23, 2025
d8665e1
Merge branch 'main' into generate-method
qgallouedec Sep 23, 2025
365d501
Merge branch 'main' into generate-method
qgallouedec Sep 23, 2025
acb44bc
Merge branch 'generate-method' into multi-turn
qgallouedec Sep 23, 2025
cdb4c76
Merge branch 'main' into generate-method
qgallouedec Sep 24, 2025
c83e710
same for rloo
qgallouedec Sep 24, 2025
ec6ad25
nits style and align
qgallouedec Sep 24, 2025
b4cadde
Merge branch 'main' into generate-method
qgallouedec Sep 24, 2025
594a07d
Merge branch 'generate-method' into multi-turn
qgallouedec Sep 24, 2025
04e4bd7
Update trl/trainer/grpo_trainer.py
qgallouedec Sep 26, 2025
242d66a
Merge branch 'main' into multi-turn
qgallouedec Sep 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/experimental.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class GroupFilter:
return group_scores

training_args = GFPOConfig(
output_dir="Qwen3-0.6B-GFPO"
output_dir="Qwen3-0.6B-GFPO",
per_device_train_batch_size=4,
num_remains_in_group=2,
bf16=True,
Expand Down
1 change: 1 addition & 0 deletions scripts/generate_tiny_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ def init_weights_tiny_model(model):
"hidden_size": 16,
"num_attention_heads": 4,
"num_key_value_heads": 2,
"embed_dim": 64,
}
config = AutoConfig.from_pretrained(model_id, text_config=text_config, vision_config=vision_config)

Expand Down
94 changes: 86 additions & 8 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1258,6 +1258,10 @@ def test_prepare_input_called_with_correct_data(self):
def test_training_vlm(self, model_id):
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")

def reward_func(completions, **kwargs):
"""Reward function that rewards longer completions."""
return [float(len(completion[0]["content"])) for completion in completions]

training_args = GRPOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
Expand All @@ -1269,7 +1273,7 @@ def test_training_vlm(self, model_id):
)
trainer = GRPOTrainer(
model=model_id,
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
)
Expand Down Expand Up @@ -1301,6 +1305,10 @@ def test_training_vlm(self, model_id):
def test_training_vlm_beta_non_zero(self):
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")

def reward_func(completions, **kwargs):
"""Reward function that rewards longer completions."""
return [float(len(completion[0]["content"])) for completion in completions]

training_args = GRPOConfig(
output_dir=self.tmp_dir,
beta=0.1, # set beta to non-zero value to test the case where the reference model is used
Expand All @@ -1312,7 +1320,7 @@ def test_training_vlm_beta_non_zero(self):
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
)
Expand Down Expand Up @@ -1340,7 +1348,11 @@ def test_training_vlm_peft(self):
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration"
)
base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")

def reward_func(completions, **kwargs):
"""Reward function that rewards longer completions."""
return [float(len(completion[0]["content"])) for completion in completions]

training_args = GRPOConfig(
output_dir=self.tmp_dir,
Expand All @@ -1352,7 +1364,7 @@ def test_training_vlm_peft(self):
)
trainer = GRPOTrainer(
model=model,
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
peft_config=LoraConfig(target_modules=["q_proj", "v_proj"]),
Expand All @@ -1376,6 +1388,10 @@ def test_training_vlm_peft(self):
def test_training_vlm_and_importance_sampling(self):
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")

def reward_func(completions, **kwargs):
"""Reward function that rewards longer completions."""
return [float(len(completion[0]["content"])) for completion in completions]

training_args = GRPOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
Expand All @@ -1387,7 +1403,7 @@ def test_training_vlm_and_importance_sampling(self):
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
)
Expand All @@ -1413,6 +1429,10 @@ def test_training_vlm_and_importance_sampling(self):
def test_training_vlm_and_liger(self):
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")

def reward_func(completions, **kwargs):
"""Reward function that rewards longer completions."""
return [float(len(completion[0]["content"])) for completion in completions]

training_args = GRPOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
Expand All @@ -1425,7 +1445,7 @@ def test_training_vlm_and_liger(self):
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
)
Expand All @@ -1451,6 +1471,10 @@ def test_training_vlm_and_prompt_truncation(self):
# If not handled properly, prompt truncation may truncate image token
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")

def reward_func(completions, **kwargs):
"""Reward function that rewards longer completions."""
return [float(len(completion[0]["content"])) for completion in completions]

training_args = GRPOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
Expand All @@ -1462,7 +1486,7 @@ def test_training_vlm_and_prompt_truncation(self):
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
)
Expand Down Expand Up @@ -1495,6 +1519,10 @@ def test_training_vlm_and_prompt_truncation(self):
def test_training_vlm_and_vllm(self, model_id) -> None:
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")

def reward_func(completions, **kwargs):
"""Reward function that rewards longer completions."""
return [float(len(completion[0]["content"])) for completion in completions]

training_args = GRPOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1,
Expand All @@ -1508,7 +1536,44 @@ def test_training_vlm_and_vllm(self, model_id) -> None:
)
trainer = GRPOTrainer(
model=model_id,
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")

@require_vision
def test_training_vlm_multi_image(self):
dataset = load_dataset("trl-internal-testing/zen-multi-image", "conversational_prompt_only", split="train")

# For now, mixing image+text and text-only examples is not supported, so we filter out text-only examples
dataset = dataset.filter(lambda x: len(x["images"]) > 0)

def reward_func(completions, **kwargs):
"""Reward function that rewards longer completions."""
return [float(len(completion[0]["content"])) for completion in completions]

training_args = GRPOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=8, # reduce the completion length to reduce memory usage
max_prompt_length=None, # disable prompt truncation, because usually, models don't support it
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
)
Expand All @@ -1519,7 +1584,20 @@ def test_training_vlm_and_vllm(self, model_id) -> None:

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# Check that the params have changed
# Because of the way the tiny models are initialized, the gradient does not flow properly through the
# vision parts of the model, so we skip them. Ideally, we should fix the init of these models.
params_to_skip = (
# "model.vision_tower.",
# "model.multi_modal_projector.",
# "model.vision_model.",
# "model.connector.modality_projection.",
# "model.visual.",
# "model.image_newline",
)
for n, param in previous_trainable_params.items():
if n.startswith(params_to_skip):
continue
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")

Expand Down
69 changes: 63 additions & 6 deletions tests/test_rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,6 +1089,10 @@ def test_prepare_input_called_with_correct_data(self):
def test_training_vlm(self, model_id):
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")

def reward_func(completions, **kwargs):
"""Reward function that rewards longer completions."""
return [float(len(completion[0]["content"])) for completion in completions]

training_args = RLOOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
Expand All @@ -1100,7 +1104,7 @@ def test_training_vlm(self, model_id):
)
trainer = RLOOTrainer(
model=model_id,
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
)
Expand Down Expand Up @@ -1132,6 +1136,10 @@ def test_training_vlm(self, model_id):
def test_training_vlm_beta_non_zero(self):
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")

def reward_func(completions, **kwargs):
"""Reward function that rewards longer completions."""
return [float(len(completion[0]["content"])) for completion in completions]

training_args = RLOOConfig(
output_dir=self.tmp_dir,
beta=0.1, # set beta to non-zero value to test the case where the reference model is used
Expand All @@ -1143,7 +1151,7 @@ def test_training_vlm_beta_non_zero(self):
)
trainer = RLOOTrainer(
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
)
Expand Down Expand Up @@ -1171,7 +1179,11 @@ def test_training_vlm_peft(self):
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration"
)
base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")

def reward_func(completions, **kwargs):
"""Reward function that rewards longer completions."""
return [float(len(completion[0]["content"])) for completion in completions]

training_args = RLOOConfig(
output_dir=self.tmp_dir,
Expand All @@ -1183,7 +1195,7 @@ def test_training_vlm_peft(self):
)
trainer = RLOOTrainer(
model=model,
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
peft_config=LoraConfig(target_modules=["q_proj", "v_proj"]),
Expand All @@ -1208,6 +1220,10 @@ def test_training_vlm_and_prompt_truncation(self):
# If not handled properly, prompt truncation may truncate image token
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")

def reward_func(completions, **kwargs):
"""Reward function that rewards longer completions."""
return [float(len(completion[0]["content"])) for completion in completions]

training_args = RLOOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
Expand All @@ -1219,7 +1235,7 @@ def test_training_vlm_and_prompt_truncation(self):
)
trainer = RLOOTrainer(
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
)
Expand Down Expand Up @@ -1252,6 +1268,10 @@ def test_training_vlm_and_prompt_truncation(self):
def test_training_vlm_and_vllm(self, model_id) -> None:
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")

def reward_func(completions, **kwargs):
"""Reward function that rewards longer completions."""
return [float(len(completion[0]["content"])) for completion in completions]

training_args = RLOOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1,
Expand All @@ -1265,7 +1285,44 @@ def test_training_vlm_and_vllm(self, model_id) -> None:
)
trainer = RLOOTrainer(
model=model_id,
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")

@require_vision
def test_training_vlm_multi_image(self):
dataset = load_dataset("trl-internal-testing/zen-multi-image", "conversational_prompt_only", split="train")

# For now, mixing image+text and text-only examples is not supported, so we filter out text-only examples
dataset = dataset.filter(lambda x: len(x["images"]) > 0)

def reward_func(completions, **kwargs):
"""Reward function that rewards longer completions."""
return [float(len(completion[0]["content"])) for completion in completions]

training_args = RLOOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=8, # reduce the completion length to reduce memory usage
max_prompt_length=None, # disable prompt truncation, because usually, models don't support it
report_to="none",
)
trainer = RLOOTrainer(
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
)
Expand Down
Loading
Loading