Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support VSI-Bench Evaluation #511

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 153 additions & 0 deletions lmms_eval/tasks/vsibench/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@

import os
from pathlib import Path
import yaml
from loguru import logger as eval_logger
from functools import partial
import numpy as np
import pandas as pd

import datasets

MCA_QUESTION_TYPES = [
"object_rel_direction_easy",
"object_rel_direction_medium",
"object_rel_direction_hard",
"object_rel_distance",
"route_planning",
"obj_appearance_order",
]
NA_QUESTION_TYPES = [
"object_abs_distance",
"object_counting",
"object_size_estimation",
"room_size_estimation",
]

METRICS_FOR_MCA = {
"accuracy": "exact_match",
}

METRICS_FOR_NA = {
"MRA:.5:.95:.05": "partial(mean_relative_accuracy, start=.5, end=.95, interval=.05)",
}


hf_home = os.getenv("HF_HOME", "~/.cache/huggingface/")
base_cache_dir = os.path.expanduser(hf_home)
with open(Path(__file__).parent / "vsibench.yaml", "r") as f:
raw_data = f.readlines()
safe_data = []
for i, line in enumerate(raw_data):
if "!function" not in line:
safe_data.append(line)
cache_name = yaml.safe_load("".join(safe_data))["dataset_kwargs"]["cache_dir"]


def vsibench_doc_to_visual(doc):
cache_dir = os.path.join(base_cache_dir, cache_name)
video_path = doc["dataset"] + "/" + doc["scene_name"] + ".mp4"
video_path = os.path.join(cache_dir, video_path)
if os.path.exists(video_path):
video_path = video_path
else:
raise FileExistsError(f"video path:{video_path} does not exist.")
return [video_path]


def vsibench_doc_to_text(doc, lmms_eval_specific_kwargs=None):
question = doc["question"]

pre_prompt = lmms_eval_specific_kwargs.get("pre_prompt", "") or "These are frames of a video."

if doc['question_type'] in NA_QUESTION_TYPES:
post_prompt = lmms_eval_specific_kwargs.get("na_post_prompt", "") or "Please answer the question using a single word or phrase."
return pre_prompt + "\n" + question + "\n" + post_prompt
elif doc['question_type'] in MCA_QUESTION_TYPES:
options = "Options:\n" + "\n".join(doc["options"])
post_prompt = lmms_eval_specific_kwargs.get("mca_post_prompt", "") or "Answer with the option's letter from the given choices directly."
return "\n".join([pre_prompt, question, options, post_prompt])
else:
raise ValueError(f"Unknown question type: {doc['question_type']}")


def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
if os.getenv('LMMS_EVAL_SHUFFLE_DOCS', None):
eval_logger.info(f"Environment variable LMMS_EVAL_SHUFFLE_DOCS detected, dataset will be shuffled.")
return dataset.shuffle(seed=42)
return dataset

def fuzzy_matching(pred):
return pred.split(' ')[0].rstrip('.').strip()

def exact_match(pred, target):
return 1. if pred.lower() == target.lower() else 0.

def abs_dist_norm(pred, target):
return abs(pred - target) / target

def mean_relative_accuracy(pred, target, start, end, interval):
num_pts = (end - start) / interval + 2
conf_intervs = np.linspace(start, end, int(num_pts))
accuracy = abs_dist_norm(pred, target) <= 1 - conf_intervs
return accuracy.mean()

WORST_CASE_FOR_METRICS = {
"accuracy": 0.,
"MRA:.5:.95:.05": 0.,
}

def to_float(pred):
try:
pred = float(pred)
except BaseException as e:
pred = None
return pred

def vsibench_process_results(doc, results):

doc['prediction'] = results[0]
if doc['question_type'] in MCA_QUESTION_TYPES:
for key, value in METRICS_FOR_MCA.items():
doc[key] = eval(value)(fuzzy_matching(doc['prediction']), doc['ground_truth'])
elif doc['question_type'] in NA_QUESTION_TYPES:
for key, value in METRICS_FOR_NA.items():
try:
doc[key] = eval(value)(to_float(fuzzy_matching(doc['prediction'])), to_float(doc['ground_truth']))
except TypeError:
doc[key] = WORST_CASE_FOR_METRICS[key]
else:
raise ValueError(f"Unknown question type: {doc['question_type']}")

return {"vsibench_score": doc}

def vsibench_aggregate_results(results):
results = pd.DataFrame(results)

output = {}

for question_type, question_type_indexes in results.groupby('question_type').groups.items():
per_question_type = results.iloc[question_type_indexes]

if question_type in MCA_QUESTION_TYPES:
for metric in METRICS_FOR_MCA.keys():
output[f"{question_type}_{metric}"] = per_question_type[metric].mean()
elif question_type in NA_QUESTION_TYPES:
for metric in METRICS_FOR_NA.keys():
if metric == 'success_rate':
output[f"{question_type}_{metric}"] = per_question_type[metric].mean()
else:
output[f"{question_type}_{metric}"] = per_question_type[metric].mean()

else:
raise ValueError(f"Unknown question type: {question_type}")

output['object_rel_direction_accuracy'] = sum([
output.pop('object_rel_direction_easy_accuracy'),
output.pop('object_rel_direction_medium_accuracy'),
output.pop('object_rel_direction_hard_accuracy'),
]) / 3.

output['overall'] = sum([_ for _ in output.values()]) / len(output)
eval_logger.info(f"Evaluation results: {output}")
return output
40 changes: 40 additions & 0 deletions lmms_eval/tasks/vsibench/vsibench.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
dataset_path: nyu-visionx/VSI-Bench
dataset_kwargs:
token: True
cache_dir: vsibench
video: True
task: vsibench
test_split: test
output_type: generate_until
process_docs: !function utils.process_docs
doc_to_visual: !function utils.vsibench_doc_to_visual
doc_to_text: !function utils.vsibench_doc_to_text
doc_to_target: "ground_truth"
generation_kwargs:
max_new_tokens: 16
temperature: 0
top_p: 1.0
num_beams: 1
do_sample: false
# The return value of process_results will be used by metrics
process_results: !function utils.vsibench_process_results
# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results
metric_list:
- metric: vsibench_score
aggregation: !function utils.vsibench_aggregate_results
higher_is_better: true
lmms_eval_specific_kwargs:
default:
pre_prompt: ""
mca_post_prompt: "Answer with the option's letter from the given choices directly."
na_post_prompt: "Please answer the question using a single word or phrase."
gemini_api:
pre_prompt: ""
mca_post_prompt: "Answer with the option's letter from the given choices directly."
na_post_prompt: "Do not response anything other than a single number!"
gpt4v:
pre_prompt: ""
mca_post_prompt: "Answer with the option's letter from the given choices directly."
na_post_prompt: "Do not response anything other than a single number!"
metadata:
- version: 0.0
Loading