3
3
import numpy as np
4
4
import torch
5
5
import torchvision .transforms as transforms
6
- from PIL import Image
6
+ from PIL import Image , ImageDraw , ImageFont
7
7
from transformers import AutoModel , AutoTokenizer
8
8
9
9
from ...dataset import DATASET_TYPE
10
10
from ...smp import *
11
11
from ..base import BaseModel
12
12
13
13
pattern = re .compile (r'[A-Z]' )
14
+ conv_pattern = '\\ [UNUSED_TOKEN_146\\ ]user\\ \n |\\ [UNUSED_TOKEN_146\\ ]assistant\\ \n |\\ [UNUSED_TOKEN_145\\ ]'
15
+
16
+
17
+ def get_font ():
18
+ try :
19
+ truetype_url = "http://opencompass.openxlab.space/utils/Fonts/SimHei.ttf"
20
+ ff = urlopen (truetype_url )
21
+ # ff = '/fs-computility/mllm/shared/dongxiaoyi/share_data/SimHei.ttf'
22
+ font = ImageFont .truetype (ff , size = 40 )
23
+ except Exception as e :
24
+ logging .warning (f'{ type (e )} : { e } ' )
25
+ logging .warning ("Fail to download the font. Use the default one." )
26
+ font = ImageFont .load_default (size = 40 )
27
+ return font
14
28
15
29
16
30
def padding_560 (b ):
@@ -26,6 +40,29 @@ def padding_560(b):
26
40
return b
27
41
28
42
43
+ def Identity_transform (img , hd_num = 25 ):
44
+ width , height = img .size
45
+ trans = False
46
+ if width < height :
47
+ img = img .transpose (Image .TRANSPOSE )
48
+ trans = True
49
+ width , height = img .size
50
+ ratio = (width / height )
51
+ scale = 1
52
+ new_h = int (scale * 560 )
53
+ new_w = int (new_h * ratio )
54
+ # print (new_h, new_w)
55
+
56
+ img = transforms .functional .resize (img , [new_h , new_w ],)
57
+ img = img .transpose (Image .TRANSPOSE )
58
+ img = padding_560 (img )
59
+ width , height = img .size
60
+ if not trans :
61
+ img = img .transpose (Image .TRANSPOSE )
62
+
63
+ return img
64
+
65
+
29
66
def HD_transform (img , im_num = 36 , id_scale = 1.5 ):
30
67
width , height = img .size
31
68
trans = False
@@ -53,15 +90,70 @@ def HD_transform(img, im_num=36, id_scale=1.5):
53
90
return img
54
91
55
92
93
+ def img_process (imgs ):
94
+ new_imgs = []
95
+ for img in imgs :
96
+ w , h = img .size
97
+ scale = w / h
98
+ if w > h :
99
+ new_w = 560 * 2
100
+ new_h = int (560 * 2 / scale )
101
+ else :
102
+ new_w = int (560 * 2 * scale )
103
+ new_h = 560 * 2
104
+ img = transforms .functional .resize (img , [new_h , new_w ],)
105
+ new_imgs .append (img )
106
+ imgs = new_imgs
107
+ new_w = 0
108
+ new_h = 0
109
+ pad = 40
110
+ if w > h :
111
+ for im in imgs :
112
+ w ,h = im .size
113
+ new_w = max (new_w , w )
114
+ new_h += h + 10 + pad
115
+ font = get_font ()
116
+ new_img = Image .new ('RGB' , (new_w , new_h ), 'white' )
117
+ draw = ImageDraw .Draw (new_img )
118
+ curr_h = 0
119
+ for idx , im in enumerate (imgs ):
120
+ w ,h = im .size
121
+ new_img .paste (im , (0 , pad + curr_h ))
122
+ draw .text ((0 , curr_h ), f'<IMAGE { idx } >' , font = font , fill = 'black' )
123
+ if idx + 1 < len (imgs ):
124
+ draw .line ([(0 , pad + curr_h + h + 5 ), (new_w , pad + curr_h + h + 5 )], fill = 'black' , width = 2 )
125
+ curr_h += h + 10 + pad
126
+ # print (new_w, new_h)
127
+ else :
128
+ for im in imgs :
129
+ w ,h = im .size
130
+ new_w += w + 10
131
+ new_h = max (new_h , h )
132
+ new_h += pad
133
+ font = get_font ()
134
+ new_img = Image .new ('RGB' , (new_w , new_h ), 'white' )
135
+ draw = ImageDraw .Draw (new_img )
136
+ curr_w = 0
137
+ for idx , im in enumerate (imgs ):
138
+ w ,h = im .size
139
+ new_img .paste (im , (curr_w , pad ))
140
+ draw .text ((curr_w , 0 ), f'<IMAGE { idx } >' , font = font , fill = 'black' )
141
+ if idx + 1 < len (imgs ):
142
+ draw .line ([(curr_w + w + 5 , 0 ), (curr_w + w + 5 , new_h )], fill = 'black' , width = 2 )
143
+ curr_w += w + 10
144
+ return new_img
145
+
146
+
56
147
meta_instruction = """You are an AI assistant whose name is InternLM (书生·浦语).\n " + "- InternLM (书生·浦语) \
57
148
is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室).
58
149
It is designed to be helpful, honest, and harmless.\n "+"- InternLM (书生·浦语) \
59
150
can understand and communicate fluently in the language chosen by the user such as English and 中文."""
60
151
61
152
62
- def model_gen (model , text , images , need_bos = True , padding = False , beams = 3 , max_token = 500 ):
153
+ def model_gen (model , text , images , need_bos = True , padding = False , beams = 3 , max_token = 500 , video_input = False ):
63
154
embeds = []
64
155
im_mask = []
156
+ # print(text)
65
157
66
158
im_idx = 0
67
159
sub_q = text .split ('<IM_POS>' )
@@ -75,15 +167,16 @@ def model_gen(model, text, images, need_bos=True, padding=False, beams=3, max_to
75
167
need_bos = False
76
168
77
169
if im_idx < len (images ) and add_im :
78
- try :
79
- image = Image .open (images [im_idx ]).convert ('RGB' )
80
- except :
81
- image = images [im_idx ].convert ('RGB' )
82
- if len (images ) > 1 :
83
- image = HD_transform (image , im_num = model .hd_num // len (images ), id_scale = model .id_scale )
170
+ image = images [im_idx ]
171
+ if video_input :
172
+ image = Identity_transform (image )
84
173
else :
85
- image = HD_transform (
86
- image , im_num = model .hd_num , id_scale = model .id_scale )
174
+ if len (images ) > 1 :
175
+ image = HD_transform (image , im_num = model .hd_num // len (images ), id_scale = model .id_scale )
176
+ else :
177
+ image = HD_transform (
178
+ image , im_num = model .hd_num , id_scale = model .id_scale )
179
+ # print(image.size)
87
180
image = model .vis_processor (image ).unsqueeze (0 ).to (model .device )
88
181
image_embeds = model .encode_img (image )
89
182
im_idx += 1
@@ -96,22 +189,16 @@ def model_gen(model, text, images, need_bos=True, padding=False, beams=3, max_to
96
189
im_mask = torch .cat (im_mask , dim = 1 )
97
190
im_mask = im_mask .bool ()
98
191
99
- outputs = model .generate (
100
- inputs_embeds = embeds , im_mask = im_mask ,
101
- eos_token_id = [
102
- model .tokenizer .eos_token_id ,
103
- model .tokenizer .convert_tokens_to_ids (['[UNUSED_TOKEN_145]' ])[0 ]
104
- ],
105
- temperature = 1.0 , max_new_tokens = max_token , num_beams = beams ,
106
- do_sample = False , repetition_penalty = 1.0 )
192
+ outputs = model .generate (inputs_embeds = embeds , im_mask = im_mask ,
193
+ temperature = 1.0 , max_new_tokens = max_token , num_beams = beams ,
194
+ do_sample = False , repetition_penalty = 1.0 )
107
195
108
196
output_token = outputs [0 ]
109
197
if output_token [0 ] == 0 or output_token [0 ] == 1 :
110
198
output_token = output_token [1 :]
111
- output_text = model .tokenizer .decode (
112
- output_token , add_special_tokens = False )
113
- output_text = output_text .split ('[UNUSED_TOKEN_145]' )[0 ].strip ()
114
- output_text = output_text .split ('</s>' )[0 ].strip ()
199
+ output_text = model .tokenizer .decode (output_token , add_special_tokens = False )
200
+ output_text = output_text .split ('[UNUSED_TOKEN_145]' )[0 ].strip ().split ('<|im_end|>' )[0 ].strip ().split ('The answer is' )[- 1 ].strip () # noqa
201
+ # print(output_text)
115
202
return output_text
116
203
117
204
@@ -137,41 +224,43 @@ def __init__(self, model_path='internlm/internlm-xcomposer2d5-7b', id_scale=1.5,
137
224
self .model .hd_num = 36
138
225
self .model .id_scale = self .id_scale
139
226
140
- def message_to_promptimg (self , message , dataset = None ):
227
+ def message_to_promptimg (self , message , dataset = None , video_input = False ):
141
228
num_images = len ([x for x in message if x ['type' ] == 'image' ])
142
229
if num_images == 0 :
143
230
prompt = '\n ' .join ([x ['value' ]
144
231
for x in message if x ['type' ] == 'text' ])
145
232
image = None
233
+
146
234
else :
147
- image = [x ['value' ] for x in message if x ['type' ] == 'image' ]
148
- if len (image ) == 1 :
149
- prompt = '' .join ([x ['value' ]
150
- for x in message if x ['type' ] == 'text' ])
151
- im_prompt = '<IM_POS>'
152
- prompt = prompt .replace ('<image 1>' , '' )
153
- prompt = im_prompt + prompt
235
+ image = [Image .open (x ['value' ]).convert ('RGB' ) for x in message if x ['type' ] == 'image' ]
236
+
237
+ if video_input :
238
+ im_prompt = '<IM_POS>Here are some frames of a video.'
239
+ if len (image ) > 64 :
240
+ step = len (image ) / 64
241
+ image = [image [int (i * step )] for i in range (64 )]
242
+ image = [img_process (image )]
243
+
154
244
else :
155
- prompt = ''
156
- im_prompt = [
157
- f'Image{ im_idx + 1 } : <IM_POS>;' for im_idx in range (len (image ))]
158
- add_im = len (im_prompt )
159
- im_idx = 0
160
- for x in message :
161
- if x ['type' ] == 'text' :
162
- prompt += x ['value' ]
163
- if add_im > im_idx :
164
- prompt += f'Image{ im_idx + 1 } '
165
- im_idx += 1
166
- im_prompt = ' ' .join (im_prompt )
167
- for i in range (len (image )):
168
- prompt = prompt .replace (f'<image { i + 1 } >' , f'Image{ i + 1 } ' )
169
- # fix bug for multi-image prompt
170
- if dataset is not None and listinstr (['mmlongbench' , 'dude' , 'slidevqa' ], dataset .lower ()):
171
- prompt = '[UNUSED_TOKEN_146]user\n ' + im_prompt + re .sub (
172
- re .escape ('[UNUSED_TOKEN_146]user\n ' ), '' , prompt
173
- )
174
- prompt = re .sub ('Image1$' , '' , prompt )
245
+ if len (image ) > 1 :
246
+ im_prompt = ' ' .join ([
247
+ f'Image{ im_idx + 1 } : <IM_POS>;' for im_idx in range (len (image ))])
248
+ else :
249
+ im_prompt = '<IM_POS>'
250
+
251
+ prompt = ''
252
+ for x in message :
253
+ if x ['type' ] == 'text' and x .get ('role' , '' ) != 'system' :
254
+ prompt += x ['value' ]
255
+ sp = [i for i in re .split (conv_pattern , prompt ) if i != '' and i != '\n ' ]
256
+ assert len (sp ) <= 2
257
+ q = sp [0 ]
258
+ prompt = f'[UNUSED_TOKEN_146]user\n { im_prompt } { q } [UNUSED_TOKEN_145]\n [UNUSED_TOKEN_146]assistant\n '
259
+
260
+ for idx in range (10 ):
261
+ idx = chr (65 + idx )
262
+ prompt = prompt .replace (f'({ idx } )' , f'{ idx } .' )
263
+
175
264
return prompt , image
176
265
177
266
def generate_mme (self , image_path , text ):
@@ -209,32 +298,43 @@ def generate_brief(self, image_path, text):
209
298
need_bos = True , max_token = 10 )
210
299
return out
211
300
301
+ def generate_video (self , image_path , text ):
302
+ out = model_gen (
303
+ self .model , text , image_path , beams = 1 , # self.beam,
304
+ need_bos = True , max_token = 100 , video_input = True )
305
+ return out
306
+
212
307
def set_max_num (self , dataset ):
213
308
if dataset is not None and listinstr (['MME-RealWorld' , 'MME-RealWorld-CN' ], dataset ):
214
309
self .model .hd_num = 25
215
310
216
311
def generate_inner (self , message , dataset = None ):
217
312
self .set_max_num (dataset )
218
- prompt , image_path = self .message_to_promptimg (message , dataset = dataset )
219
-
220
313
with torch .cuda .amp .autocast ():
221
314
if dataset is None :
315
+ prompt , image_path = self .message_to_promptimg (message , dataset = dataset )
222
316
return self .generate_vanilla (image_path , prompt )
223
317
assert isinstance (dataset , str )
224
- if dataset == 'MME' :
225
- return self .generate_mme (image_path , prompt )
226
- elif listinstr (['hallu' , 'pope' ], dataset .lower ()):
227
- return self .generate_brief (image_path , prompt )
228
- elif listinstr (['llava' , 'mmvet' ], dataset .lower ()):
229
- return self .generate_vanilla (image_path , prompt )
230
- elif dataset is not None and DATASET_TYPE (dataset ) == 'MCQ' :
231
- return self .generate_multichoice (image_path , prompt , dataset )
232
- elif listinstr (['MME-RealWorld' , 'MME-RealWorld-CN' ], dataset ):
233
- return self .generate_multichoice (image_path , prompt , dataset )
234
- elif dataset is not None and DATASET_TYPE (dataset ) == 'VQA' :
235
- return self .generate_vqa (image_path , prompt )
318
+
319
+ if listinstr (['video' , 'mvbench' ], dataset .lower ()):
320
+ prompt , image_path = self .message_to_promptimg (message , dataset = dataset , video_input = True )
321
+ return self .generate_video (image_path , prompt )
236
322
else :
237
- return self .generate_vanilla (image_path , prompt )
323
+ prompt , image_path = self .message_to_promptimg (message , dataset = dataset )
324
+ if dataset == 'MME' :
325
+ return self .generate_mme (image_path , prompt )
326
+ elif listinstr (['hallu' , 'pope' ], dataset .lower ()):
327
+ return self .generate_brief (image_path , prompt )
328
+ elif listinstr (['llava' , 'mmvet' ], dataset .lower ()):
329
+ return self .generate_vanilla (image_path , prompt )
330
+ elif dataset is not None and DATASET_TYPE (dataset ) == 'MCQ' :
331
+ return self .generate_multichoice (image_path , prompt , dataset )
332
+ elif listinstr (['MME-RealWorld' , 'MME-RealWorld-CN' ], dataset ):
333
+ return self .generate_multichoice (image_path , prompt , dataset )
334
+ elif dataset is not None and DATASET_TYPE (dataset ) == 'VQA' :
335
+ return self .generate_vqa (image_path , prompt )
336
+ else :
337
+ return self .generate_vanilla (image_path , prompt )
238
338
239
339
def use_custom_prompt (self , dataset ):
240
340
assert dataset is not None
@@ -291,8 +391,8 @@ def build_prompt(self, line, dataset=None):
291
391
prompt = f'[UNUSED_TOKEN_146]user\n { q } [UNUSED_TOKEN_145]\n [UNUSED_TOKEN_146]assistant\n '
292
392
else :
293
393
q = line ['question' ]
294
- prompt = f'[UNUSED_TOKEN_146]user \n Answer the question using a single word or phrase.\
295
- { q } [UNUSED_TOKEN_145]\n [UNUSED_TOKEN_146]assistant\n '
394
+ prefix = 'Answer the question using a single word or phrase.'
395
+ prompt = f'[UNUSED_TOKEN_146]user \n { prefix } { q } [UNUSED_TOKEN_145]\n [UNUSED_TOKEN_146]assistant\n '
296
396
ret = [dict (type = 'text' , value = prompt )]
297
397
ret .extend ([dict (type = 'image' , value = s ) for s in tgt_path ])
298
398
return ret
0 commit comments