Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix XComposer2d5 #668

Merged
merged 2 commits into from
Dec 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 167 additions & 67 deletions vlmeval/vlm/xcomposer/xcomposer2d5.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,28 @@
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
from PIL import Image, ImageDraw, ImageFont
from transformers import AutoModel, AutoTokenizer

from ...dataset import DATASET_TYPE
from ...smp import *
from ..base import BaseModel

pattern = re.compile(r'[A-Z]')
conv_pattern = '\\[UNUSED_TOKEN_146\\]user\\\n|\\[UNUSED_TOKEN_146\\]assistant\\\n|\\[UNUSED_TOKEN_145\\]'


def get_font():
try:
truetype_url = "http://opencompass.openxlab.space/utils/Fonts/SimHei.ttf"
ff = urlopen(truetype_url)
# ff = '/fs-computility/mllm/shared/dongxiaoyi/share_data/SimHei.ttf'
font = ImageFont.truetype(ff, size=40)
except Exception as e:
logging.warning(f'{type(e)}: {e}')
logging.warning("Fail to download the font. Use the default one.")
font = ImageFont.load_default(size=40)
return font


def padding_560(b):
Expand All @@ -26,6 +40,29 @@ def padding_560(b):
return b


def Identity_transform(img, hd_num=25):
width, height = img.size
trans = False
if width < height:
img = img.transpose(Image.TRANSPOSE)
trans = True
width, height = img.size
ratio = (width / height)
scale = 1
new_h = int(scale * 560)
new_w = int(new_h * ratio)
# print (new_h, new_w)

img = transforms.functional.resize(img, [new_h, new_w],)
img = img.transpose(Image.TRANSPOSE)
img = padding_560(img)
width, height = img.size
if not trans:
img = img.transpose(Image.TRANSPOSE)

return img


def HD_transform(img, im_num=36, id_scale=1.5):
width, height = img.size
trans = False
Expand Down Expand Up @@ -53,15 +90,70 @@ def HD_transform(img, im_num=36, id_scale=1.5):
return img


def img_process(imgs):
new_imgs = []
for img in imgs:
w, h = img.size
scale = w / h
if w > h:
new_w = 560 * 2
new_h = int(560 * 2 / scale)
else:
new_w = int(560 * 2 * scale)
new_h = 560 * 2
img = transforms.functional.resize(img, [new_h, new_w],)
new_imgs.append(img)
imgs = new_imgs
new_w = 0
new_h = 0
pad = 40
if w > h:
for im in imgs:
w,h = im.size
new_w = max(new_w, w)
new_h += h + 10 + pad
font = get_font()
new_img = Image.new('RGB', (new_w, new_h), 'white')
draw = ImageDraw.Draw(new_img)
curr_h = 0
for idx, im in enumerate(imgs):
w,h = im.size
new_img.paste(im, (0, pad + curr_h))
draw.text((0, curr_h), f'<IMAGE {idx}>', font=font, fill='black')
if idx + 1 < len(imgs):
draw.line([(0, pad + curr_h + h + 5), (new_w, pad + curr_h + h + 5)], fill='black', width=2)
curr_h += h + 10 + pad
# print (new_w, new_h)
else:
for im in imgs:
w,h = im.size
new_w += w + 10
new_h = max(new_h, h)
new_h += pad
font = get_font()
new_img = Image.new('RGB', (new_w, new_h), 'white')
draw = ImageDraw.Draw(new_img)
curr_w = 0
for idx, im in enumerate(imgs):
w,h = im.size
new_img.paste(im, (curr_w, pad))
draw.text((curr_w, 0), f'<IMAGE {idx}>', font=font, fill='black')
if idx + 1 < len(imgs):
draw.line([(curr_w + w + 5, 0), (curr_w + w + 5, new_h)], fill='black', width=2)
curr_w += w + 10
return new_img


meta_instruction = """You are an AI assistant whose name is InternLM (书生·浦语).\n" + "- InternLM (书生·浦语) \
is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室).
It is designed to be helpful, honest, and harmless.\n"+"- InternLM (书生·浦语) \
can understand and communicate fluently in the language chosen by the user such as English and 中文."""


def model_gen(model, text, images, need_bos=True, padding=False, beams=3, max_token=500):
def model_gen(model, text, images, need_bos=True, padding=False, beams=3, max_token=500, video_input=False):
embeds = []
im_mask = []
# print(text)

im_idx = 0
sub_q = text.split('<IM_POS>')
Expand All @@ -75,15 +167,16 @@ def model_gen(model, text, images, need_bos=True, padding=False, beams=3, max_to
need_bos = False

if im_idx < len(images) and add_im:
try:
image = Image.open(images[im_idx]).convert('RGB')
except:
image = images[im_idx].convert('RGB')
if len(images) > 1:
image = HD_transform(image, im_num=model.hd_num // len(images), id_scale=model.id_scale)
image = images[im_idx]
if video_input:
image = Identity_transform(image)
else:
image = HD_transform(
image, im_num=model.hd_num, id_scale=model.id_scale)
if len(images) > 1:
image = HD_transform(image, im_num=model.hd_num // len(images), id_scale=model.id_scale)
else:
image = HD_transform(
image, im_num=model.hd_num, id_scale=model.id_scale)
# print(image.size)
image = model.vis_processor(image).unsqueeze(0).to(model.device)
image_embeds = model.encode_img(image)
im_idx += 1
Expand All @@ -96,22 +189,16 @@ def model_gen(model, text, images, need_bos=True, padding=False, beams=3, max_to
im_mask = torch.cat(im_mask, dim=1)
im_mask = im_mask.bool()

outputs = model.generate(
inputs_embeds=embeds, im_mask=im_mask,
eos_token_id=[
model.tokenizer.eos_token_id,
model.tokenizer.convert_tokens_to_ids(['[UNUSED_TOKEN_145]'])[0]
],
temperature=1.0, max_new_tokens=max_token, num_beams=beams,
do_sample=False, repetition_penalty=1.0)
outputs = model.generate(inputs_embeds=embeds, im_mask=im_mask,
temperature=1.0, max_new_tokens=max_token, num_beams=beams,
do_sample=False, repetition_penalty=1.0)

output_token = outputs[0]
if output_token[0] == 0 or output_token[0] == 1:
output_token = output_token[1:]
output_text = model.tokenizer.decode(
output_token, add_special_tokens=False)
output_text = output_text.split('[UNUSED_TOKEN_145]')[0].strip()
output_text = output_text.split('</s>')[0].strip()
output_text = model.tokenizer.decode(output_token, add_special_tokens=False)
output_text = output_text.split('[UNUSED_TOKEN_145]')[0].strip().split('<|im_end|>')[0].strip().split('The answer is')[-1].strip() # noqa
# print(output_text)
return output_text


Expand All @@ -137,41 +224,43 @@ def __init__(self, model_path='internlm/internlm-xcomposer2d5-7b', id_scale=1.5,
self.model.hd_num = 36
self.model.id_scale = self.id_scale

def message_to_promptimg(self, message, dataset=None):
def message_to_promptimg(self, message, dataset=None, video_input=False):
num_images = len([x for x in message if x['type'] == 'image'])
if num_images == 0:
prompt = '\n'.join([x['value']
for x in message if x['type'] == 'text'])
image = None

else:
image = [x['value'] for x in message if x['type'] == 'image']
if len(image) == 1:
prompt = ''.join([x['value']
for x in message if x['type'] == 'text'])
im_prompt = '<IM_POS>'
prompt = prompt.replace('<image 1>', '')
prompt = im_prompt + prompt
image = [Image.open(x['value']).convert('RGB') for x in message if x['type'] == 'image']

if video_input:
im_prompt = '<IM_POS>Here are some frames of a video.'
if len(image) > 64:
step = len(image) / 64
image = [image[int(i * step)] for i in range(64)]
image = [img_process(image)]

else:
prompt = ''
im_prompt = [
f'Image{im_idx+1}: <IM_POS>;' for im_idx in range(len(image))]
add_im = len(im_prompt)
im_idx = 0
for x in message:
if x['type'] == 'text':
prompt += x['value']
if add_im > im_idx:
prompt += f'Image{im_idx + 1}'
im_idx += 1
im_prompt = ' '.join(im_prompt)
for i in range(len(image)):
prompt = prompt.replace(f'<image {i+1}>', f'Image{i+1} ')
# fix bug for multi-image prompt
if dataset is not None and listinstr(['mmlongbench', 'dude', 'slidevqa'], dataset.lower()):
prompt = '[UNUSED_TOKEN_146]user\n' + im_prompt + re.sub(
re.escape('[UNUSED_TOKEN_146]user\n'), '', prompt
)
prompt = re.sub('Image1$', '', prompt)
if len(image) > 1:
im_prompt = ' '.join([
f'Image{im_idx+1}: <IM_POS>;' for im_idx in range(len(image))])
else:
im_prompt = '<IM_POS>'

prompt = ''
for x in message:
if x['type'] == 'text' and x.get('role', '') != 'system':
prompt += x['value']
sp = [i for i in re.split(conv_pattern, prompt) if i != '' and i != '\n']
assert len(sp) <= 2
q = sp[0]
prompt = f'[UNUSED_TOKEN_146]user\n{im_prompt}{q}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n'

for idx in range(10):
idx = chr(65 + idx)
prompt = prompt.replace(f'({idx})', f'{idx}.')

return prompt, image

def generate_mme(self, image_path, text):
Expand Down Expand Up @@ -209,32 +298,43 @@ def generate_brief(self, image_path, text):
need_bos=True, max_token=10)
return out

def generate_video(self, image_path, text):
out = model_gen(
self.model, text, image_path, beams=1, # self.beam,
need_bos=True, max_token=100, video_input=True)
return out

def set_max_num(self, dataset):
if dataset is not None and listinstr(['MME-RealWorld', 'MME-RealWorld-CN'], dataset):
self.model.hd_num = 25

def generate_inner(self, message, dataset=None):
self.set_max_num(dataset)
prompt, image_path = self.message_to_promptimg(message, dataset=dataset)

with torch.cuda.amp.autocast():
if dataset is None:
prompt, image_path = self.message_to_promptimg(message, dataset=dataset)
return self.generate_vanilla(image_path, prompt)
assert isinstance(dataset, str)
if dataset == 'MME':
return self.generate_mme(image_path, prompt)
elif listinstr(['hallu', 'pope'], dataset.lower()):
return self.generate_brief(image_path, prompt)
elif listinstr(['llava', 'mmvet'], dataset.lower()):
return self.generate_vanilla(image_path, prompt)
elif dataset is not None and DATASET_TYPE(dataset) == 'MCQ':
return self.generate_multichoice(image_path, prompt, dataset)
elif listinstr(['MME-RealWorld', 'MME-RealWorld-CN'], dataset):
return self.generate_multichoice(image_path, prompt, dataset)
elif dataset is not None and DATASET_TYPE(dataset) == 'VQA':
return self.generate_vqa(image_path, prompt)

if listinstr(['video', 'mvbench'], dataset.lower()):
prompt, image_path = self.message_to_promptimg(message, dataset=dataset, video_input=True)
return self.generate_video(image_path, prompt)
else:
return self.generate_vanilla(image_path, prompt)
prompt, image_path = self.message_to_promptimg(message, dataset=dataset)
if dataset == 'MME':
return self.generate_mme(image_path, prompt)
elif listinstr(['hallu', 'pope'], dataset.lower()):
return self.generate_brief(image_path, prompt)
elif listinstr(['llava', 'mmvet'], dataset.lower()):
return self.generate_vanilla(image_path, prompt)
elif dataset is not None and DATASET_TYPE(dataset) == 'MCQ':
return self.generate_multichoice(image_path, prompt, dataset)
elif listinstr(['MME-RealWorld', 'MME-RealWorld-CN'], dataset):
return self.generate_multichoice(image_path, prompt, dataset)
elif dataset is not None and DATASET_TYPE(dataset) == 'VQA':
return self.generate_vqa(image_path, prompt)
else:
return self.generate_vanilla(image_path, prompt)

def use_custom_prompt(self, dataset):
assert dataset is not None
Expand Down Expand Up @@ -291,8 +391,8 @@ def build_prompt(self, line, dataset=None):
prompt = f'[UNUSED_TOKEN_146]user\n{q}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n'
else:
q = line['question']
prompt = f'[UNUSED_TOKEN_146]user\nAnswer the question using a single word or phrase.\
{q}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n'
prefix = 'Answer the question using a single word or phrase.'
prompt = f'[UNUSED_TOKEN_146]user\n{prefix}{q}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n'
ret = [dict(type='text', value=prompt)]
ret.extend([dict(type='image', value=s) for s in tgt_path])
return ret
Loading