Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Benchmark] mmniah #434

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vlmeval/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .image_mcq import ImageMCQDataset, MMMUDataset, CustomMCQDataset, GMAIMMBenchDataset
from .image_mt import MMDUDataset
from .image_vqa import (
ImageVQADataset, MathVision, OCRBench, MathVista, LLaVABench, MMVet, MTVQADataset, CustomVQADataset
ImageVQADataset, MathVision, OCRBench, MathVista, LLaVABench, MMVet, MTVQADataset, CustomVQADataset, MMNIAH
)

from .vcr import VCRDataset
Expand Down Expand Up @@ -110,7 +110,7 @@ def evaluate(self, eval_file, **judge_kwargs):
IMAGE_DATASET = [
ImageCaptionDataset, ImageYORNDataset, ImageMCQDataset, ImageVQADataset, MathVision,
MMMUDataset, OCRBench, MathVista, LLaVABench, MMVet, MTVQADataset,
MMLongBench, VCRDataset, MMDUDataset, DUDE, SlideVQA, GMAIMMBenchDataset
MMLongBench, VCRDataset, MMDUDataset, DUDE, SlideVQA, MUIRDataset, GMAIMMBenchDataset, MMNIAH
]

VIDEO_DATASET = [
Expand Down
2 changes: 2 additions & 0 deletions vlmeval/dataset/image_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@


def img_root_map(dataset):
if 'MM_NIAH' in dataset:
return 'MMNIAH'
if 'OCRVQA' in dataset:
return 'OCRVQA'
if 'COCO_VAL' == dataset:
Expand Down
162 changes: 162 additions & 0 deletions vlmeval/dataset/image_vqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,3 +431,165 @@ def load_data(self, dataset):

def evaluate(self, eval_file, **judge_kwargs):
raise NotImplementedError


class MMNIAH(ImageBaseDataset):
TYPE = 'VQA'
DATASET_URL = {
'MM_NIAH_VAL':
'https://huggingface.co/datasets/petter12321/MM-NIAH-VLMEvalKit/resolve/main/MM_NIAH_VAL.tsv',
'MM_NIAH_TEST':
['https://huggingface.co/datasets/petter12321/MM-NIAH-VLMEvalKit/resolve/main/part-aa',
'https://huggingface.co/datasets/petter12321/MM-NIAH-VLMEvalKit/resolve/main/part-ab',
'https://huggingface.co/datasets/petter12321/MM-NIAH-VLMEvalKit/resolve/main/part-ac',
'https://huggingface.co/datasets/petter12321/MM-NIAH-VLMEvalKit/resolve/main/part-ad',
'https://huggingface.co/datasets/petter12321/MM-NIAH-VLMEvalKit/resolve/main/part-ae']}
DATASET_MD5 = {'MM_NIAH_VAL': 'bcce50780324152a00495cfa466558a1',
'MM_NIAH_TEST': '49b19b787b7b9f565380d47fd6888c74'}

def prepare_tsv(self, url, file_md5=None):
import os
data_root = LMUDataRoot()
os.makedirs(data_root, exist_ok=True)
update_flag = False
file_name = 'MM_NIAH_VAL.tsv' if 'MM_NIAH_VAL' in url else 'MM_NIAH_TEST.tsv'
data_path = osp.join(data_root, file_name)
if osp.exists(data_path) and (file_md5 is None or md5(data_path) == file_md5):
pass
elif file_name == 'MM_NIAH_TEST.tsv':
warnings.warn('The dataset tsv is not downloaded')
for i in range(len(url)):
if osp.exists(osp.join(data_root, 'part-a' + chr(ord('a') + i))):
print('part_a' + chr(ord('a') + i) + ' is existed')
continue
download_file(url[i], data_path)
file_prefix = 'part-'
output_file = data_path
split_files = sorted([f for f in os.listdir(data_root) if f.startswith(file_prefix)])
with open(output_file, 'wb') as outfile:
# 逐个读取每个拆分文件并写入到输出文件
for filename in split_files:
with open(osp.join(data_root, filename), 'rb') as infile:
outfile.write(infile.read())
update_flag = True
else:
warnings.warn('The dataset tsv is not downloaded')
download_file(url, data_path)
update_flag = True

if file_size(data_path, 'GB') > 1:
local_path = data_path.replace('.tsv', '_local.tsv')
if not osp.exists(local_path) or os.environ.get('FORCE_LOCAL', None) or update_flag:
from ..tools import LOCALIZE
LOCALIZE(data_path, local_path)
data_path = local_path
return load(data_path)

@classmethod
def evaluate(self, eval_file, **judge_kwargs):
from .utils.mmniah import is_correct
# find-image, count-text, find-text,
# infer-choose, count-image, visual-reasoning
MMNIAH_score = {
'count-text': 0,
'find-image': 0,
'find-text': 0,
'infer-choose': 0,
'count-image': 0,
'visual-reasoning': 0,
'total': 0,
}
MMNIAH_num = {
'count-text': 0,
'find-image': 0,
'find-text': 0,
'infer-choose': 0,
'count-image': 0,
'visual-reasoning': 0,
'total': 0,
}
final_score_dict = {
'count-text': 0,
'find-image': 0,
'find-text': 0,
'infer-choose': 0,
'count-image': 0,
'visual-reasoning': 0,
'total': 0,
}
data = load(eval_file)
lt = len(data)
lines = [data.iloc[i] for i in range(lt)]
for i in tqdm(range(len(lines))):
line = lines[i]
predict = line['prediction']
answers = line['answer']
# print("predict =", predict)
# print("answers =", answers)
category = line['category']
if is_correct(answers, predict):
MMNIAH_score[category] += 1
MMNIAH_score['total'] += 1
MMNIAH_num[category] += 1
MMNIAH_num['total'] += 1

for category in ['find-image', 'count-text', 'find-text', 'infer-choose', 'count-image', 'visual-reasoning', 'total']:
if MMNIAH_num[category] != 0:
final_score_dict[category] = MMNIAH_score[category] / MMNIAH_num[category]
else:
final_score_dict[category] = None

score_pth = eval_file.replace('.xlsx', '_score.json')
dump(final_score_dict, score_pth)
return final_score_dict

def build_prompt(self, line):
msgs = super().build_prompt(line)
if isinstance(line, int):
line = self.data.iloc[line]
totalchoice = line['multi-choice options']
totalchoice = eval(totalchoice)
# find-image, count-text, find-text,
# infer-choose, count-image, visual-reasoning
# data_type = line['category']
context = msgs[-1]['value']
context = eval(context)
question = context[0] + '\n' + context[1]
# tgt_path是所有图像地址列表
tgt_path = []
# msgs[0]['value']
for i in range(len(msgs) - 1):
tgt_path.append(msgs[i]['value'])
choices = totalchoice[0]
choices_image = totalchoice[1]
if choices:
for c_idx, c in enumerate(choices):
question = f"{question}\n{chr(c_idx + ord('A'))}. {c}"
question += "\nAnswer with the option's letter from the given choices directly."
elif choices_image:
for c_idx in range(len(choices_image)):
question = f"{question}\n{chr(c_idx + ord('A'))}. <image>"
question += "\nAnswer with the option's letter from the given choices directly."
else:
question += '\nAnswer the question using a single word or phrase.'
question = '<start>' + question + '<end>'
question = question.split('<image>')
assert len(tgt_path) + 1 == len(question)
context = []
for i in range(len(tgt_path)):
context.append(question[i])
context.append(tgt_path[i])
context.append(question[-1])
context[0] = context[0][7:]
context[-1] = context[-1][:-5]
msgs = []
for i in range(len(context)):
if i % 2 == 0:
msgs.append(dict(type='text', value=context[i]))
else:
ROOT = LMUDataRoot()
msgs.append(dict(type='image', value=osp.join(osp.join(ROOT, 'images', self.dataset_name), context[i])))
for element in msgs:
if element['value'] == '':
msgs.remove(element)
return msgs
Loading