Skip to content

Commit

Permalink
[Benchmark] Add VL-RewardBench (#703)
Browse files Browse the repository at this point in the history
* update vlrewardbench

* pre-commit fix

* formatter

* [Improvement] Better `AUTO_SPLIT` and model split for InternVL2

* [Minor] Improve CC-OCR Import

* [Model] Support QVQ

* [Model] Update Molmo Eval to Match Official Implementation (#648)

* add molmo prompts

* fix lint format

* [Fix] Refine Qwen-VL2 device assignment

* [Fix] Fix RealWorldQA md5

* update MMMU_DEV_VAL tsv

* [Fix] Fix confusing image width&height (#704)

Co-authored-by: Yuan Ye <[email protected]>

* Update llama_vision.py (#705)

* [Fix] Fix Lint

* Fix Lint

* Fix Lint

---------

Co-authored-by: kennymckormick <[email protected]>
Co-authored-by: jamespark3922 <[email protected]>
Co-authored-by: CMeteor <[email protected]>
Co-authored-by: Yuan Ye <[email protected]>
Co-authored-by: Guowei Xu <[email protected]>
  • Loading branch information
6 people authored Jan 1, 2025
1 parent ae10ca4 commit 276d90a
Show file tree
Hide file tree
Showing 6 changed files with 207 additions and 41 deletions.
2 changes: 1 addition & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def main():
judge_kwargs['model'] = 'chatgpt-0125'
elif listinstr(['MMVet', 'LLaVABench', 'MMBench-Video'], dataset_name):
judge_kwargs['model'] = 'gpt-4-turbo'
elif listinstr(['MathVista', 'MathVerse', 'MathVision', 'DynaMath'], dataset_name):
elif listinstr(['MathVista', 'MathVerse', 'MathVision', 'DynaMath', 'VL-RewardBench'], dataset_name): # noqa: E501
judge_kwargs['model'] = 'gpt-4o-mini'
elif listinstr(['MMLongBench', 'MMDU', 'DUDE', 'SLIDEVQA', 'MIA-Bench', 'WildVision'], dataset_name): # noqa: E501
judge_kwargs['model'] = 'gpt-4o'
Expand Down
3 changes: 2 additions & 1 deletion vlmeval/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .mmlongbench import MMLongBench
from .dude import DUDE
from .slidevqa import SlideVQA
from .vl_rewardbench import VLRewardBench

from .mmbench_video import MMBenchVideo
from .videomme import VideoMME
Expand Down Expand Up @@ -132,7 +133,7 @@ def evaluate(self, eval_file, **judge_kwargs):
MMLongBench, VCRDataset, MMDUDataset, DUDE, SlideVQA, MUIRDataset, CCOCRDataset,
GMAIMMBenchDataset, MMERealWorld, HRBenchDataset, CRPE, MathVerse, NaturalBenchDataset,
MIABench, OlympiadBench, WildVision, MMMath, QSpatial, Dynamath, MMGenBench, VizWiz, MMNIAH,
CMMMU
CMMMU, VLRewardBench
]

VIDEO_DATASET = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def eval_formula(self, response_info, gt_info, op_name='formula'):
pred = response_info[img_name]

if op_name == 'formula':
pred = pred.replace("\n", " ").replace("```latex", "").replace("```", "").replace("\t", " ").replace(" ", "")
pred = pred.replace("\n", " ").replace("```latex", "").replace("```", "").replace("\t", " ").replace(" ", "") # noqa: E501
gt = gt.replace(" ", "")
elif op_name == 'molecular':
pred = pred.replace("\n", "").replace(" ", "").replace("<smiles>", "").replace("</smiles>", "")
Expand Down
61 changes: 25 additions & 36 deletions vlmeval/dataset/utils/ccocr_evaluator/kie_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ def update_cost(node1: Node, node2: Node):
label2 = node2.label
label1_leaf = "<leaf>" in label1
label2_leaf = "<leaf>" in label2
if label1_leaf == True and label2_leaf == True:
if label1_leaf and label2_leaf:
return edit_distance(label1.replace("<leaf>", ""), label2.replace("<leaf>", ""))
elif label1_leaf == False and label2_leaf == True:
elif not label1_leaf and label2_leaf:
return 1 + len(label2.replace("<leaf>", ""))
elif label1_leaf == True and label2_leaf == False:
elif label1_leaf and not label2_leaf:
return 1 + len(label1.replace("<leaf>", ""))
else:
return int(label1 != label2)
Expand Down Expand Up @@ -121,7 +121,8 @@ def normalize_dict(data: Union[Dict, List, Any]):

def cal_f1_all(preds, answers):
"""
Calculate global F1 accuracy score (field-level, micro-averaged) by counting all true positives, false negatives and false positives
Calculate global F1 accuracy score (field-level, micro-averaged) by counting all true positives,
false negatives and false positives
"""
metric_info, error_info = {}, {}
total_tp, total_fn_or_fp = 0, 0
Expand Down Expand Up @@ -233,35 +234,28 @@ def cal_acc(pred: dict, answer: dict):
"""
pred = construct_tree_from_dict(normalize_dict(pred))
answer = construct_tree_from_dict(normalize_dict(answer))
return max(
0,
1
- (
zss.distance(
pred,
answer,
get_children=zss.Node.get_children,
insert_cost=insert_and_remove_cost,
remove_cost=insert_and_remove_cost,
update_cost=update_cost,
return_operations=False,
)
/ zss.distance(
construct_tree_from_dict(normalize_dict({})),
answer,
get_children=zss.Node.get_children,
insert_cost=insert_and_remove_cost,
remove_cost=insert_and_remove_cost,
update_cost=update_cost,
return_operations=False,
)
),
val1 = zss.distance(
pred,
answer,
get_children=zss.Node.get_children,
insert_cost=insert_and_remove_cost,
remove_cost=insert_and_remove_cost,
update_cost=update_cost,
return_operations=False,
)
val2 = zss.distance(
construct_tree_from_dict(normalize_dict({})),
answer,
get_children=zss.Node.get_children,
insert_cost=insert_and_remove_cost,
remove_cost=insert_and_remove_cost,
update_cost=update_cost,
return_operations=False,
)
return max(0, 1 - val1 / val2)


def cal_acc_all(pred_info, answer_info):
"""
"""
acc_info, error_info = {}, {}
for file_name, answer in answer_info.items():
# if file_name not in pred_info:
Expand Down Expand Up @@ -303,13 +297,11 @@ def eval_donut(pdt_info, gt_info, normalize_func=None, data_name=None):
acc_average, acc_error_info = cal_acc_all(pdt_info, gt_info)
eval_info = {"f1_score": f1_score, "acc": acc_average, "class_f1_score": class_eval_info,
"f1_error_info": error_info, "acc_error_info": acc_error_info}
print(data_name, "f1_score", f1_score, "acc", acc_average)
print(data_name, "f1_score", f1_score, "acc", acc_average)
return eval_info


def post_process_to_json(qwen_info_str, file_name=None):
"""
"""
try:
if "```json" in qwen_info_str:
if "```" not in qwen_info_str:
Expand All @@ -320,10 +312,7 @@ def post_process_to_json(qwen_info_str, file_name=None):
json_str = qwen_info_str.strip().replace("\n", "")
json_data = json.loads(json_str)
return json_data
except Exception as e:
# print("--> post error: {}, file_name: {}".format(e, file_name))
# print("json_raw", qwen_info_str)
# print("json_str", json_str)
except Exception as err: # noqa: F841
return None


Expand Down
6 changes: 4 additions & 2 deletions vlmeval/dataset/utils/ccocr_evaluator/ocr_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,10 @@ def evaluate(self, response_info, gt_info, **kwargs):
image_pdt_info, image_gt_info = {}, {}
for file_name, gt_src in gt_info.items():
pred_src = response_info.get(file_name, "")
pdt_token_list = text_normalize_and_tokenize(str(pred_src).strip(), is_word_level, is_lower, is_alphanum_only)
gt_token_list = text_normalize_and_tokenize(str(gt_src).strip(), is_word_level, is_lower, is_alphanum_only)
pdt_token_list = text_normalize_and_tokenize(
str(pred_src).strip(), is_word_level, is_lower, is_alphanum_only)
gt_token_list = text_normalize_and_tokenize(
str(gt_src).strip(), is_word_level, is_lower, is_alphanum_only)
image_pdt_info[file_name] = pdt_token_list
image_gt_info[file_name] = gt_token_list
eval_result = calculate_metrics(image_pdt_info, image_gt_info, is_verbose=False)
Expand Down
174 changes: 174 additions & 0 deletions vlmeval/dataset/vl_rewardbench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
from ast import literal_eval

from .image_base import ImageBaseDataset
from .utils import build_judge, DEBUG_MESSAGE
from ..smp import *
from ..utils import track_progress_rich


LLM_PARSE_ANSWER_PROMPT = '''
You are given a pairwise judgement for two responses. Please return the better response according to the judgement.
Return the Answer X ONLY. e.g., Answer 1 or Answer 2.
Judgement: {judgement}
'''


PROMPT_TEMPLATE = '''\
You are a highly capable multimodal AI assistant tasked with evaluating answers to visual questions.
Please analyze the following image and question, then determine which of the two provided answers is better.
Question: {query}
Answer 1: {answer_0}
Answer 2: {answer_1}
Please evaluate both answers based on the following criteria:
1. Accuracy: How well does the answer align with the visual information in the image?
2. Completeness: Does the answer fully address all aspects of the question?
3. Clarity: Is the answer easy to understand and well-articulated?
4. Relevance: Does the answer directly relate to the question and the image?
After your evaluation, please:
1. Explain your reasoning for each criterion.
2. Provide an overall judgment on which answer is better (Answer 1 or Answer 2).\
For example: Overall Judgment: Answer X is better.
Your response should be structured and detailed, \
demonstrating your understanding of both the visual and textual elements of the task.'''


def get_score(line, parsed_response, random_number):
gt_ans = line['human_ranking'].index(0 if random_number == 0 else 1) + 1
if 'Answer 1'.lower() in parsed_response.lower():
pred = 1
elif 'Answer 2'.lower() in parsed_response.lower():
pred = 2
else: # failed
pred = 'None' # random.choice([1, 2])

if pred == gt_ans:
return 1.0
else:
return 0.0


def VLRewardBench_eval_answer(model, line):
response = toliststr(line['response'])
random_number = sum(len(res) for res in response) % 2

prompt = LLM_PARSE_ANSWER_PROMPT.format(judgement=line['prediction'])
messages = [dict(type='text', value=prompt)]

resp = model.generate(messages)
score = get_score(line, resp, random_number)

if score is None:
return 'Unknown'
return score


class VLRewardBench(ImageBaseDataset):
TYPE = 'VQA'
DATASET_URL = {
'VL-RewardBench': 'https://huggingface.co/datasets/MMInstruction/VL-RewardBench/resolve/main/vl_rewardbench.tsv'
}
DATASET_MD5 = {'VL-RewardBench': '1d2676f4ab4a5f755019ec0af2b28189'}

# Given one data record, return the built prompt (a multi-modal message), can override
def build_prompt(self, line):
if isinstance(line, int):
line = self.data.iloc[line]
tgt_path = self.dump_image(line) # save image to local
question = line['question']
msgs = []
if isinstance(tgt_path, list):
msgs.extend([dict(type='image', value=p) for p in tgt_path])
else:
msgs = [dict(type='image', value=tgt_path)]

response = toliststr(line['response'])
random_number = sum(len(res) for res in response) % 2
if random_number == 1:
# randomly shuffle the order of the responses
response = response[::-1]
query_prompt = PROMPT_TEMPLATE.format(
query=question, answer_0=response[0], answer_1=response[1]
)
msgs = msgs + [dict(type='text', value=query_prompt)]
return msgs

# It returns a DataFrame
@classmethod
def evaluate(self, eval_file, **judge_kwargs):
suffix = eval_file.split('.')[-1]
model = judge_kwargs['model']
storage = eval_file.replace(f'.{suffix}', f'_{model}.xlsx')
score_file = eval_file.replace(f'.{suffix}', f'_{model}_score.csv')
tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
nproc = judge_kwargs.pop('nproc', 4)

if not osp.exists(storage):
raw_data = VLRewardBench('VL-RewardBench').data
data = load(eval_file)
data['prediction'] = [str(x) for x in data['prediction']]
data['human_ranking'] = [literal_eval(x) for x in raw_data['answer']]

judge_kwargs['temperature'] = 0
judge_kwargs['timeout'] = 60
model = build_judge(max_tokens=128, **judge_kwargs)

assert model.working(), (
'VLRewardBench evaluation requires a working OPENAI API\n'
+ DEBUG_MESSAGE
)

lt = len(data)
lines = [data.iloc[i] for i in range(lt)]
tups = [(model, line) for line in lines]
indices = [line['index'] for line in lines]

ans = load(tmp_file) if osp.exists(tmp_file) else {}
tups = [x for x, i in zip(tups, indices) if i not in ans]
indices = [i for i in indices if i not in ans]

if len(indices):
new_results = track_progress_rich(
VLRewardBench_eval_answer,
tups,
nproc=nproc,
chunksize=nproc,
keys=indices,
save=tmp_file,
)
ans = load(tmp_file)
for k, v in zip(indices, new_results):
ans[k] = v

data['score'] = [ans[idx] for idx in data['index']]
# data.pop('image')
dump(data, storage)

data = load(storage)
lt = len(data)

category_scores = defaultdict(lambda: 0)
category_cnt = defaultdict(lambda: 0)
scores = defaultdict(lambda: 0)
for i in range(lt):
item = data.iloc[i]
category_scores[item['category']] += item['score']
category_cnt[item['category']] += 1
# calculate the average score for each category
for k, v in category_scores.items():
scores[k] = v / category_cnt[k]
# calculate category macro accuracy (average across categories)
scores['Macro Accuracy'] = sum(scores.values()) / len(scores)
# calculate the total average score
scores['Overall Consistency'] = sum(category_scores.values()) / lt

scores = {k: [v] for k, v in scores.items()}
scores = pd.DataFrame(scores)
dump(scores, score_file)
return scores

0 comments on commit 276d90a

Please sign in to comment.