2
2
import json
3
3
4
4
5
- def get_prompt (model_type , processor , conversation ):
6
- if model_type == "paligemma" :
7
- return conversation
8
-
9
- if "chat_template" in processor .__dict__ .keys ():
10
- prompt = processor .apply_chat_template (
11
- conversation ,
12
- tokenize = False ,
13
- add_generation_prompt = False ,
14
- )
15
- elif "tokenizer" in processor .__dict__ .keys ():
16
- prompt = processor .tokenizer .apply_chat_template (
17
- conversation ,
18
- tokenize = False ,
19
- add_generation_prompt = False ,
20
- )
21
-
22
- return prompt
23
-
24
-
25
5
class Dataset :
26
6
def __init__ (
27
7
self ,
@@ -53,9 +33,21 @@ def __getitem__(self, idx):
53
33
item = self .dataset [idx ]
54
34
55
35
images = item .get ("images" , item .get ("image" , None ))
56
- conversations = item . get ( "messages" , item . get ( "conversations" ))
57
- if images in ( None , "" , []) :
36
+
37
+ if images is None or images == "" or images == [] :
58
38
images = []
39
+ elif not isinstance (images , list ):
40
+ images = [images ]
41
+
42
+ image_paths = []
43
+ image_data = []
44
+ for img in images :
45
+ if isinstance (img , str ):
46
+ image_paths .append (img )
47
+ else :
48
+ image_data .append (img )
49
+
50
+ conversations = item .get ("messages" , item .get ("conversations" ))
59
51
prompts = []
60
52
61
53
if isinstance (conversations , list ) and isinstance (conversations [0 ], list ):
@@ -67,27 +59,52 @@ def __getitem__(self, idx):
67
59
"Pixtral batch processing is not supported yet. Set batch size to 1."
68
60
)
69
61
70
- prompt = get_prompt (
71
- self .config ["model_type" ], self .processor , conversation
72
- )
62
+ if "chat_template" in self .processor .__dict__ :
63
+ prompt = self .processor .apply_chat_template (
64
+ conversation ,
65
+ tokenize = False ,
66
+ add_generation_prompt = False ,
67
+ num_images = len (images ),
68
+ num_audios = 0 ,
69
+ )
70
+ else :
71
+ prompt = self .processor .tokenizer .apply_chat_template (
72
+ conversation ,
73
+ tokenize = False ,
74
+ add_generation_prompt = False ,
75
+ num_images = len (images ),
76
+ num_audios = 0 ,
77
+ )
73
78
prompts .append (prompt )
74
79
75
80
else :
76
81
if self .config ["model_type" ] == "pixtral" :
77
82
conversations = [json .loads (i ) for i in conversations ]
78
- prompt = get_prompt (
79
- self .config ["model_type" ], self .processor , conversations
80
- )
83
+ if "chat_template" in self .processor .__dict__ :
84
+ prompt = self .processor .apply_chat_template (
85
+ conversations ,
86
+ tokenize = False ,
87
+ add_generation_prompt = False ,
88
+ num_images = len (images ),
89
+ num_audios = 0 ,
90
+ )
91
+ else :
92
+ prompt = self .processor .tokenizer .apply_chat_template (
93
+ conversations ,
94
+ tokenize = False ,
95
+ add_generation_prompt = False ,
96
+ num_images = len (images ),
97
+ num_audios = 0 ,
98
+ )
81
99
prompts .append (prompt )
82
100
83
- image_token_index = getattr (self .config , "image_token_index" , "image_token_id" )
84
101
85
102
inputs = prepare_inputs (
86
103
processor = self .processor ,
87
- images = images ,
104
+ images = image_data ,
88
105
audio = None ,
89
106
prompts = prompts ,
90
- image_token_index = image_token_index ,
107
+ image_token_index = getattr ( self . config , " image_token_index" , "image_token_id" ) ,
91
108
resize_shape = self .image_resize_shape
92
109
)
93
110
0 commit comments