@@ -758,7 +758,14 @@ def __init__(
758
758
else :
759
759
self .completion_only_loss = args .completion_only_loss
760
760
761
- if data_collator is None and not self ._is_vlm :
761
+ self ._is_vision_dataset = "image" in dataset_sample or "images" in dataset_sample
762
+ if self ._is_vision_dataset and not self ._is_vlm :
763
+ raise ValueError (
764
+ "The dataset appears to be vision-related (contains 'image' or 'images' keys), but the provided "
765
+ "model does not seem to be a vision-language model. Please check your model and dataset."
766
+ )
767
+
768
+ if data_collator is None and not self ._is_vision_dataset :
762
769
# Get the pad token: if not provided, use the one from the processing class or the eos token
763
770
# if the processing class does not have a pad token.
764
771
pad_token = args .pad_token or tokenizer .pad_token or tokenizer .eos_token
@@ -777,7 +784,7 @@ def __init__(
777
784
return_position_ids = use_flash_attention ,
778
785
pad_to_multiple_of = args .pad_to_multiple_of ,
779
786
)
780
- elif data_collator is None and self ._is_vlm :
787
+ elif data_collator is None and self ._is_vision_dataset :
781
788
data_collator = DataCollatorForVisionLanguageModeling (
782
789
processor = processing_class ,
783
790
max_length = args .max_length ,
@@ -805,7 +812,9 @@ def __init__(
805
812
# Skip dataset preparation if `skip_prepare_dataset=True` in `dataset_kwargs`, or if it's a VLM, where
806
813
# preprocessing (e.g., image-to-pixel conversion) is too costly and done on the fly instead.
807
814
skip_prepare_dataset = (
808
- args .dataset_kwargs is not None and args .dataset_kwargs .get ("skip_prepare_dataset" , False ) or self ._is_vlm
815
+ args .dataset_kwargs is not None
816
+ and args .dataset_kwargs .get ("skip_prepare_dataset" , False )
817
+ or self ._is_vision_dataset
809
818
)
810
819
if not skip_prepare_dataset :
811
820
if self .completion_only_loss and formatting_func :
@@ -959,22 +968,36 @@ def add_eos(example, eos_token):
959
968
if isinstance (dataset , Dataset ): # `IterableDataset.map` does not support `desc`
960
969
map_kwargs ["desc" ] = f"Tokenizing { dataset_name } dataset"
961
970
962
- def tokenize (example , processing_class , dataset_text_field , assistant_only_loss ):
971
+ def tokenize_fn (example , processing_class , dataset_text_field , assistant_only_loss ):
963
972
if "prompt" in example : # prompt-completion case
964
973
output = {}
965
974
if is_conversational (example ):
975
+ if self ._is_vlm :
976
+ prepare_multimodal_messages (example ["prompt" ], num_images = 0 )
977
+ prepare_multimodal_messages (example ["completion" ], num_images = 0 )
966
978
prompt_ids = processing_class .apply_chat_template (
967
979
example ["prompt" ],
980
+ tokenize = True ,
968
981
tools = example .get ("tools" ),
969
982
** example .get ("chat_template_kwargs" , {}),
970
983
)
984
+ # Fix transformers inconsistency: for VLMs, apply_chat_template returns lists of lists
985
+ # even for single examples, while for LLMs it returns lists of ints.
986
+ prompt_ids = prompt_ids [0 ] if isinstance (prompt_ids [0 ], list ) else prompt_ids
971
987
prompt_completion_processed = processing_class .apply_chat_template (
972
988
example ["prompt" ] + example ["completion" ],
973
989
return_dict = True ,
990
+ tokenize = True ,
974
991
return_assistant_tokens_mask = assistant_only_loss ,
975
992
tools = example .get ("tools" ),
976
993
** example .get ("chat_template_kwargs" , {}),
977
994
)
995
+ # Fix transformers inconsistency: for VLMs, apply_chat_template returns lists of lists
996
+ # even for single examples, while for LLMs it returns lists of ints.
997
+ prompt_completion_processed = {
998
+ k : v [0 ] if isinstance (v [0 ], list ) else v
999
+ for k , v in prompt_completion_processed .items ()
1000
+ }
978
1001
prompt_completion_ids = prompt_completion_processed ["input_ids" ]
979
1002
if "assistant_masks" in prompt_completion_processed :
980
1003
output ["assistant_masks" ] = prompt_completion_processed ["assistant_masks" ]
@@ -999,13 +1022,19 @@ def tokenize(example, processing_class, dataset_text_field, assistant_only_loss)
999
1022
1000
1023
else : # language modeling case
1001
1024
if is_conversational (example ):
1025
+ if self ._is_vlm :
1026
+ prepare_multimodal_messages (example ["messages" ], num_images = 0 )
1002
1027
processed = processing_class .apply_chat_template (
1003
1028
example ["messages" ],
1004
1029
return_dict = True ,
1030
+ tokenize = True ,
1005
1031
return_assistant_tokens_mask = assistant_only_loss ,
1006
1032
tools = example .get ("tools" ),
1007
1033
** example .get ("chat_template_kwargs" , {}),
1008
1034
)
1035
+ # Fix transformers inconsistency: for VLMs, apply_chat_template returns lists of lists
1036
+ # even for single examples, while for LLMs it returns lists of ints.
1037
+ processed = {k : v [0 ] if isinstance (v [0 ], list ) else v for k , v in processed .items ()}
1009
1038
if "assistant_masks" in processed and 1 not in processed ["assistant_masks" ]:
1010
1039
raise RuntimeError (
1011
1040
"You're using `assistant_only_loss=True`, but at least one example has no "
@@ -1020,7 +1049,7 @@ def tokenize(example, processing_class, dataset_text_field, assistant_only_loss)
1020
1049
return output
1021
1050
1022
1051
dataset = dataset .map (
1023
- tokenize ,
1052
+ tokenize_fn ,
1024
1053
fn_kwargs = {
1025
1054
"processing_class" : processing_class ,
1026
1055
"dataset_text_field" : args .dataset_text_field ,
@@ -1064,7 +1093,7 @@ def _set_signature_columns_if_needed(self):
1064
1093
# and "attention_mask"). When using `train_on_completion_only` we add a "completion_mask" column to the
1065
1094
# dataset. So we need to override the default signature columns to include "completion_mask" as well.
1066
1095
if self ._signature_columns is None :
1067
- if self ._is_vlm :
1096
+ if self ._is_vision_dataset :
1068
1097
self ._signature_columns = ["messages" , "prompt" , "completion" , "images" ]
1069
1098
else :
1070
1099
self ._signature_columns = ["input_ids" , "labels" , "seq_lengths" , "completion_mask" , "assistant_masks" ]
0 commit comments