Skip to content

Commit 9e5e60c

Browse files
SamuelBarryCSsergiopaniegoqgallouedec
authored
👩‍🦯 Fix usage of VLM using text only (#4080)
Co-authored-by: Sergio Paniego Blanco <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]>
1 parent 5c52f46 commit 9e5e60c

File tree

2 files changed

+65
-6
lines changed

2 files changed

+65
-6
lines changed

tests/test_sft_trainer.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1374,6 +1374,36 @@ def test_train_vlm_gemma_3n(self):
13741374
continue
13751375
self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated")
13761376

1377+
@require_vision
1378+
def test_train_vlm_text_only_data(self):
1379+
# Get the dataset
1380+
dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train")
1381+
1382+
# Initialize the trainer
1383+
training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none")
1384+
trainer = SFTTrainer(
1385+
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
1386+
args=training_args,
1387+
train_dataset=dataset,
1388+
)
1389+
1390+
# Save the initial parameters to compare them later
1391+
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
1392+
1393+
# Train the model
1394+
trainer.train()
1395+
1396+
# Check that the training loss is not None
1397+
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
1398+
1399+
# Check the params have changed
1400+
for n, param in previous_trainable_params.items():
1401+
new_param = trainer.model.get_parameter(n)
1402+
if n.startswith("model.visual"):
1403+
self.assertTrue(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is updated")
1404+
else:
1405+
self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated")
1406+
13771407
@require_peft
13781408
def test_prompt_tuning(self):
13791409
"""Test that SFT works with Prompt Tuning."""

trl/trainer/sft_trainer.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -758,7 +758,14 @@ def __init__(
758758
else:
759759
self.completion_only_loss = args.completion_only_loss
760760

761-
if data_collator is None and not self._is_vlm:
761+
self._is_vision_dataset = "image" in dataset_sample or "images" in dataset_sample
762+
if self._is_vision_dataset and not self._is_vlm:
763+
raise ValueError(
764+
"The dataset appears to be vision-related (contains 'image' or 'images' keys), but the provided "
765+
"model does not seem to be a vision-language model. Please check your model and dataset."
766+
)
767+
768+
if data_collator is None and not self._is_vision_dataset:
762769
# Get the pad token: if not provided, use the one from the processing class or the eos token
763770
# if the processing class does not have a pad token.
764771
pad_token = args.pad_token or tokenizer.pad_token or tokenizer.eos_token
@@ -777,7 +784,7 @@ def __init__(
777784
return_position_ids=use_flash_attention,
778785
pad_to_multiple_of=args.pad_to_multiple_of,
779786
)
780-
elif data_collator is None and self._is_vlm:
787+
elif data_collator is None and self._is_vision_dataset:
781788
data_collator = DataCollatorForVisionLanguageModeling(
782789
processor=processing_class,
783790
max_length=args.max_length,
@@ -805,7 +812,9 @@ def __init__(
805812
# Skip dataset preparation if `skip_prepare_dataset=True` in `dataset_kwargs`, or if it's a VLM, where
806813
# preprocessing (e.g., image-to-pixel conversion) is too costly and done on the fly instead.
807814
skip_prepare_dataset = (
808-
args.dataset_kwargs is not None and args.dataset_kwargs.get("skip_prepare_dataset", False) or self._is_vlm
815+
args.dataset_kwargs is not None
816+
and args.dataset_kwargs.get("skip_prepare_dataset", False)
817+
or self._is_vision_dataset
809818
)
810819
if not skip_prepare_dataset:
811820
if self.completion_only_loss and formatting_func:
@@ -959,22 +968,36 @@ def add_eos(example, eos_token):
959968
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
960969
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
961970

962-
def tokenize(example, processing_class, dataset_text_field, assistant_only_loss):
971+
def tokenize_fn(example, processing_class, dataset_text_field, assistant_only_loss):
963972
if "prompt" in example: # prompt-completion case
964973
output = {}
965974
if is_conversational(example):
975+
if self._is_vlm:
976+
prepare_multimodal_messages(example["prompt"], num_images=0)
977+
prepare_multimodal_messages(example["completion"], num_images=0)
966978
prompt_ids = processing_class.apply_chat_template(
967979
example["prompt"],
980+
tokenize=True,
968981
tools=example.get("tools"),
969982
**example.get("chat_template_kwargs", {}),
970983
)
984+
# Fix transformers inconsistency: for VLMs, apply_chat_template returns lists of lists
985+
# even for single examples, while for LLMs it returns lists of ints.
986+
prompt_ids = prompt_ids[0] if isinstance(prompt_ids[0], list) else prompt_ids
971987
prompt_completion_processed = processing_class.apply_chat_template(
972988
example["prompt"] + example["completion"],
973989
return_dict=True,
990+
tokenize=True,
974991
return_assistant_tokens_mask=assistant_only_loss,
975992
tools=example.get("tools"),
976993
**example.get("chat_template_kwargs", {}),
977994
)
995+
# Fix transformers inconsistency: for VLMs, apply_chat_template returns lists of lists
996+
# even for single examples, while for LLMs it returns lists of ints.
997+
prompt_completion_processed = {
998+
k: v[0] if isinstance(v[0], list) else v
999+
for k, v in prompt_completion_processed.items()
1000+
}
9781001
prompt_completion_ids = prompt_completion_processed["input_ids"]
9791002
if "assistant_masks" in prompt_completion_processed:
9801003
output["assistant_masks"] = prompt_completion_processed["assistant_masks"]
@@ -999,13 +1022,19 @@ def tokenize(example, processing_class, dataset_text_field, assistant_only_loss)
9991022

10001023
else: # language modeling case
10011024
if is_conversational(example):
1025+
if self._is_vlm:
1026+
prepare_multimodal_messages(example["messages"], num_images=0)
10021027
processed = processing_class.apply_chat_template(
10031028
example["messages"],
10041029
return_dict=True,
1030+
tokenize=True,
10051031
return_assistant_tokens_mask=assistant_only_loss,
10061032
tools=example.get("tools"),
10071033
**example.get("chat_template_kwargs", {}),
10081034
)
1035+
# Fix transformers inconsistency: for VLMs, apply_chat_template returns lists of lists
1036+
# even for single examples, while for LLMs it returns lists of ints.
1037+
processed = {k: v[0] if isinstance(v[0], list) else v for k, v in processed.items()}
10091038
if "assistant_masks" in processed and 1 not in processed["assistant_masks"]:
10101039
raise RuntimeError(
10111040
"You're using `assistant_only_loss=True`, but at least one example has no "
@@ -1020,7 +1049,7 @@ def tokenize(example, processing_class, dataset_text_field, assistant_only_loss)
10201049
return output
10211050

10221051
dataset = dataset.map(
1023-
tokenize,
1052+
tokenize_fn,
10241053
fn_kwargs={
10251054
"processing_class": processing_class,
10261055
"dataset_text_field": args.dataset_text_field,
@@ -1064,7 +1093,7 @@ def _set_signature_columns_if_needed(self):
10641093
# and "attention_mask"). When using `train_on_completion_only` we add a "completion_mask" column to the
10651094
# dataset. So we need to override the default signature columns to include "completion_mask" as well.
10661095
if self._signature_columns is None:
1067-
if self._is_vlm:
1096+
if self._is_vision_dataset:
10681097
self._signature_columns = ["messages", "prompt", "completion", "images"]
10691098
else:
10701099
self._signature_columns = ["input_ids", "labels", "seq_lengths", "completion_mask", "assistant_masks"]

0 commit comments

Comments
 (0)