diff --git a/run.py b/run.py
index 26930a227..bee2c23b2 100644
--- a/run.py
+++ b/run.py
@@ -322,7 +322,7 @@ def main():
judge_kwargs['model'] = 'chatgpt-0125'
elif listinstr(['MMVet', 'LLaVABench', 'MMBench-Video'], dataset_name):
judge_kwargs['model'] = 'gpt-4-turbo'
- elif listinstr(['MathVista', 'MathVerse', 'MathVision', 'DynaMath'], dataset_name):
+ elif listinstr(['MathVista', 'MathVerse', 'MathVision', 'DynaMath', 'VL-RewardBench'], dataset_name): # noqa: E501
judge_kwargs['model'] = 'gpt-4o-mini'
elif listinstr(['MMLongBench', 'MMDU', 'DUDE', 'SLIDEVQA', 'MIA-Bench', 'WildVision'], dataset_name): # noqa: E501
judge_kwargs['model'] = 'gpt-4o'
diff --git a/vlmeval/api/siliconflow.py b/vlmeval/api/siliconflow.py
index 6fea69e2f..eeba793f6 100644
--- a/vlmeval/api/siliconflow.py
+++ b/vlmeval/api/siliconflow.py
@@ -26,7 +26,7 @@ def resize_image(image: Image.Image, max_height: int, max_width: int) -> Image.I
def encode_image(path: str, max_height: int = 1024, max_width: int = 1024) -> str:
image = Image.open(path).convert("RGB")
image = resize_image(image, max_height, max_width)
- height, width = image.size
+ width, height = image.size
if min(height, width) < 50:
scale = 50 / min(width, height)
image = image.resize((int(width * scale), int(height * scale)))
diff --git a/vlmeval/config.py b/vlmeval/config.py
index 6eb09f790..8545b045a 100644
--- a/vlmeval/config.py
+++ b/vlmeval/config.py
@@ -13,7 +13,7 @@
VideoChatGPT_ROOT = None
PLLaVA_ROOT = None
RBDash_ROOT = None
-VITA_ROOT = None
+VITA_ROOT = '/fs-computility/mllm1/shared/dhd/VITA'
LLAVA_V1_7B_MODEL_PTH = 'Please set your local path to LLaVA-7B-v1.1 here, the model weight is obtained by merging LLaVA delta weight based on vicuna-7b-v1.1 in https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md with vicuna-7b-v1.1. '
video_models = {
@@ -332,6 +332,7 @@
qwen2vl_series = {
'Qwen-VL-Max-0809': partial(Qwen2VLAPI, model='qwen-vl-max-0809', min_pixels=1280*28*28, max_pixels=16384*28*28),
'Qwen-VL-Plus-0809': partial(Qwen2VLAPI, model='qwen-vl-plus-0809', min_pixels=1280*28*28, max_pixels=16384*28*28),
+ 'QVQ-72B-Preview': partial(Qwen2VLChat, model_path='Qwen/QVQ-72B-Preview', min_pixels=1280*28*28, max_pixels=16384*28*28, system_prompt='You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.', max_new_tokens=8192, post_process=False),
'Qwen2-VL-72B-Instruct': partial(Qwen2VLChat, model_path='Qwen/Qwen2-VL-72B-Instruct', min_pixels=1280*28*28, max_pixels=16384*28*28),
'Qwen2-VL-7B-Instruct': partial(Qwen2VLChat, model_path='Qwen/Qwen2-VL-7B-Instruct', min_pixels=1280*28*28, max_pixels=16384*28*28),
'Qwen2-VL-7B-Instruct-AWQ': partial(Qwen2VLChat, model_path='Qwen/Qwen2-VL-7B-Instruct-AWQ', min_pixels=1280*28*28, max_pixels=16384*28*28),
diff --git a/vlmeval/dataset/__init__.py b/vlmeval/dataset/__init__.py
index ca3b39845..a1bab65c0 100644
--- a/vlmeval/dataset/__init__.py
+++ b/vlmeval/dataset/__init__.py
@@ -20,6 +20,7 @@
from .mmlongbench import MMLongBench
from .dude import DUDE
from .slidevqa import SlideVQA
+from .vl_rewardbench import VLRewardBench
from .mmbench_video import MMBenchVideo
from .videomme import VideoMME
@@ -132,7 +133,7 @@ def evaluate(self, eval_file, **judge_kwargs):
MMLongBench, VCRDataset, MMDUDataset, DUDE, SlideVQA, MUIRDataset, CCOCRDataset,
GMAIMMBenchDataset, MMERealWorld, HRBenchDataset, CRPE, MathVerse, NaturalBenchDataset,
MIABench, OlympiadBench, WildVision, MMMath, QSpatial, Dynamath, MMGenBench, VizWiz, MMNIAH,
- CMMMU
+ CMMMU, VLRewardBench
]
VIDEO_DATASET = [
diff --git a/vlmeval/dataset/image_ccocr.py b/vlmeval/dataset/image_ccocr.py
index 0af2f0776..290bfceef 100644
--- a/vlmeval/dataset/image_ccocr.py
+++ b/vlmeval/dataset/image_ccocr.py
@@ -1,10 +1,18 @@
+# flake8: noqa
+
import os
import re
import tempfile
from functools import partial
import pandas as pd
-from .utils import ccocr_evaluator_map
+try:
+ from .utils.ccocr_evaluator import evaluator_map_info as ccocr_evaluator_map
+except ImportError as err:
+ import warnings
+ warnings.warn('The dependency of CCOCR evaluator is not properly installed')
+ warnings.warn(f'{type(err)}: {err}')
+
from .image_base import ImageBaseDataset
from ..smp import *
@@ -157,7 +165,7 @@ def evaluate(self, eval_file, **judge_kwargs):
for data_info in dict_list:
image_name = data_info['image_name']
gt_info[image_name] = data_info['answer']
-
+
# warning the FAIL samples
if data_info['prediction'] != FAIL_MSG:
ptd_info[image_name] = data_info['prediction']
diff --git a/vlmeval/dataset/image_mcq.py b/vlmeval/dataset/image_mcq.py
index 1a70304ad..145209ddc 100644
--- a/vlmeval/dataset/image_mcq.py
+++ b/vlmeval/dataset/image_mcq.py
@@ -143,7 +143,7 @@ class ImageMCQDataset(ImageBaseDataset):
'AI2D_TEST': '0f593e0d1c7df9a3d69bf1f947e71975',
'AI2D_TEST_NO_MASK': 'fd8f463634d4fe9fbd23b876e8eea5be',
'MMStar': 'e1ecd2140806c1b1bbf54b43372efb9e',
- 'RealWorldQA': '92321028d2bc29040284b6674721e48f',
+ 'RealWorldQA': '4de008f55dc4fd008ca9e15321dc44b7',
'MLLMGuard_DS': '975fc0dd7119386e198c37d71e274b3f',
'BLINK': '3b6649b6a662184ea046908e5506260e',
'TaskMeAnything_v1_imageqa_random': '023fef69e2ca21827afb77c5ec3bc889',
@@ -286,7 +286,7 @@ class MMMUDataset(ImageMCQDataset):
}
DATASET_MD5 = {
- 'MMMU_DEV_VAL': '521afc0f3bf341e6654327792781644d',
+ 'MMMU_DEV_VAL': '585e8ad75e73f75dcad265dfd0417d64',
'MMMU_TEST': 'c19875d11a2d348d07e5eb4bdf33166d',
}
diff --git a/vlmeval/dataset/utils/__init__.py b/vlmeval/dataset/utils/__init__.py
index c9cb7cbcc..ecad31278 100644
--- a/vlmeval/dataset/utils/__init__.py
+++ b/vlmeval/dataset/utils/__init__.py
@@ -1,10 +1,9 @@
from .judge_util import build_judge, DEBUG_MESSAGE
from .multiple_choice import extract_answer_from_item, prefetch_answer
from .vqa_eval import levenshtein_distance
-from .ccocr_evaluator import evaluator_map_info as ccocr_evaluator_map
__all__ = [
- 'build_judge', 'extract_answer_from_item', 'prefetch_answer', 'ccocr_evaluator_map',
+ 'build_judge', 'extract_answer_from_item', 'prefetch_answer',
'levenshtein_distance', 'DEBUG_MESSAGE',
]
diff --git a/vlmeval/dataset/utils/ccocr_evaluator/README.md b/vlmeval/dataset/utils/ccocr_evaluator/README.md
index a636aa4a4..99572ef58 100644
--- a/vlmeval/dataset/utils/ccocr_evaluator/README.md
+++ b/vlmeval/dataset/utils/ccocr_evaluator/README.md
@@ -1,12 +1,12 @@
# CC-OCR: A Comprehensive and Challenging OCR Benchmark for Evaluating Large Multimodal Models in Literacy
## Introduction
-
+
Please refer to our [GitHub](https://github.com/AlibabaResearch/AdvancedLiterateMachinery/tree/main/Benchmarks/CC-OCR) for more information.
## Running Scripts
-Once the environment is ready, execute the following script from the root directory of VLMEvalKit
+Once the environment is ready, execute the following script from the root directory of VLMEvalKit
to perform inference and evaluation tasks in batch.
```shell
@@ -44,13 +44,13 @@ If you find our work helpful, feel free to give us a cite.
```
@misc{yang2024ccocr,
- title={CC-OCR: A Comprehensive and Challenging OCR Benchmark for Evaluating Large Multimodal Models in Literacy},
+ title={CC-OCR: A Comprehensive and Challenging OCR Benchmark for Evaluating Large Multimodal Models in Literacy},
author={Zhibo Yang and Jun Tang and Zhaohai Li and Pengfei Wang and Jianqiang Wan and Humen Zhong and Xuejing Liu and Mingkun Yang and Peng Wang and Shuai Bai and LianWen Jin and Junyang Lin},
year={2024},
eprint={2412.02210},
archivePrefix={arXiv},
primaryClass={cs.CV},
- url={https://arxiv.org/abs/2412.02210},
+ url={https://arxiv.org/abs/2412.02210},
}
```
diff --git a/vlmeval/dataset/utils/ccocr_evaluator/__init__.py b/vlmeval/dataset/utils/ccocr_evaluator/__init__.py
index cd3ff5ac8..d89f6f6b7 100644
--- a/vlmeval/dataset/utils/ccocr_evaluator/__init__.py
+++ b/vlmeval/dataset/utils/ccocr_evaluator/__init__.py
@@ -9,4 +9,4 @@
"doc_parsing": ParsingEvaluator("doc_parsing"),
"multi_lan_ocr": OcrEvaluator("multi_lan_ocr"),
"multi_scene_ocr": OcrEvaluator("multi_scene_ocr")
-}
\ No newline at end of file
+}
diff --git a/vlmeval/dataset/utils/ccocr_evaluator/common.py b/vlmeval/dataset/utils/ccocr_evaluator/common.py
index a3e803495..6ce9bcb55 100644
--- a/vlmeval/dataset/utils/ccocr_evaluator/common.py
+++ b/vlmeval/dataset/utils/ccocr_evaluator/common.py
@@ -26,7 +26,7 @@ def pick_response_text(json_path):
response_text = None
if model_name.startswith("gpt") or model_name.startswith("o1"):
- response_text = model_response.get("data", {}).get("response", {}).get("choices", [{}])[0].get("message", {}).get("content", None)
+ response_text = model_response.get("data", {}).get("response", {}).get("choices", [{}])[0].get("message", {}).get("content", None) # noqa: E501
elif model_name.startswith("local_"):
response_text = model_response
else:
@@ -35,7 +35,7 @@ def pick_response_text(json_path):
elif model_name.startswith("gemini"):
content_list = model_response.get("candidates", [{}])[0].get("content", {}).get("parts", None)
elif model_name.startswith("qwen"):
- content_list = model_response.get("output", {}).get("choices", [{}])[0].get("message", {}).get("content", None)
+ content_list = model_response.get("output", {}).get("choices", [{}])[0].get("message", {}).get("content", None) # noqa: E501
else:
raise NotImplementedError("The pick_response_text NOT implemented for model: {}".format(model_name))
@@ -115,7 +115,7 @@ def __call__(self, pdt_res_dir, gt_info, with_response_ratio=True, **kwargs):
# add response_success_ratio
if "summary" in eval_info and with_response_ratio:
success_ratio = (len(response_info) + len(post_error_list)) / (len(gt_info) + 1e-9)
- eval_info["summary"].update({"response_success_ratio": success_ratio })
+ eval_info["summary"].update({"response_success_ratio": success_ratio})
return meta_info, eval_info
@@ -149,9 +149,9 @@ def summary_multi_exp(exp_dir_base, dataset_list=None, is_weighted_sum=False):
data_status_info = json.load(f)
all_dataset_name.extend(data_status_info.keys())
dataset_list = sorted(set(all_dataset_name))
-
+
# summary main code
- all_evaluate_info, line_index = {}, 0
+ all_evaluate_info, _ = {}, 0
for exp_name in os.listdir(exp_dir_base):
dir_status_path = os.path.join(exp_dir_base, exp_name, "status.json")
if not os.path.exists(dir_status_path):
@@ -219,4 +219,4 @@ def summary_multi_exp(exp_dir_base, dataset_list=None, is_weighted_sum=False):
summary_path = summary_multi_exp(exp_base_dir, dataset_list=None, is_weighted_sum=False)
print("--> info: summary saved at : {}".format(summary_path))
- print("happy coding.")
\ No newline at end of file
+ print("happy coding.")
diff --git a/vlmeval/dataset/utils/ccocr_evaluator/doc_parsing_evaluator.py b/vlmeval/dataset/utils/ccocr_evaluator/doc_parsing_evaluator.py
index 382c4be88..d059adc09 100644
--- a/vlmeval/dataset/utils/ccocr_evaluator/doc_parsing_evaluator.py
+++ b/vlmeval/dataset/utils/ccocr_evaluator/doc_parsing_evaluator.py
@@ -241,7 +241,7 @@ def eval_formula(self, response_info, gt_info, op_name='formula'):
pred = response_info[img_name]
if op_name == 'formula':
- pred = pred.replace("\n", " ").replace("```latex", "").replace("```", "").replace("\t", " ").replace(" ", "")
+ pred = pred.replace("\n", " ").replace("```latex", "").replace("```", "").replace("\t", " ").replace(" ", "") # noqa: E501
gt = gt.replace(" ", "")
elif op_name == 'molecular':
pred = pred.replace("\n", "").replace(" ", "").replace("", "").replace("", "")
diff --git a/vlmeval/dataset/utils/ccocr_evaluator/kie_evaluator.py b/vlmeval/dataset/utils/ccocr_evaluator/kie_evaluator.py
index 2613a338f..797d42446 100644
--- a/vlmeval/dataset/utils/ccocr_evaluator/kie_evaluator.py
+++ b/vlmeval/dataset/utils/ccocr_evaluator/kie_evaluator.py
@@ -66,11 +66,11 @@ def update_cost(node1: Node, node2: Node):
label2 = node2.label
label1_leaf = "" in label1
label2_leaf = "" in label2
- if label1_leaf == True and label2_leaf == True:
+ if label1_leaf and label2_leaf:
return edit_distance(label1.replace("", ""), label2.replace("", ""))
- elif label1_leaf == False and label2_leaf == True:
+ elif not label1_leaf and label2_leaf:
return 1 + len(label2.replace("", ""))
- elif label1_leaf == True and label2_leaf == False:
+ elif label1_leaf and not label2_leaf:
return 1 + len(label1.replace("", ""))
else:
return int(label1 != label2)
@@ -121,7 +121,8 @@ def normalize_dict(data: Union[Dict, List, Any]):
def cal_f1_all(preds, answers):
"""
- Calculate global F1 accuracy score (field-level, micro-averaged) by counting all true positives, false negatives and false positives
+ Calculate global F1 accuracy score (field-level, micro-averaged) by counting all true positives,
+ false negatives and false positives
"""
metric_info, error_info = {}, {}
total_tp, total_fn_or_fp = 0, 0
@@ -233,35 +234,28 @@ def cal_acc(pred: dict, answer: dict):
"""
pred = construct_tree_from_dict(normalize_dict(pred))
answer = construct_tree_from_dict(normalize_dict(answer))
- return max(
- 0,
- 1
- - (
- zss.distance(
- pred,
- answer,
- get_children=zss.Node.get_children,
- insert_cost=insert_and_remove_cost,
- remove_cost=insert_and_remove_cost,
- update_cost=update_cost,
- return_operations=False,
- )
- / zss.distance(
- construct_tree_from_dict(normalize_dict({})),
- answer,
- get_children=zss.Node.get_children,
- insert_cost=insert_and_remove_cost,
- remove_cost=insert_and_remove_cost,
- update_cost=update_cost,
- return_operations=False,
- )
- ),
+ val1 = zss.distance(
+ pred,
+ answer,
+ get_children=zss.Node.get_children,
+ insert_cost=insert_and_remove_cost,
+ remove_cost=insert_and_remove_cost,
+ update_cost=update_cost,
+ return_operations=False,
+ )
+ val2 = zss.distance(
+ construct_tree_from_dict(normalize_dict({})),
+ answer,
+ get_children=zss.Node.get_children,
+ insert_cost=insert_and_remove_cost,
+ remove_cost=insert_and_remove_cost,
+ update_cost=update_cost,
+ return_operations=False,
)
+ return max(0, 1 - val1 / val2)
def cal_acc_all(pred_info, answer_info):
- """
- """
acc_info, error_info = {}, {}
for file_name, answer in answer_info.items():
# if file_name not in pred_info:
@@ -303,13 +297,11 @@ def eval_donut(pdt_info, gt_info, normalize_func=None, data_name=None):
acc_average, acc_error_info = cal_acc_all(pdt_info, gt_info)
eval_info = {"f1_score": f1_score, "acc": acc_average, "class_f1_score": class_eval_info,
"f1_error_info": error_info, "acc_error_info": acc_error_info}
- print(data_name, "f1_score", f1_score, "acc", acc_average)
+ print(data_name, "f1_score", f1_score, "acc", acc_average)
return eval_info
def post_process_to_json(qwen_info_str, file_name=None):
- """
- """
try:
if "```json" in qwen_info_str:
if "```" not in qwen_info_str:
@@ -320,10 +312,7 @@ def post_process_to_json(qwen_info_str, file_name=None):
json_str = qwen_info_str.strip().replace("\n", "")
json_data = json.loads(json_str)
return json_data
- except Exception as e:
- # print("--> post error: {}, file_name: {}".format(e, file_name))
- # print("json_raw", qwen_info_str)
- # print("json_str", json_str)
+ except Exception as err: # noqa: F841
return None
diff --git a/vlmeval/dataset/utils/ccocr_evaluator/ocr_evaluator.py b/vlmeval/dataset/utils/ccocr_evaluator/ocr_evaluator.py
index e33b95c67..5c82abccc 100644
--- a/vlmeval/dataset/utils/ccocr_evaluator/ocr_evaluator.py
+++ b/vlmeval/dataset/utils/ccocr_evaluator/ocr_evaluator.py
@@ -92,8 +92,10 @@ def evaluate(self, response_info, gt_info, **kwargs):
image_pdt_info, image_gt_info = {}, {}
for file_name, gt_src in gt_info.items():
pred_src = response_info.get(file_name, "")
- pdt_token_list = text_normalize_and_tokenize(str(pred_src).strip(), is_word_level, is_lower, is_alphanum_only)
- gt_token_list = text_normalize_and_tokenize(str(gt_src).strip(), is_word_level, is_lower, is_alphanum_only)
+ pdt_token_list = text_normalize_and_tokenize(
+ str(pred_src).strip(), is_word_level, is_lower, is_alphanum_only)
+ gt_token_list = text_normalize_and_tokenize(
+ str(gt_src).strip(), is_word_level, is_lower, is_alphanum_only)
image_pdt_info[file_name] = pdt_token_list
image_gt_info[file_name] = gt_token_list
eval_result = calculate_metrics(image_pdt_info, image_gt_info, is_verbose=False)
diff --git a/vlmeval/dataset/vl_rewardbench.py b/vlmeval/dataset/vl_rewardbench.py
new file mode 100644
index 000000000..d8dad7383
--- /dev/null
+++ b/vlmeval/dataset/vl_rewardbench.py
@@ -0,0 +1,174 @@
+from ast import literal_eval
+
+from .image_base import ImageBaseDataset
+from .utils import build_judge, DEBUG_MESSAGE
+from ..smp import *
+from ..utils import track_progress_rich
+
+
+LLM_PARSE_ANSWER_PROMPT = '''
+You are given a pairwise judgement for two responses. Please return the better response according to the judgement.
+Return the Answer X ONLY. e.g., Answer 1 or Answer 2.
+
+Judgement: {judgement}
+'''
+
+
+PROMPT_TEMPLATE = '''\
+You are a highly capable multimodal AI assistant tasked with evaluating answers to visual questions.
+Please analyze the following image and question, then determine which of the two provided answers is better.
+
+Question: {query}
+
+Answer 1: {answer_0}
+
+Answer 2: {answer_1}
+
+Please evaluate both answers based on the following criteria:
+1. Accuracy: How well does the answer align with the visual information in the image?
+2. Completeness: Does the answer fully address all aspects of the question?
+3. Clarity: Is the answer easy to understand and well-articulated?
+4. Relevance: Does the answer directly relate to the question and the image?
+
+After your evaluation, please:
+1. Explain your reasoning for each criterion.
+2. Provide an overall judgment on which answer is better (Answer 1 or Answer 2).\
+For example: Overall Judgment: Answer X is better.
+
+Your response should be structured and detailed, \
+demonstrating your understanding of both the visual and textual elements of the task.'''
+
+
+def get_score(line, parsed_response, random_number):
+ gt_ans = line['human_ranking'].index(0 if random_number == 0 else 1) + 1
+ if 'Answer 1'.lower() in parsed_response.lower():
+ pred = 1
+ elif 'Answer 2'.lower() in parsed_response.lower():
+ pred = 2
+ else: # failed
+ pred = 'None' # random.choice([1, 2])
+
+ if pred == gt_ans:
+ return 1.0
+ else:
+ return 0.0
+
+
+def VLRewardBench_eval_answer(model, line):
+ response = toliststr(line['response'])
+ random_number = sum(len(res) for res in response) % 2
+
+ prompt = LLM_PARSE_ANSWER_PROMPT.format(judgement=line['prediction'])
+ messages = [dict(type='text', value=prompt)]
+
+ resp = model.generate(messages)
+ score = get_score(line, resp, random_number)
+
+ if score is None:
+ return 'Unknown'
+ return score
+
+
+class VLRewardBench(ImageBaseDataset):
+ TYPE = 'VQA'
+ DATASET_URL = {
+ 'VL-RewardBench': 'https://huggingface.co/datasets/MMInstruction/VL-RewardBench/resolve/main/vl_rewardbench.tsv'
+ }
+ DATASET_MD5 = {'VL-RewardBench': '1d2676f4ab4a5f755019ec0af2b28189'}
+
+ # Given one data record, return the built prompt (a multi-modal message), can override
+ def build_prompt(self, line):
+ if isinstance(line, int):
+ line = self.data.iloc[line]
+ tgt_path = self.dump_image(line) # save image to local
+ question = line['question']
+ msgs = []
+ if isinstance(tgt_path, list):
+ msgs.extend([dict(type='image', value=p) for p in tgt_path])
+ else:
+ msgs = [dict(type='image', value=tgt_path)]
+
+ response = toliststr(line['response'])
+ random_number = sum(len(res) for res in response) % 2
+ if random_number == 1:
+ # randomly shuffle the order of the responses
+ response = response[::-1]
+ query_prompt = PROMPT_TEMPLATE.format(
+ query=question, answer_0=response[0], answer_1=response[1]
+ )
+ msgs = msgs + [dict(type='text', value=query_prompt)]
+ return msgs
+
+ # It returns a DataFrame
+ @classmethod
+ def evaluate(self, eval_file, **judge_kwargs):
+ suffix = eval_file.split('.')[-1]
+ model = judge_kwargs['model']
+ storage = eval_file.replace(f'.{suffix}', f'_{model}.xlsx')
+ score_file = eval_file.replace(f'.{suffix}', f'_{model}_score.csv')
+ tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
+ nproc = judge_kwargs.pop('nproc', 4)
+
+ if not osp.exists(storage):
+ raw_data = VLRewardBench('VL-RewardBench').data
+ data = load(eval_file)
+ data['prediction'] = [str(x) for x in data['prediction']]
+ data['human_ranking'] = [literal_eval(x) for x in raw_data['answer']]
+
+ judge_kwargs['temperature'] = 0
+ judge_kwargs['timeout'] = 60
+ model = build_judge(max_tokens=128, **judge_kwargs)
+
+ assert model.working(), (
+ 'VLRewardBench evaluation requires a working OPENAI API\n'
+ + DEBUG_MESSAGE
+ )
+
+ lt = len(data)
+ lines = [data.iloc[i] for i in range(lt)]
+ tups = [(model, line) for line in lines]
+ indices = [line['index'] for line in lines]
+
+ ans = load(tmp_file) if osp.exists(tmp_file) else {}
+ tups = [x for x, i in zip(tups, indices) if i not in ans]
+ indices = [i for i in indices if i not in ans]
+
+ if len(indices):
+ new_results = track_progress_rich(
+ VLRewardBench_eval_answer,
+ tups,
+ nproc=nproc,
+ chunksize=nproc,
+ keys=indices,
+ save=tmp_file,
+ )
+ ans = load(tmp_file)
+ for k, v in zip(indices, new_results):
+ ans[k] = v
+
+ data['score'] = [ans[idx] for idx in data['index']]
+ # data.pop('image')
+ dump(data, storage)
+
+ data = load(storage)
+ lt = len(data)
+
+ category_scores = defaultdict(lambda: 0)
+ category_cnt = defaultdict(lambda: 0)
+ scores = defaultdict(lambda: 0)
+ for i in range(lt):
+ item = data.iloc[i]
+ category_scores[item['category']] += item['score']
+ category_cnt[item['category']] += 1
+ # calculate the average score for each category
+ for k, v in category_scores.items():
+ scores[k] = v / category_cnt[k]
+ # calculate category macro accuracy (average across categories)
+ scores['Macro Accuracy'] = sum(scores.values()) / len(scores)
+ # calculate the total average score
+ scores['Overall Consistency'] = sum(category_scores.values()) / lt
+
+ scores = {k: [v] for k, v in scores.items()}
+ scores = pd.DataFrame(scores)
+ dump(scores, score_file)
+ return scores
diff --git a/vlmeval/smp/misc.py b/vlmeval/smp/misc.py
index 81b1ae5c4..d0d9e909f 100644
--- a/vlmeval/smp/misc.py
+++ b/vlmeval/smp/misc.py
@@ -277,4 +277,15 @@ def get_gpu_memory():
def auto_split_flag():
flag = os.environ.get('AUTO_SPLIT', '0')
- return flag == '1'
+ if flag == '1':
+ return True
+ _, world_size = get_rank_and_world_size()
+ try:
+ import torch
+ device_count = torch.cuda.device_count()
+ if device_count > world_size and device_count % world_size == 0:
+ return True
+ else:
+ return False
+ except:
+ return False
diff --git a/vlmeval/vlm/internvl/utils.py b/vlmeval/vlm/internvl/utils.py
index a4e16c8d8..7e3a917fd 100644
--- a/vlmeval/vlm/internvl/utils.py
+++ b/vlmeval/vlm/internvl/utils.py
@@ -119,7 +119,7 @@ def get_local_rank_and_local_world_size():
def split_model(model_path):
- num_gpus_per_node = 8
+ num_gpus_per_node = torch.cuda.device_count()
rank, world_size = get_rank_and_world_size()
try:
local_rank, local_world_size = get_local_rank_and_local_world_size()
@@ -130,7 +130,7 @@ def split_model(model_path):
gpus_per_process = int(os.environ['GPUS_PER_PROCESS'])
else:
gpus_per_process = 8 # default to use 8 GPUs for one model
-
+ gpus_per_process = min(gpus_per_process, num_gpus_per_node // local_world_size)
start_gpu = local_rank * gpus_per_process
end_gpu = start_gpu + gpus_per_process
@@ -159,6 +159,7 @@ def split_model(model_path):
device_map['language_model.model.embed_tokens'] = visible_devices[0]
device_map['language_model.output'] = visible_devices[0]
device_map['language_model.model.norm'] = visible_devices[0]
+ device_map['language_model.model.rotary_emb'] = visible_devices[0]
device_map['language_model.lm_head'] = visible_devices[0]
device_map[f'language_model.model.layers.{num_layers - 1}'] = visible_devices[0]
diff --git a/vlmeval/vlm/llama_vision.py b/vlmeval/vlm/llama_vision.py
index 9abbcfd95..649f963f0 100644
--- a/vlmeval/vlm/llama_vision.py
+++ b/vlmeval/vlm/llama_vision.py
@@ -79,7 +79,7 @@ def __init__(self, model_path='meta-llama/Llama-3.2-11B-Vision-Instruct', **kwar
self.device = 'cuda'
self.processor = AutoProcessor.from_pretrained(model_path)
- if 'Instruct' in model_path:
+ if 'Instruct' in model_path or 'cot' in model_path or 'CoT' in model_path:
kwargs_default = dict(do_sample=True, temperature=0.6, top_p=0.9)
else:
kwargs_default = dict(do_sample=False, max_new_tokens=512, temperature=0.0, top_p=None, num_beams=1)
@@ -200,5 +200,7 @@ def generate_inner(self, message, dataset=None):
self.kwargs['max_new_tokens'] = 128
else:
self.kwargs['max_new_tokens'] = 512
+ if "cot" in self.model_name or "CoT" in self.model_name:
+ self.kwargs['max_new_tokens'] = 2048
output = self.model.generate(**inputs, **self.kwargs)
return self.processor.decode(output[0][inputs['input_ids'].shape[1]:]).replace('<|eot_id|>', '')
diff --git a/vlmeval/vlm/molmo.py b/vlmeval/vlm/molmo.py
index fe111a06f..1227995fc 100644
--- a/vlmeval/vlm/molmo.py
+++ b/vlmeval/vlm/molmo.py
@@ -1,11 +1,33 @@
import torch
from PIL import Image
-import os.path as osp
-import sys
from .base import BaseModel
from ..smp import *
from ..dataset import DATASET_TYPE
+TYPE_PROMPTS = {
+ 'Y/N':'vqa2:',
+ 'VQA':'vqa2:',
+ 'MCQ':'a_okvqa_mc:',
+}
+
+DATASET_PROMPTS = {
+ 'AI2D_TEST':'ai2_diagram:',
+ 'AI2D_TEST_NO_MASK':'ai2_diagram:',
+ 'COCO_VAL':'coco_captioning:',
+ 'ChartQA_TEST':'chart_qa:',
+ 'ChartQA_VAL':'chart_qa:',
+ 'DocVQA_VAL':'doc_qa:',
+ 'DocVQA_TEST':'doc_qa:',
+ 'InfoVQA_TEST':'info_qa:',
+ 'InfoVQA_VAL':'info_qa:',
+ 'OCRVQA_TEST':'ocr_vqa:',
+ 'OCRVQA_TESTCORE':'ocr_vqa:',
+ 'ScienceQA_VAL':'science_qa:',
+ 'ScienceQA_TEST':'science_qa:',
+ 'TableVQABench':'tabwmp_da:',
+ 'TextVQA_VAL':'text_vqa:'
+}
+
class molmo(BaseModel):
@@ -36,6 +58,106 @@ def __init__(self, model_path='allenai/Molmo-7B-D-0924', **kwargs):
self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)
self.kwargs = kwargs
self.model_name = model_path
+ # set default maximum number of crops to 36
+ self.max_crops = kwargs.get('max_crops', 36)
+
+ def use_custom_prompt(self, dataset):
+ if DATASET_TYPE(dataset) in ['Y/N', 'MCQ', 'VQA']:
+ return True
+ return False
+
+ def build_prompt(self, line, dataset=None):
+ assert self.use_custom_prompt(dataset)
+ assert dataset is None or isinstance(dataset, str)
+ tgt_path = self.dump_image(line, dataset)
+ prefix = None
+ if dataset in ['MMMU_DEV_VAL', 'MMMU_TEST']:
+ prompt = self.build_prompt_mcq_vqa(line)
+ elif dataset in ['MathVista_MINI']:
+ prompt = self.build_prompt_mathvista(line)
+ elif dataset in ['AI2D_TEST', 'AI2D_TEST_NO_MASK']:
+ prompt = self.build_prompt_ai2d(line)
+ elif dataset is not None and listinstr(list(DATASET_PROMPTS.keys()), dataset):
+ prefix = DATASET_PROMPTS[dataset] # rest of supervised datasets are in VQA format
+ prompt = self.build_prompt_vqa(line, prefix)
+ elif dataset is not None and listinstr(['MCQ'], DATASET_TYPE(dataset)):
+ prompt = self.build_prompt_multiple_choice(line)
+ else:
+ prompt = self.build_prompt_vqa(line)
+
+ message = [dict(type='text', value=prompt)]
+ message.extend([dict(type='image', value=s) for s in tgt_path])
+
+ # interleave dataset
+ if dataset.startswith('MMMU_'):
+ from .. import MMMUDataset
+ message = MMMUDataset.split_MMMU(message)
+ return message
+
+ def build_prompt_mathvista(self, line):
+ if line['question_type'] == 'multi_choice':
+ prompt = self.build_prompt_multiple_choice(line)
+ else:
+ prompt = self.build_prompt_vqa(line)
+ return prompt
+
+ def build_prompt_ai2d(self, line):
+ def option_is_abc(line):
+ for cand in string.ascii_uppercase:
+ if cand in line and not pd.isna(line[cand]):
+ # check if option is single letter
+ if not line[cand].strip().isalpha() or len(line[cand].strip()) > 1:
+ return False
+ return True
+
+ if line['abcLabel'] and option_is_abc(line):
+ prompt = line['question']
+ options = {
+ cand: line[cand]
+ for cand in string.ascii_uppercase
+ if cand in line and not pd.isna(line[cand])
+ }
+ for key, item in options.items():
+ prompt += f'\n{item}'
+ prompt = f"ai2_diagram_no_letter: {prompt}"
+ # prompt = self.build_prompt_multiple_choice(line, prefix='ai2_diagram_no_letter:')
+ else:
+ prompt = self.build_prompt_multiple_choice(line, prefix='ai2_diagram:')
+ return prompt
+
+ def build_prompt_mcq_vqa(self, line):
+ if line['question_type'] == 'multiple-choice':
+ prompt = self.build_prompt_multiple_choice(line)
+ else:
+ prompt = self.build_prompt_vqa(line)
+ return prompt
+
+ def build_prompt_multiple_choice(self, line, prefix=None):
+ question = line['question']
+ hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
+ if hint is not None:
+ question = hint + '\n' + question
+ options = {
+ cand: line[cand]
+ for cand in string.ascii_uppercase
+ if cand in line and not pd.isna(line[cand])
+ }
+ for key, item in options.items():
+ question += f'\n{key}: {item}'
+ if prefix is None:
+ prompt = f"{TYPE_PROMPTS['MCQ']} {question}"
+ else:
+ prompt = f"{prefix} {question}"
+
+ return prompt
+
+ def build_prompt_vqa(self, line, prefix=None):
+ question = line['question']
+ if prefix is None:
+ prompt = f"{TYPE_PROMPTS['VQA']} {question}"
+ else:
+ prompt = f"{prefix} {question}"
+ return prompt
def generate_inner(self, message, dataset=None):
from transformers import GenerationConfig
@@ -44,10 +166,15 @@ def generate_inner(self, message, dataset=None):
image = Image.open(image_path)
if image.mode != "RGB":
image = image.convert("RGB")
+
# process the image and text
+ max_crops = self.max_crops
inputs = self.processor.process(
images=[image],
- text=prompt
+ text=prompt,
+ images_kwargs={
+ "max_crops": max_crops
+ }
)
# move inputs to the correct device and make a batch of size 1
@@ -63,7 +190,16 @@ def generate_inner(self, message, dataset=None):
# only get generated tokens; decode them to text
generated_tokens = output[0, inputs['input_ids'].size(1):]
- generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
+ generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
+
+ # AI2D: map direct answer to letter option
+ if dataset in ['AI2D_TEST', 'AI2D_TEST_NO_MASK']:
+ # 'ai2_diagram_no_letter: Which of the following is the magma chamber?\nK\nB\nC\nH'
+ if 'ai2_diagram_no_letter' in prompt:
+ options = prompt.split('\n')[1:]
+ answer = options.index(generated_text)
+ generated_text = chr(answer + ord('A'))
+
+ # print(dataset, prompt, generated_text, inputs['images'].size()) # uncomment to debug
- # print the generated text
return generated_text
diff --git a/vlmeval/vlm/qwen2_vl/model.py b/vlmeval/vlm/qwen2_vl/model.py
index 65629b0a9..0c4f28868 100644
--- a/vlmeval/vlm/qwen2_vl/model.py
+++ b/vlmeval/vlm/qwen2_vl/model.py
@@ -76,6 +76,7 @@ def __init__(
repetition_penalty=1.0,
use_custom_prompt: bool = True,
system_prompt: str | None = None,
+ post_process: bool = False, # if True, will try to only extract stuff in the last \boxed{}.
verbose: bool = False,
):
super().__init__(use_custom_prompt=use_custom_prompt)
@@ -90,6 +91,7 @@ def __init__(
)
self.system_prompt = system_prompt
self.verbose = verbose
+ self.post_process = post_process
self.fps = 2.0
self.nframe = 64
self.FRAME_FACTOR = 2
@@ -106,22 +108,22 @@ def __init__(
assert max_gpu_mem > 0
# If only one process and GPU memory is less than 40GB
- if auto_split_flag():
+ if '72b' in self.model_path.lower():
+ self.model = Qwen2VLForConditionalGeneration.from_pretrained(
+ model_path, torch_dtype='auto', device_map=split_model(), attn_implementation='flash_attention_2'
+ )
+ self.model.eval()
+ elif auto_split_flag():
assert world_size == 1, 'Only support world_size == 1 when AUTO_SPLIT is set for non-72B Qwen2-VL'
# Will Use All GPUs to run one model
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
model_path, torch_dtype='auto', device_map='auto', attn_implementation='flash_attention_2'
)
- elif '72b' not in self.model_path.lower():
+ else:
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
model_path, torch_dtype='auto', device_map='cpu', attn_implementation='flash_attention_2'
)
self.model.cuda().eval()
- else:
- self.model = Qwen2VLForConditionalGeneration.from_pretrained(
- model_path, torch_dtype='auto', device_map=split_model(), attn_implementation='flash_attention_2'
- )
- self.model.eval()
torch.cuda.empty_cache()
@@ -195,6 +197,24 @@ def generate_inner(self, message, dataset=None):
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
response = out[0]
+ if self.post_process:
+ resp = response.split('\\boxed{')[-1]
+ lt = len(resp)
+ counter, end = 1, None
+ for i in range(lt):
+ if resp[i] == '{':
+ counter += 1
+ elif resp[i] == '}':
+ counter -= 1
+ if counter == 0:
+ end = i
+ break
+ elif i == lt - 1:
+ end = lt
+ break
+ if end is not None:
+ response = resp[:end]
+
if self.verbose:
print(f'\033[32m{response}\033[0m')
return response
diff --git a/vlmeval/vlm/valley/valley_eagle_chat.py b/vlmeval/vlm/valley/valley_eagle_chat.py
index 4c82ccc9d..374fad561 100644
--- a/vlmeval/vlm/valley/valley_eagle_chat.py
+++ b/vlmeval/vlm/valley/valley_eagle_chat.py
@@ -124,7 +124,7 @@ def __init__(self,
from .valley_eagle.util.mm_utils import process_anyres_image
from .valley_eagle import conversation as conversation_lib
from .valley_eagle.util.data_util import dynamic_preprocess, preprocess
-
+
torch_dtype = torch.float16
padding_side = 'left'
use_fast = True
diff --git a/vlmeval/vlm/vita.py b/vlmeval/vlm/vita.py
index db0b31d0a..e7604cee2 100644
--- a/vlmeval/vlm/vita.py
+++ b/vlmeval/vlm/vita.py
@@ -28,9 +28,9 @@ def __init__(self, model_path='VITA/vita', root=None, **kwargs):
model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, _ = load_pretrained_model(
- model_path, None, model_name, model_type='qwen2p5_instruct', device_map='auto'
+ model_path, None, model_name, model_type='qwen2p5_instruct', device_map='cuda'
)
- # model.cuda().eval()
+ model = model.eval()
# model.tie_weights()
audio_encoder = model.get_audio_encoder()