diff --git a/src/axolotl/utils/collators/mm_chat.py b/src/axolotl/utils/collators/mm_chat.py index f49e97f37..73429518b 100644 --- a/src/axolotl/utils/collators/mm_chat.py +++ b/src/axolotl/utils/collators/mm_chat.py @@ -7,7 +7,7 @@ from transformers import PreTrainedTokenizerBase, ProcessorMixin from transformers.data.data_collator import DataCollatorMixin from transformers.utils import PaddingStrategy - +from PIL import Image @dataclass class MultiModalChatDataCollator(DataCollatorMixin): @@ -52,7 +52,7 @@ def process_rows(examples, processor, chat_template, max_images, length_only=Fal ) for example in examples ] - images = [example["images"] for example in examples] + images = [Image.open(example["images"]) if type(example["images"])==str else example["images"] for example in examples] if max_images > 0: images = [img_batch[:max_images] for img_batch in images]