Skip to content

Commit 88f1746

Browse files
[Fix] Fix XComposer2d5 (#668)
* [Fix] Fix XComposer2.5 * update
1 parent 8c2c130 commit 88f1746

File tree

1 file changed

+167
-67
lines changed

1 file changed

+167
-67
lines changed

vlmeval/vlm/xcomposer/xcomposer2d5.py

+167-67
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,28 @@
33
import numpy as np
44
import torch
55
import torchvision.transforms as transforms
6-
from PIL import Image
6+
from PIL import Image, ImageDraw, ImageFont
77
from transformers import AutoModel, AutoTokenizer
88

99
from ...dataset import DATASET_TYPE
1010
from ...smp import *
1111
from ..base import BaseModel
1212

1313
pattern = re.compile(r'[A-Z]')
14+
conv_pattern = '\\[UNUSED_TOKEN_146\\]user\\\n|\\[UNUSED_TOKEN_146\\]assistant\\\n|\\[UNUSED_TOKEN_145\\]'
15+
16+
17+
def get_font():
18+
try:
19+
truetype_url = "http://opencompass.openxlab.space/utils/Fonts/SimHei.ttf"
20+
ff = urlopen(truetype_url)
21+
# ff = '/fs-computility/mllm/shared/dongxiaoyi/share_data/SimHei.ttf'
22+
font = ImageFont.truetype(ff, size=40)
23+
except Exception as e:
24+
logging.warning(f'{type(e)}: {e}')
25+
logging.warning("Fail to download the font. Use the default one.")
26+
font = ImageFont.load_default(size=40)
27+
return font
1428

1529

1630
def padding_560(b):
@@ -26,6 +40,29 @@ def padding_560(b):
2640
return b
2741

2842

43+
def Identity_transform(img, hd_num=25):
44+
width, height = img.size
45+
trans = False
46+
if width < height:
47+
img = img.transpose(Image.TRANSPOSE)
48+
trans = True
49+
width, height = img.size
50+
ratio = (width / height)
51+
scale = 1
52+
new_h = int(scale * 560)
53+
new_w = int(new_h * ratio)
54+
# print (new_h, new_w)
55+
56+
img = transforms.functional.resize(img, [new_h, new_w],)
57+
img = img.transpose(Image.TRANSPOSE)
58+
img = padding_560(img)
59+
width, height = img.size
60+
if not trans:
61+
img = img.transpose(Image.TRANSPOSE)
62+
63+
return img
64+
65+
2966
def HD_transform(img, im_num=36, id_scale=1.5):
3067
width, height = img.size
3168
trans = False
@@ -53,15 +90,70 @@ def HD_transform(img, im_num=36, id_scale=1.5):
5390
return img
5491

5592

93+
def img_process(imgs):
94+
new_imgs = []
95+
for img in imgs:
96+
w, h = img.size
97+
scale = w / h
98+
if w > h:
99+
new_w = 560 * 2
100+
new_h = int(560 * 2 / scale)
101+
else:
102+
new_w = int(560 * 2 * scale)
103+
new_h = 560 * 2
104+
img = transforms.functional.resize(img, [new_h, new_w],)
105+
new_imgs.append(img)
106+
imgs = new_imgs
107+
new_w = 0
108+
new_h = 0
109+
pad = 40
110+
if w > h:
111+
for im in imgs:
112+
w,h = im.size
113+
new_w = max(new_w, w)
114+
new_h += h + 10 + pad
115+
font = get_font()
116+
new_img = Image.new('RGB', (new_w, new_h), 'white')
117+
draw = ImageDraw.Draw(new_img)
118+
curr_h = 0
119+
for idx, im in enumerate(imgs):
120+
w,h = im.size
121+
new_img.paste(im, (0, pad + curr_h))
122+
draw.text((0, curr_h), f'<IMAGE {idx}>', font=font, fill='black')
123+
if idx + 1 < len(imgs):
124+
draw.line([(0, pad + curr_h + h + 5), (new_w, pad + curr_h + h + 5)], fill='black', width=2)
125+
curr_h += h + 10 + pad
126+
# print (new_w, new_h)
127+
else:
128+
for im in imgs:
129+
w,h = im.size
130+
new_w += w + 10
131+
new_h = max(new_h, h)
132+
new_h += pad
133+
font = get_font()
134+
new_img = Image.new('RGB', (new_w, new_h), 'white')
135+
draw = ImageDraw.Draw(new_img)
136+
curr_w = 0
137+
for idx, im in enumerate(imgs):
138+
w,h = im.size
139+
new_img.paste(im, (curr_w, pad))
140+
draw.text((curr_w, 0), f'<IMAGE {idx}>', font=font, fill='black')
141+
if idx + 1 < len(imgs):
142+
draw.line([(curr_w + w + 5, 0), (curr_w + w + 5, new_h)], fill='black', width=2)
143+
curr_w += w + 10
144+
return new_img
145+
146+
56147
meta_instruction = """You are an AI assistant whose name is InternLM (书生·浦语).\n" + "- InternLM (书生·浦语) \
57148
is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室).
58149
It is designed to be helpful, honest, and harmless.\n"+"- InternLM (书生·浦语) \
59150
can understand and communicate fluently in the language chosen by the user such as English and 中文."""
60151

61152

62-
def model_gen(model, text, images, need_bos=True, padding=False, beams=3, max_token=500):
153+
def model_gen(model, text, images, need_bos=True, padding=False, beams=3, max_token=500, video_input=False):
63154
embeds = []
64155
im_mask = []
156+
# print(text)
65157

66158
im_idx = 0
67159
sub_q = text.split('<IM_POS>')
@@ -75,15 +167,16 @@ def model_gen(model, text, images, need_bos=True, padding=False, beams=3, max_to
75167
need_bos = False
76168

77169
if im_idx < len(images) and add_im:
78-
try:
79-
image = Image.open(images[im_idx]).convert('RGB')
80-
except:
81-
image = images[im_idx].convert('RGB')
82-
if len(images) > 1:
83-
image = HD_transform(image, im_num=model.hd_num // len(images), id_scale=model.id_scale)
170+
image = images[im_idx]
171+
if video_input:
172+
image = Identity_transform(image)
84173
else:
85-
image = HD_transform(
86-
image, im_num=model.hd_num, id_scale=model.id_scale)
174+
if len(images) > 1:
175+
image = HD_transform(image, im_num=model.hd_num // len(images), id_scale=model.id_scale)
176+
else:
177+
image = HD_transform(
178+
image, im_num=model.hd_num, id_scale=model.id_scale)
179+
# print(image.size)
87180
image = model.vis_processor(image).unsqueeze(0).to(model.device)
88181
image_embeds = model.encode_img(image)
89182
im_idx += 1
@@ -96,22 +189,16 @@ def model_gen(model, text, images, need_bos=True, padding=False, beams=3, max_to
96189
im_mask = torch.cat(im_mask, dim=1)
97190
im_mask = im_mask.bool()
98191

99-
outputs = model.generate(
100-
inputs_embeds=embeds, im_mask=im_mask,
101-
eos_token_id=[
102-
model.tokenizer.eos_token_id,
103-
model.tokenizer.convert_tokens_to_ids(['[UNUSED_TOKEN_145]'])[0]
104-
],
105-
temperature=1.0, max_new_tokens=max_token, num_beams=beams,
106-
do_sample=False, repetition_penalty=1.0)
192+
outputs = model.generate(inputs_embeds=embeds, im_mask=im_mask,
193+
temperature=1.0, max_new_tokens=max_token, num_beams=beams,
194+
do_sample=False, repetition_penalty=1.0)
107195

108196
output_token = outputs[0]
109197
if output_token[0] == 0 or output_token[0] == 1:
110198
output_token = output_token[1:]
111-
output_text = model.tokenizer.decode(
112-
output_token, add_special_tokens=False)
113-
output_text = output_text.split('[UNUSED_TOKEN_145]')[0].strip()
114-
output_text = output_text.split('</s>')[0].strip()
199+
output_text = model.tokenizer.decode(output_token, add_special_tokens=False)
200+
output_text = output_text.split('[UNUSED_TOKEN_145]')[0].strip().split('<|im_end|>')[0].strip().split('The answer is')[-1].strip() # noqa
201+
# print(output_text)
115202
return output_text
116203

117204

@@ -137,41 +224,43 @@ def __init__(self, model_path='internlm/internlm-xcomposer2d5-7b', id_scale=1.5,
137224
self.model.hd_num = 36
138225
self.model.id_scale = self.id_scale
139226

140-
def message_to_promptimg(self, message, dataset=None):
227+
def message_to_promptimg(self, message, dataset=None, video_input=False):
141228
num_images = len([x for x in message if x['type'] == 'image'])
142229
if num_images == 0:
143230
prompt = '\n'.join([x['value']
144231
for x in message if x['type'] == 'text'])
145232
image = None
233+
146234
else:
147-
image = [x['value'] for x in message if x['type'] == 'image']
148-
if len(image) == 1:
149-
prompt = ''.join([x['value']
150-
for x in message if x['type'] == 'text'])
151-
im_prompt = '<IM_POS>'
152-
prompt = prompt.replace('<image 1>', '')
153-
prompt = im_prompt + prompt
235+
image = [Image.open(x['value']).convert('RGB') for x in message if x['type'] == 'image']
236+
237+
if video_input:
238+
im_prompt = '<IM_POS>Here are some frames of a video.'
239+
if len(image) > 64:
240+
step = len(image) / 64
241+
image = [image[int(i * step)] for i in range(64)]
242+
image = [img_process(image)]
243+
154244
else:
155-
prompt = ''
156-
im_prompt = [
157-
f'Image{im_idx+1}: <IM_POS>;' for im_idx in range(len(image))]
158-
add_im = len(im_prompt)
159-
im_idx = 0
160-
for x in message:
161-
if x['type'] == 'text':
162-
prompt += x['value']
163-
if add_im > im_idx:
164-
prompt += f'Image{im_idx + 1}'
165-
im_idx += 1
166-
im_prompt = ' '.join(im_prompt)
167-
for i in range(len(image)):
168-
prompt = prompt.replace(f'<image {i+1}>', f'Image{i+1} ')
169-
# fix bug for multi-image prompt
170-
if dataset is not None and listinstr(['mmlongbench', 'dude', 'slidevqa'], dataset.lower()):
171-
prompt = '[UNUSED_TOKEN_146]user\n' + im_prompt + re.sub(
172-
re.escape('[UNUSED_TOKEN_146]user\n'), '', prompt
173-
)
174-
prompt = re.sub('Image1$', '', prompt)
245+
if len(image) > 1:
246+
im_prompt = ' '.join([
247+
f'Image{im_idx+1}: <IM_POS>;' for im_idx in range(len(image))])
248+
else:
249+
im_prompt = '<IM_POS>'
250+
251+
prompt = ''
252+
for x in message:
253+
if x['type'] == 'text' and x.get('role', '') != 'system':
254+
prompt += x['value']
255+
sp = [i for i in re.split(conv_pattern, prompt) if i != '' and i != '\n']
256+
assert len(sp) <= 2
257+
q = sp[0]
258+
prompt = f'[UNUSED_TOKEN_146]user\n{im_prompt}{q}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n'
259+
260+
for idx in range(10):
261+
idx = chr(65 + idx)
262+
prompt = prompt.replace(f'({idx})', f'{idx}.')
263+
175264
return prompt, image
176265

177266
def generate_mme(self, image_path, text):
@@ -209,32 +298,43 @@ def generate_brief(self, image_path, text):
209298
need_bos=True, max_token=10)
210299
return out
211300

301+
def generate_video(self, image_path, text):
302+
out = model_gen(
303+
self.model, text, image_path, beams=1, # self.beam,
304+
need_bos=True, max_token=100, video_input=True)
305+
return out
306+
212307
def set_max_num(self, dataset):
213308
if dataset is not None and listinstr(['MME-RealWorld', 'MME-RealWorld-CN'], dataset):
214309
self.model.hd_num = 25
215310

216311
def generate_inner(self, message, dataset=None):
217312
self.set_max_num(dataset)
218-
prompt, image_path = self.message_to_promptimg(message, dataset=dataset)
219-
220313
with torch.cuda.amp.autocast():
221314
if dataset is None:
315+
prompt, image_path = self.message_to_promptimg(message, dataset=dataset)
222316
return self.generate_vanilla(image_path, prompt)
223317
assert isinstance(dataset, str)
224-
if dataset == 'MME':
225-
return self.generate_mme(image_path, prompt)
226-
elif listinstr(['hallu', 'pope'], dataset.lower()):
227-
return self.generate_brief(image_path, prompt)
228-
elif listinstr(['llava', 'mmvet'], dataset.lower()):
229-
return self.generate_vanilla(image_path, prompt)
230-
elif dataset is not None and DATASET_TYPE(dataset) == 'MCQ':
231-
return self.generate_multichoice(image_path, prompt, dataset)
232-
elif listinstr(['MME-RealWorld', 'MME-RealWorld-CN'], dataset):
233-
return self.generate_multichoice(image_path, prompt, dataset)
234-
elif dataset is not None and DATASET_TYPE(dataset) == 'VQA':
235-
return self.generate_vqa(image_path, prompt)
318+
319+
if listinstr(['video', 'mvbench'], dataset.lower()):
320+
prompt, image_path = self.message_to_promptimg(message, dataset=dataset, video_input=True)
321+
return self.generate_video(image_path, prompt)
236322
else:
237-
return self.generate_vanilla(image_path, prompt)
323+
prompt, image_path = self.message_to_promptimg(message, dataset=dataset)
324+
if dataset == 'MME':
325+
return self.generate_mme(image_path, prompt)
326+
elif listinstr(['hallu', 'pope'], dataset.lower()):
327+
return self.generate_brief(image_path, prompt)
328+
elif listinstr(['llava', 'mmvet'], dataset.lower()):
329+
return self.generate_vanilla(image_path, prompt)
330+
elif dataset is not None and DATASET_TYPE(dataset) == 'MCQ':
331+
return self.generate_multichoice(image_path, prompt, dataset)
332+
elif listinstr(['MME-RealWorld', 'MME-RealWorld-CN'], dataset):
333+
return self.generate_multichoice(image_path, prompt, dataset)
334+
elif dataset is not None and DATASET_TYPE(dataset) == 'VQA':
335+
return self.generate_vqa(image_path, prompt)
336+
else:
337+
return self.generate_vanilla(image_path, prompt)
238338

239339
def use_custom_prompt(self, dataset):
240340
assert dataset is not None
@@ -291,8 +391,8 @@ def build_prompt(self, line, dataset=None):
291391
prompt = f'[UNUSED_TOKEN_146]user\n{q}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n'
292392
else:
293393
q = line['question']
294-
prompt = f'[UNUSED_TOKEN_146]user\nAnswer the question using a single word or phrase.\
295-
{q}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n'
394+
prefix = 'Answer the question using a single word or phrase.'
395+
prompt = f'[UNUSED_TOKEN_146]user\n{prefix}{q}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n'
296396
ret = [dict(type='text', value=prompt)]
297397
ret.extend([dict(type='image', value=s) for s in tgt_path])
298398
return ret

0 commit comments

Comments
 (0)