-
Notifications
You must be signed in to change notification settings - Fork 2.3k
📽 Multi image support for GRPO/RLOO #4113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 22 commits
552e899
449ef07
c8933aa
229c554
3ca6ad5
dcf4b92
30ad7ca
86cc30b
088897b
d2adc63
f4c82bf
1257796
099a39b
529add6
fc6b11f
f998432
fa73876
52d8bd9
dfc0d38
fc52e68
e17ec42
efbb03a
562c662
485781c
05270f8
1c53094
9b6652e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1089,6 +1089,10 @@ def test_prepare_input_called_with_correct_data(self): | |
| def test_training_vlm(self, model_id): | ||
| dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") | ||
|
|
||
| def reward_func(completions, **kwargs): | ||
| """Reward function that rewards longer completions.""" | ||
| return [float(len(completion[0]["content"])) for completion in completions] | ||
|
|
||
| training_args = RLOOConfig( | ||
| output_dir=self.tmp_dir, | ||
| learning_rate=0.1, # increase the learning rate to speed up the test | ||
|
|
@@ -1100,7 +1104,7 @@ def test_training_vlm(self, model_id): | |
| ) | ||
| trainer = RLOOTrainer( | ||
| model=model_id, | ||
| reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", | ||
| reward_funcs=reward_func, | ||
| args=training_args, | ||
| train_dataset=dataset, | ||
| ) | ||
|
|
@@ -1132,6 +1136,10 @@ def test_training_vlm(self, model_id): | |
| def test_training_vlm_beta_non_zero(self): | ||
| dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") | ||
|
|
||
| def reward_func(completions, **kwargs): | ||
| """Reward function that rewards longer completions.""" | ||
| return [float(len(completion[0]["content"])) for completion in completions] | ||
|
|
||
| training_args = RLOOConfig( | ||
| output_dir=self.tmp_dir, | ||
| beta=0.1, # set beta to non-zero value to test the case where the reference model is used | ||
|
|
@@ -1143,7 +1151,7 @@ def test_training_vlm_beta_non_zero(self): | |
| ) | ||
| trainer = RLOOTrainer( | ||
| model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", | ||
| reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", | ||
| reward_funcs=reward_func, | ||
| args=training_args, | ||
| train_dataset=dataset, | ||
| ) | ||
|
|
@@ -1171,7 +1179,11 @@ def test_training_vlm_peft(self): | |
| "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration" | ||
| ) | ||
| base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] | ||
| dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") | ||
| dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") | ||
|
|
||
| def reward_func(completions, **kwargs): | ||
| """Reward function that rewards longer completions.""" | ||
| return [float(len(completion[0]["content"])) for completion in completions] | ||
|
|
||
| training_args = RLOOConfig( | ||
| output_dir=self.tmp_dir, | ||
|
|
@@ -1183,7 +1195,7 @@ def test_training_vlm_peft(self): | |
| ) | ||
| trainer = RLOOTrainer( | ||
| model=model, | ||
| reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", | ||
| reward_funcs=reward_func, | ||
| args=training_args, | ||
| train_dataset=dataset, | ||
| peft_config=LoraConfig(target_modules=["q_proj", "v_proj"]), | ||
|
|
@@ -1208,6 +1220,10 @@ def test_training_vlm_and_prompt_truncation(self): | |
| # If not handled properly, prompt truncation may truncate image token | ||
| dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") | ||
|
|
||
| def reward_func(completions, **kwargs): | ||
| """Reward function that rewards longer completions.""" | ||
| return [float(len(completion[0]["content"])) for completion in completions] | ||
|
|
||
| training_args = RLOOConfig( | ||
| output_dir=self.tmp_dir, | ||
| learning_rate=0.1, # increase the learning rate to speed up the test | ||
|
|
@@ -1219,7 +1235,7 @@ def test_training_vlm_and_prompt_truncation(self): | |
| ) | ||
| trainer = RLOOTrainer( | ||
| model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", | ||
| reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", | ||
| reward_funcs=reward_func, | ||
| args=training_args, | ||
| train_dataset=dataset, | ||
| ) | ||
|
|
@@ -1252,6 +1268,10 @@ def test_training_vlm_and_prompt_truncation(self): | |
| def test_training_vlm_and_vllm(self, model_id) -> None: | ||
| dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") | ||
|
|
||
| def reward_func(completions, **kwargs): | ||
| """Reward function that rewards longer completions.""" | ||
| return [float(len(completion[0]["content"])) for completion in completions] | ||
|
|
||
| training_args = RLOOConfig( | ||
| output_dir=self.tmp_dir, | ||
| learning_rate=0.1, | ||
|
|
@@ -1265,7 +1285,44 @@ def test_training_vlm_and_vllm(self, model_id) -> None: | |
| ) | ||
| trainer = RLOOTrainer( | ||
| model=model_id, | ||
| reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", | ||
| reward_funcs=reward_func, | ||
| args=training_args, | ||
| train_dataset=dataset, | ||
| ) | ||
|
|
||
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | ||
|
|
||
| trainer.train() | ||
|
|
||
| self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | ||
|
|
||
| for n, param in previous_trainable_params.items(): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does the same comment for GRPO apply here? https://github.com/huggingface/trl/pull/4113/files#diff-96dca172e696190fc3e1469166e88aface95ebae959284c6806f2e25d2217c16R1587
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. answered here #4113 (comment) |
||
| new_param = trainer.model.get_parameter(n) | ||
| self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") | ||
|
|
||
| @require_vision | ||
| def test_training_vlm_multi_image(self): | ||
| dataset = load_dataset("trl-internal-testing/zen-multi-image", "conversational_prompt_only", split="train") | ||
|
|
||
| # For now, mixing image+text and text-only examples is not supported, so we filter out text-only examples | ||
| dataset = dataset.filter(lambda x: len(x["images"]) > 0) | ||
|
|
||
| def reward_func(completions, **kwargs): | ||
| """Reward function that rewards longer completions.""" | ||
| return [float(len(completion[0]["content"])) for completion in completions] | ||
|
|
||
| training_args = RLOOConfig( | ||
| output_dir=self.tmp_dir, | ||
| learning_rate=0.1, # increase the learning rate to speed up the test | ||
| per_device_train_batch_size=3, # reduce the batch size to reduce memory usage | ||
| num_generations=3, # reduce the number of generations to reduce memory usage | ||
| max_completion_length=8, # reduce the completion length to reduce memory usage | ||
| max_prompt_length=None, # disable prompt truncation, because usually, models don't support it | ||
| report_to="none", | ||
| ) | ||
| trainer = RLOOTrainer( | ||
| model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", | ||
| reward_funcs=reward_func, | ||
| args=training_args, | ||
| train_dataset=dataset, | ||
| ) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't support visual reward model, so it doesn't really make sense to test this case, where the image is dropped and a warning is raised.