diff --git a/lmms_eval/models/gpt4v.py b/lmms_eval/models/gpt4v.py index b5e94e06..dd6696f6 100755 --- a/lmms_eval/models/gpt4v.py +++ b/lmms_eval/models/gpt4v.py @@ -53,6 +53,7 @@ def __init__( timeout: int = 120, continual_mode: bool = False, response_persistent_folder: str = None, + interleaved: bool = True, **kwargs, ) -> None: super().__init__() @@ -65,6 +66,7 @@ def __init__( self.image_token = "" self.timeout = timeout self.continual_mode = continual_mode + self.interleaved = interleaved if self.continual_mode: if response_persistent_folder is None: raise ValueError("Continual mode requires a persistent path for the response. Please provide a valid path.") @@ -136,6 +138,21 @@ def flatten(self, input): new_list.append(j) return new_list + def construct_interleaved_input(self, content, media): + print(content, len(media)) + pattern = r"" + parts = re.split(pattern, content) + result = [] + for i, part in enumerate(parts): + if i % 2 == 0: + if part == "": + continue + result.append({"type": "text", "text": part}) + else: + result.append(media[int(part)]) + + return result + def generate_until(self, requests) -> List[str]: res = [] pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding") @@ -167,23 +184,28 @@ def generate_until(self, requests) -> List[str]: response_json = {"role": "user", "content": []} # When there is no image token in the context, append the image to the text - if self.image_token not in contexts: - payload["messages"].append(deepcopy(response_json)) - payload["messages"][0]["content"].append({"type": "text", "text": contexts}) - for img in imgs: - payload["messages"][0]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}}) - else: - contexts = contexts.split(self.image_token) - for idx, img in enumerate(imgs): + if not self.interleaved: + if self.image_token not in contexts: payload["messages"].append(deepcopy(response_json)) - payload["messages"][idx]["content"].append({"type": "text", "text": contexts[idx]}) - payload["messages"][idx]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}}) - - # If n image tokens are in the contexts - # contexts will be splitted into n+1 chunks - # Manually add it into the payload + payload["messages"][0]["content"].append({"type": "text", "text": contexts}) + for img in imgs: + payload["messages"][0]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}}) + else: + contexts = contexts.split(self.image_token) + for idx, img in enumerate(imgs): + payload["messages"].append(deepcopy(response_json)) + payload["messages"][idx]["content"].append({"type": "text", "text": contexts[idx]}) + payload["messages"][idx]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}}) + + # If n image tokens are in the contexts + # contexts will be splitted into n+1 chunks + # Manually add it into the payload + payload["messages"].append(deepcopy(response_json)) + payload["messages"][-1]["content"].append({"type": "text", "text": contexts[-1]}) + else: + media = [{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}} for img in imgs] payload["messages"].append(deepcopy(response_json)) - payload["messages"][-1]["content"].append({"type": "text", "text": contexts[-1]}) + payload["messages"][0]["content"].extend(self.construct_interleaved_input(contexts, media)) if "max_new_tokens" not in gen_kwargs: gen_kwargs["max_new_tokens"] = 1024 diff --git a/lmms_eval/tasks/megabench/breakdown/analysis_utils.py b/lmms_eval/tasks/megabench/breakdown/analysis_utils.py index abcd7ff7..ea47d562 100644 --- a/lmms_eval/tasks/megabench/breakdown/analysis_utils.py +++ b/lmms_eval/tasks/megabench/breakdown/analysis_utils.py @@ -1,10 +1,11 @@ import json -from collections import defaultdict import os +from collections import defaultdict # Add path definition at the top after imports all_task_meta_path = os.path.join(os.path.dirname(__file__), "all_task_meta.json") + def task_list_refine(task_list): task_results = [] for task in task_list: @@ -47,12 +48,7 @@ def derive_keyword_stats(task_results_with_meta, include_per_task_info=False): if include_per_task_info: skills_stats[skill]["tasks"].append((task_name, score)) - for stat_dict, key in [ - (input_format_stats, "input_format"), - (output_format_stats, "output_format"), - (input_num_stats, "num_input"), - (app_stats, "app") - ]: + for stat_dict, key in [(input_format_stats, "input_format"), (output_format_stats, "output_format"), (input_num_stats, "num_input"), (app_stats, "app")]: if value := task.get(key): stat_dict[value]["count"] += 1 stat_dict[value]["total_score"] += score @@ -83,10 +79,10 @@ def collect_task_metadata(model_results): # Load the complete task metadata with open(all_task_meta_path, "r") as f: all_meta = json.load(f) - + # Create result dictionary all_task_meta = {} - + # Match results with metadata for task_result in model_results: task_name = task_result["name"] @@ -94,5 +90,5 @@ def collect_task_metadata(model_results): meta = all_meta[task_name].copy() # Create a copy to avoid modifying original meta.update(task_result) all_task_meta[task_name] = meta - + return all_task_meta diff --git a/lmms_eval/tasks/megabench/breakdown/derive_breakdown_results.py b/lmms_eval/tasks/megabench/breakdown/derive_breakdown_results.py index be049c67..3377d334 100644 --- a/lmms_eval/tasks/megabench/breakdown/derive_breakdown_results.py +++ b/lmms_eval/tasks/megabench/breakdown/derive_breakdown_results.py @@ -1,20 +1,18 @@ -import json import argparse +import json from pathlib import Path -from analysis_utils import ( - task_list_refine, - collect_task_metadata, - derive_keyword_stats, -) + +from analysis_utils import collect_task_metadata, derive_keyword_stats, task_list_refine + def calculate_model_summary(task_results_with_meta): """ Re-calculate model performance summary statistics across core and open tasks. - + Args: task_results: List of task results with scores task_metadata: Dictionary containing task metadata including task types - + Returns: Dictionary containing summary statistics for core and open tasks """ @@ -23,27 +21,27 @@ def calculate_model_summary(task_results_with_meta): # Separate core and open tasks for task in task_results_with_meta.values(): - if task['eval_type'] == 'llm': + if task["eval_type"] == "llm": open_tasks.append(task) else: core_tasks.append(task) - + def calculate_stats(tasks): if not tasks: return None - - total_samples = sum(task.get('num_query', 0) for task in tasks) - macro_scores = [task.get('score', 0) for task in tasks] - + + total_samples = sum(task.get("num_query", 0) for task in tasks) + macro_scores = [task.get("score", 0) for task in tasks] + return { "num_eval_tasks": len(tasks), "num_eval_samples": total_samples, "macro_mean_score": sum(macro_scores) / len(tasks) if tasks else 0, } - + core_stats = calculate_stats(core_tasks) open_stats = calculate_stats(open_tasks) - + # Calculate overall score (weighted average based on number of tasks) # If either stat is None, use only the available stat if core_stats is None: @@ -53,17 +51,11 @@ def calculate_stats(tasks): overall_score = core_stats["macro_mean_score"] if core_stats else 0 total_tasks = core_stats["num_eval_tasks"] if core_stats else 0 else: - total_tasks = (core_stats["num_eval_tasks"] + open_stats["num_eval_tasks"]) - overall_score = ( - (core_stats["macro_mean_score"] * core_stats["num_eval_tasks"] + - open_stats["macro_mean_score"] * open_stats["num_eval_tasks"]) / total_tasks - ) - - return { - "core": core_stats, - "open": open_stats, - "overall_score": overall_score - } + total_tasks = core_stats["num_eval_tasks"] + open_stats["num_eval_tasks"] + overall_score = (core_stats["macro_mean_score"] * core_stats["num_eval_tasks"] + open_stats["macro_mean_score"] * open_stats["num_eval_tasks"]) / total_tasks + + return {"core": core_stats, "open": open_stats, "overall_score": overall_score} + def merge_json_files(input_dir, output_path, key="name"): """ @@ -72,11 +64,11 @@ def merge_json_files(input_dir, output_path, key="name"): Prioritizes LLM evaluations over rule-based ones when duplicates exist. """ data_dict = {} # Using name as key for easy lookup and updates - + # Find all matching JSON files in the directory json_paths = list(Path(input_dir).glob("megabench*data_with_scores*.json")) print(f"Found {len(json_paths)} files to merge") - + # Load and merge all JSON files for path in json_paths: print(f"Processing {path}") @@ -84,64 +76,61 @@ def merge_json_files(input_dir, output_path, key="name"): data = json.load(f) if isinstance(data, dict) and "data" in data: data = task_list_refine(data["data"]) - + # Update or add entries for item in data: item_key = item[key] # If new item or if new item is LLM-evaluated (prioritize LLM eval) - if item_key not in data_dict or ( - item.get("eval_type") == "llm" and data_dict[item_key].get("eval_type") != "llm" - ): + if item_key not in data_dict or (item.get("eval_type") == "llm" and data_dict[item_key].get("eval_type") != "llm"): data_dict[item_key] = item # Convert back to list merged_data = list(data_dict.values()) - + # Save the merged result output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, "w") as f: json.dump(merged_data, f, indent=4) - + print(f"Merged file with {len(merged_data)} tasks saved to {output_path}") return merged_data + def main(): # Parse command line arguments - parser = argparse.ArgumentParser(description='Merge and process evaluation score files.') - parser.add_argument('--input_dir', type=str, help='Directory containing score files') + parser = argparse.ArgumentParser(description="Merge and process evaluation score files.") + parser.add_argument("--input_dir", type=str, help="Directory containing score files") args = parser.parse_args() # Convert path to Path object input_dir = Path(args.input_dir) - + # Create analysis directory under input directory output_dir = input_dir / "analysis" output_dir.mkdir(parents=True, exist_ok=True) - + # Merge files output_path = output_dir / "task_results.json" task_results = merge_json_files(input_dir, output_path) - + # Collect metadata and derive keyword stats task_results_with_meta = collect_task_metadata(task_results) keyword_stats = derive_keyword_stats(task_results_with_meta) - + # Calculate model summary model_summary = calculate_model_summary(task_results_with_meta) - summary_results = { - "model_summary": model_summary, - "keyword_stats": keyword_stats - } - + summary_results = {"model_summary": model_summary, "keyword_stats": keyword_stats} + # Save keyword stats stats_output = output_dir / "summary_and_keyword_stats.json" with open(stats_output, "w") as f: json.dump(summary_results, f, indent=4) - + print(f"\nResults saved in {output_dir}:") print(f"- Merged data: {output_path}") print(f"- Multi-dimensional keywords stats: {stats_output}") + if __name__ == "__main__": main() diff --git a/lmms_eval/tasks/megabench/evaluator.py b/lmms_eval/tasks/megabench/evaluator.py index a4147776..5114b83f 100644 --- a/lmms_eval/tasks/megabench/evaluator.py +++ b/lmms_eval/tasks/megabench/evaluator.py @@ -1,11 +1,11 @@ import argparse +import ast import json import os from typing import Any, Dict, List -import ast from datasets import load_dataset -from metrics import MetricType, AggregationType, ResponseParseType +from metrics import AggregationType, MetricType, ResponseParseType from metrics.parsing.common.utils import evaluate_as_string from metrics.scoring.vlm_as_judge import VLMJudgeScore @@ -24,21 +24,14 @@ def __init__( """ self.hf_data = self._load_hf(subset_name) # e.g. same structure used previously self.data = self._load_json(responses_file) # The model's output - self.eval_results = ( - self._load_json(output_file) - if os.path.exists(output_file) - else {"data": self.data} - ) + self.eval_results = self._load_json(output_file) if os.path.exists(output_file) else {"data": self.data} self.output_file = output_file - # Build a dict of {task_name -> metric configuration} for quick lookup self.scoring_functions = {} for task_name, task_samples in self.hf_data.items(): - self.scoring_functions[task_name] = ast.literal_eval( - task_samples[0]["metric_info"] - ) - + self.scoring_functions[task_name] = ast.literal_eval(task_samples[0]["metric_info"]) + def _load_hf(self, subset_name: str) -> List[Dict[str, Any]]: """ Load the HF dataset for the given subset name. @@ -50,21 +43,21 @@ def _load_hf(self, subset_name: str) -> List[Dict[str, Any]]: if task_name not in task_dict: task_dict[task_name] = [] task_dict[task_name].append(sample) - + return task_dict - + def _get_eval_context(self, task_name, query): if "query_idx" in query: - query_idx = query["query_idx"] + query_idx = query["query_idx"] eval_context = self.hf_data[task_name][query_idx]["eval_context"] else: global_idx = query["global_idx"] global_idx_to_sample = {sample["id"]: sample for sample in self.hf_data[task_name]} eval_context = global_idx_to_sample[global_idx]["eval_context"] - + eval_context = ast.literal_eval(eval_context) return eval_context - + def _task_needs_eval(self, task: Dict) -> bool: task_in_results = False for existing_task in self.eval_results["data"]: @@ -75,26 +68,14 @@ def _task_needs_eval(self, task: Dict) -> bool: if len(task["query_response"]) != len(existing_task["query_response"]): return True - for res_example, saved_example in zip( - task["query_response"], existing_task["query_response"] - ): - if ( - res_example["response"] != saved_example["response"] - or res_example["correct_answer"] - != saved_example["correct_answer"] - ): + for res_example, saved_example in zip(task["query_response"], existing_task["query_response"]): + if res_example["response"] != saved_example["response"] or res_example["correct_answer"] != saved_example["correct_answer"]: # model response or gt answer changed return True - elif ( - "scores" not in saved_example - or "query" not in saved_example["scores"] - ): + elif "scores" not in saved_example or "query" not in saved_example["scores"]: # no existing eval results (not evaluated before) return True - elif ( - saved_example["scores"]["query"] == -1 - and len(saved_example["scores"]["field"]) == 0 - ): + elif saved_example["scores"]["query"] == -1 and len(saved_example["scores"]["field"]) == 0: return True else: # nothing changed, using the old eval results @@ -133,7 +114,7 @@ def evaluate(self): # If no scoring config is found for the given task_name, skip score_config = self.scoring_functions.get( - task_name, + task_name, { "field_score_function": {}, "aggregation": {"function": None, "field_weights": {}}, @@ -185,14 +166,7 @@ def evaluate(self): # 2) Evaluate each field for fld, fld_metric_name in field_score_functions.items(): metric = self._build_metric(fld_metric_name, score_config) - self._evaluate_field( - task_name, - metric, - fld, - response_obj, - correct_answer, - query - ) + self._evaluate_field(task_name, metric, fld, response_obj, correct_answer, query) # Evaluate global auxiliary metrics (if any) for fld, fld_metric_name in global_aux_metrics.items(): @@ -202,10 +176,10 @@ def evaluate(self): tmp_obj = {fld: response_obj} self._evaluate_field( task_name, - metric, - fld, - tmp_obj, - correct_answer, + metric, + fld, + tmp_obj, + correct_answer, query, is_aux=True, ) @@ -226,7 +200,7 @@ def evaluate(self): mean_score = 0.0 task["task_score"] = task_score_sum task["mean_task_score"] = mean_score - task['eval_type'] = 'llm' if isinstance(metric, VLMJudgeScore) else 'rule' + task["eval_type"] = "llm" if isinstance(metric, VLMJudgeScore) else "rule" total_query_score += task_score_sum total_task_score += mean_score @@ -281,7 +255,7 @@ def _evaluate_field( ) -> float: """Compute score for a single field using the given metric.""" eval_context = self._get_eval_context(task_name, query) - + if metric == MetricType.UNSUPPORTED: print(f"The metric for {field} in task {task_name} is not supported") return 0.0 @@ -299,11 +273,7 @@ def _evaluate_field( query["scores"]["field"][field] = score query["scores"]["info"][field] = eval_info elif isinstance(metric, VLMJudgeScore): - response_info = ( - response_obj.get(field) - if isinstance(response_obj, dict) - else response_obj - ) + response_info = response_obj.get(field) if isinstance(response_obj, dict) else response_obj score, eval_info = metric.match( response_info, correct_answer, @@ -335,9 +305,7 @@ def _parse_response( res_parsing_pass = True if parser.is_single_field_parser(): # single field - assert ( - len(answer_fields) == 1 - ), "The answer_string parse must be used when the answer has a single field" + assert len(answer_fields) == 1, "The answer_string parse must be used when the answer has a single field" answer_key = answer_fields[0] global_description = task["global_description"] @@ -356,9 +324,7 @@ def _parse_response( # Structural output (using JSON parser or other specified parsing func) or dummy parse (return all) response_obj = parser.parse(response_text) - if parser == ResponseParseType.JSON and ( - not isinstance(response_obj, dict) or not response_obj - ): + if parser == ResponseParseType.JSON and (not isinstance(response_obj, dict) or not response_obj): # Expect a JSON, but parsing failed, # Record the failure parsing, and use the raw string for each field of the answer res_parsing_pass = False @@ -367,9 +333,7 @@ def _parse_response( response_obj[field] = response_text if not res_parsing_pass: - print( - f"Task:{task_name}, cannot parse query with global idx {query['global_idx']}" - ) + print(f"Task:{task_name}, cannot parse query with global idx {query['global_idx']}") return response_obj def _build_metric(self, metric_name: str, score_config: Dict[str, Any]): diff --git a/lmms_eval/tasks/megabench/image_video_utils.py b/lmms_eval/tasks/megabench/image_video_utils.py index 199c6fac..ad1d2d2f 100644 --- a/lmms_eval/tasks/megabench/image_video_utils.py +++ b/lmms_eval/tasks/megabench/image_video_utils.py @@ -1,10 +1,11 @@ +import os +import re +from ast import literal_eval from mimetypes import guess_type + import cv2 import numpy as np from PIL import Image -import re -import os -from ast import literal_eval ## Image reading utils @@ -77,6 +78,7 @@ def is_video_file(file_path): ## Handle tasks with mixed image and video inputs. ## Need to subsample video frames to multiple images + def load_media_content(media_path, max_nframes): # normalize media path if is_video_file(media_path): diff --git a/lmms_eval/tasks/megabench/metrics/aggregation/mean_agg.py b/lmms_eval/tasks/megabench/metrics/aggregation/mean_agg.py index 13d675fa..8bffc722 100644 --- a/lmms_eval/tasks/megabench/metrics/aggregation/mean_agg.py +++ b/lmms_eval/tasks/megabench/metrics/aggregation/mean_agg.py @@ -1,5 +1,6 @@ from numbers import Number from typing import Dict + import numpy as np diff --git a/lmms_eval/tasks/megabench/metrics/aggregation_type.py b/lmms_eval/tasks/megabench/metrics/aggregation_type.py index 9adf56bf..8a3cbfd4 100644 --- a/lmms_eval/tasks/megabench/metrics/aggregation_type.py +++ b/lmms_eval/tasks/megabench/metrics/aggregation_type.py @@ -1,5 +1,6 @@ from enum import Enum from functools import cached_property + from metrics.aggregation.mean_agg import MeanAggregation from metrics.aggregation.min_agg import MinAggregation from metrics.aggregation.unsupported_agg import UnsupportedAggregation diff --git a/lmms_eval/tasks/megabench/metrics/metric_type.py b/lmms_eval/tasks/megabench/metrics/metric_type.py index cc3b9c0f..215339b9 100644 --- a/lmms_eval/tasks/megabench/metrics/metric_type.py +++ b/lmms_eval/tasks/megabench/metrics/metric_type.py @@ -1,56 +1,66 @@ -from functools import cached_property -from enum import Enum import logging +from enum import Enum +from functools import cached_property -# Import all metrics -from metrics.scoring.simple_str_match import SimpleStrMatch -from metrics.scoring.exact_str_match import ExactStrMatch, CodeResultExactStrMatch -from metrics.scoring.dict_exact_match_agg_recall import DictExactStrMatchAggRecall -from metrics.scoring.exact_str_match_case_insensitive import ExactStrMatchCaseInsensitive -from metrics.scoring.normalized_similarity_damerau_levenshtein import NormalizedSimilarityDamerauLevenshtein -from metrics.scoring.near_str_match import NearStrMatch -from metrics.scoring.number_rel_diff_ratio import NumberRelDiffRatio -from metrics.scoring.set_equality import ( - SetEquality, - SetEqualityCaseInsensitive, - StringSetEqualityLineSplit, - StringSetEqualityCommaSplit -) -from metrics.scoring.dict_set_equality_agg_jaccard import DictSetEqualityAggJaccard -from metrics.scoring.dict_equality import DictEquality, DictPrecision -from metrics.scoring.jaccard import Jaccard, JaccardCaseInsensitive -from metrics.scoring.dict_jaccard_agg_jaccard import DictJaccardAggJaccard -from metrics.scoring.set_precision import SetPrecision -from metrics.scoring.positive_int_match import PositiveIntMatch +from metrics.scoring.ascii_art_vlm_judge import AsciiArtVLMJudgeScore from metrics.scoring.chess_jaccard import ChessMoveJaccard -from metrics.scoring.longest_common_list_prefix_ratio import LongestCommonListPrefixRatio -from metrics.scoring.nli_entailment import NliEntailment -from metrics.scoring.sacrebleu_bleu import Bleu -from metrics.scoring.gleu import GLEUChinese -from metrics.scoring.xml_nbbox_iou import XmlNbboxIouSingle -from metrics.scoring.general_numerical_match import BoxedSingleNumericalMatch, GeneralSingleNumericalMatch +from metrics.scoring.constrained_generation import ConstrainedGenerationEval from metrics.scoring.coordinate_sequence_match import CoordsSequenceSimilarity -from metrics.scoring.latex_expr_equality import LatexExprEquality, TextLatexExprEquality -from metrics.scoring.nbbox_iou import NbboxIouTuple, NbboxIouSingle, NbboxIouSequence +from metrics.scoring.dict_equality import DictEquality, DictPrecision +from metrics.scoring.dict_exact_match_agg_recall import DictExactStrMatchAggRecall +from metrics.scoring.dict_jaccard_agg_jaccard import DictJaccardAggJaccard from metrics.scoring.dict_nbbox_iou_tuple_agg_jaccard import DictNbboxIouTupleAggJaccard -from metrics.scoring.xml_norm_point_in_bbox import XmlNormPointInBbox -from metrics.scoring.xml_norm_point_distance import XmlNormPointDistance +from metrics.scoring.dict_set_equality_agg_jaccard import DictSetEqualityAggJaccard +from metrics.scoring.exact_str_match import CodeResultExactStrMatch, ExactStrMatch +from metrics.scoring.exact_str_match_case_insensitive import ( + ExactStrMatchCaseInsensitive, +) +from metrics.scoring.general_numerical_match import ( + BoxedSingleNumericalMatch, + GeneralSingleNumericalMatch, +) from metrics.scoring.geo_proximity import GeoProximityLocationDict -from metrics.scoring.mse import NormalizedRMSE, AngleSeqFloatRMSE +from metrics.scoring.gleu import GLEUChinese +from metrics.scoring.jaccard import Jaccard, JaccardCaseInsensitive +from metrics.scoring.latex_expr_equality import LatexExprEquality, TextLatexExprEquality +from metrics.scoring.longest_common_list_prefix_ratio import ( + LongestCommonListPrefixRatio, +) +from metrics.scoring.mse import AngleSeqFloatRMSE, NormalizedRMSE +from metrics.scoring.multi_ref_phrase import MultipleReferencePhraseEval +from metrics.scoring.nbbox_iou import NbboxIouSequence, NbboxIouSingle, NbboxIouTuple +from metrics.scoring.near_str_match import NearStrMatch +from metrics.scoring.nli_entailment import NliEntailment +from metrics.scoring.normalized_similarity_damerau_levenshtein import ( + NormalizedSimilarityDamerauLevenshtein, +) +from metrics.scoring.number_rel_diff_ratio import NumberRelDiffRatio +from metrics.scoring.positive_int_match import PositiveIntMatch from metrics.scoring.program_judge import ProgramJudge +from metrics.scoring.sacrebleu_bleu import Bleu from metrics.scoring.sequence_equality import ( + SequenceAccuracyCaseInsensitive, SequenceEquality, SequenceEqualityCaseInsensitive, - SequenceAccuracyCaseInsensitive ) +from metrics.scoring.set_equality import ( + SetEquality, + SetEqualityCaseInsensitive, + StringSetEqualityCommaSplit, + StringSetEqualityLineSplit, +) +from metrics.scoring.set_precision import SetPrecision + +# Import all metrics +from metrics.scoring.simple_str_match import SimpleStrMatch from metrics.scoring.symbolic_planning import SymbolicPlanningMetricTest -from metrics.scoring.multi_ref_phrase import MultipleReferencePhraseEval -from metrics.scoring.constrained_generation import ConstrainedGenerationEval from metrics.scoring.unsupported_scoring import UnsupportedScoring ## The vlm-judge metrics from metrics.scoring.vlm_as_judge import VLMJudgeScore -from metrics.scoring.ascii_art_vlm_judge import AsciiArtVLMJudgeScore +from metrics.scoring.xml_nbbox_iou import XmlNbboxIouSingle +from metrics.scoring.xml_norm_point_distance import XmlNormPointDistance +from metrics.scoring.xml_norm_point_in_bbox import XmlNormPointInBbox class MetricType(Enum): @@ -164,7 +174,7 @@ def class_impl(self): if self not in implementations: logging.error(f"Metric {self} not implemented...") return UnsupportedScoring() - + return implementations[self] def match(self, response: str, correct_answer: str): diff --git a/lmms_eval/tasks/megabench/metrics/parsing/answer_str_parse.py b/lmms_eval/tasks/megabench/metrics/parsing/answer_str_parse.py index 982f64c9..2b409ebe 100644 --- a/lmms_eval/tasks/megabench/metrics/parsing/answer_str_parse.py +++ b/lmms_eval/tasks/megabench/metrics/parsing/answer_str_parse.py @@ -1,10 +1,11 @@ import logging + from metrics.parsing.common.parsers import parse_json from metrics.parsing.common.utils import ( - extract_code_block_content, - extract_answer_content, - evaluate_as_string, drop_additional_text, + evaluate_as_string, + extract_answer_content, + extract_code_block_content, ) logger = logging.getLogger("errorLogger") @@ -58,7 +59,7 @@ def _parse( # ) if "[]" not in answer_content: return answer_content - return str(response_obj) # make sure the response to the metric is always a string + return str(response_obj) # make sure the response to the metric is always a string else: # drop the redundant string quotes answer_content = evaluate_as_string(answer_content) diff --git a/lmms_eval/tasks/megabench/metrics/parsing/common/parsers.py b/lmms_eval/tasks/megabench/metrics/parsing/common/parsers.py index b3d652a8..9aca7b6d 100644 --- a/lmms_eval/tasks/megabench/metrics/parsing/common/parsers.py +++ b/lmms_eval/tasks/megabench/metrics/parsing/common/parsers.py @@ -1,10 +1,11 @@ import ast import json import re -import regex # Supports the non-standard ?R regex operator from typing import List -from .utils import extract_code_block_content, extract_answer_at_beginning_of_line +import regex # Supports the non-standard ?R regex operator + +from .utils import extract_answer_at_beginning_of_line, extract_code_block_content PARSING_TIMEOUT = 0.1 @@ -23,9 +24,7 @@ def parse_json(response: str): # Find all potential JSON objects try: - potential_jsons = regex.findall( - json_pattern, response_, timeout=PARSING_TIMEOUT - ) + potential_jsons = regex.findall(json_pattern, response_, timeout=PARSING_TIMEOUT) except TimeoutError: if response_.startswith("["): return [] @@ -44,11 +43,7 @@ def parse_json(response: str): # Process each string literal for s in strings: # Unescape the string content - unescaped = ( - s[1:-1] - .replace("__DOUBLE_QUOTE__", '"') - .replace("__SINGLE_QUOTE__", "'") - ) + unescaped = s[1:-1].replace("__DOUBLE_QUOTE__", '"').replace("__SINGLE_QUOTE__", "'") # Try to parse it as JSON try: parsed = json.loads(unescaped) diff --git a/lmms_eval/tasks/megabench/metrics/parsing/common/utils.py b/lmms_eval/tasks/megabench/metrics/parsing/common/utils.py index 9db1c6bd..d4a1575b 100644 --- a/lmms_eval/tasks/megabench/metrics/parsing/common/utils.py +++ b/lmms_eval/tasks/megabench/metrics/parsing/common/utils.py @@ -1,5 +1,5 @@ -import re import ast +import re def extract_code_block_content( @@ -39,7 +39,7 @@ def extract_code_block_content( def keep_the_last_answer(s: str): # 1. Find the last occurrence - s = s.replace('answer:', 'Answer:') + s = s.replace("answer:", "Answer:") last_index = s.rfind("Answer:") # If "Answer:" is found in the string @@ -47,10 +47,10 @@ def keep_the_last_answer(s: str): # 2. Separate into prefix and suffix prefix = s[:last_index] suffix = s[last_index:] - + # 3. Remove all earlier occurrences of "Answer:" cleaned_prefix = prefix.replace("Answer:", "") - + # 4. Combine them back together result = cleaned_prefix + suffix else: @@ -60,16 +60,12 @@ def keep_the_last_answer(s: str): return result -def extract_answer_content( - response, is_ascii_art=False, should_remove_surrounding_whitespace=True -): +def extract_answer_content(response, is_ascii_art=False, should_remove_surrounding_whitespace=True): response = keep_the_last_answer(response) if is_ascii_art: match = re.search(r"\*\*?Answer:(.*?)\*\*?|\bAnswer:(.*)", response, re.DOTALL) else: - match = re.search( - r"\*\*?Answer:\s*(.*?)\*\*?|\bAnswer:\s*(.*)", response, re.DOTALL - ) + match = re.search(r"\*\*?Answer:\s*(.*?)\*\*?|\bAnswer:\s*(.*)", response, re.DOTALL) if match: # Extract the content after "Answer:" response = match.group(1) or match.group(2) @@ -113,11 +109,7 @@ def drop_additional_text(result): result_first_paragraph, ) - only_return_first_paragraph = ( - potential_ans_in_single_line - and result_first_paragraph.strip() != "" - and not _is_multiline_answer(result) - ) + only_return_first_paragraph = potential_ans_in_single_line and result_first_paragraph.strip() != "" and not _is_multiline_answer(result) if only_return_first_paragraph: return result_first_paragraph diff --git a/lmms_eval/tasks/megabench/metrics/parsing/dummy_parse.py b/lmms_eval/tasks/megabench/metrics/parsing/dummy_parse.py index 21b5a2b1..b868107a 100644 --- a/lmms_eval/tasks/megabench/metrics/parsing/dummy_parse.py +++ b/lmms_eval/tasks/megabench/metrics/parsing/dummy_parse.py @@ -1,5 +1,4 @@ class DummyParse: - @staticmethod def parse(response: str, *args, **kwargs) -> dict: """return the raw string without doing anything""" diff --git a/lmms_eval/tasks/megabench/metrics/response_parse_type.py b/lmms_eval/tasks/megabench/metrics/response_parse_type.py index 5262ba24..a80ccd47 100644 --- a/lmms_eval/tasks/megabench/metrics/response_parse_type.py +++ b/lmms_eval/tasks/megabench/metrics/response_parse_type.py @@ -1,12 +1,13 @@ -from functools import cached_property from enum import Enum -from metrics.parsing.json_parse import JsonParse +from functools import cached_property + from metrics.parsing.answer_str_parse import ( AnswerStrParse, AsciiAnswerStrParse, VerbatimAnswerStrParse, ) from metrics.parsing.dummy_parse import DummyParse +from metrics.parsing.json_parse import JsonParse class ResponseParseType(Enum): diff --git a/lmms_eval/tasks/megabench/metrics/scoring/ascii_art_vlm_judge.py b/lmms_eval/tasks/megabench/metrics/scoring/ascii_art_vlm_judge.py index 2cc2a689..d4f98710 100644 --- a/lmms_eval/tasks/megabench/metrics/scoring/ascii_art_vlm_judge.py +++ b/lmms_eval/tasks/megabench/metrics/scoring/ascii_art_vlm_judge.py @@ -1,6 +1,7 @@ """Return if two ASCII art images depict the same thing.""" from numbers import Number + import requests from metrics.scoring.common.conversions import ascii_text_to_image from metrics.scoring.vlm_as_judge import OpenAIVLMJudger @@ -15,7 +16,7 @@ def __init__(self, metric_config, model="gpt-4o-2024-08-06"): metric_config, model, ) - + def encode_image(self, image): """Encode an image into base64 and return its mime type.""" mime_type = "image/jpeg" @@ -31,7 +32,7 @@ def encode_image(self, image): encoded_image = self._encode_image(image, image_format) return encoded_image, mime_type - + def create_image_content(self, image): base64_image, mime_type = self.encode_image(image) return { @@ -72,16 +73,14 @@ def query(self, images): json=query_payload, ) except (requests.exceptions.JSONDecodeError, requests.exceptions.ConnectionError) as e: - print(f'Error in requests: {e}') - print('Retry...') + print(f"Error in requests: {e}") + print("Retry...") continue response_ = response.json() if "error" in response_: error_info = response_["error"] - print( - f"Got error with type: {error_info['type']}. Message: {error_info['message']}" - ) + print(f"Got error with type: {error_info['type']}. Message: {error_info['message']}") print(f"Retry...") else: response_data = response_ @@ -93,9 +92,7 @@ def query(self, images): choices = response_data["choices"] if choices and "message" in choices[0]: message_content = choices[0]["message"]["content"] - print( - f"gpt-4o judge results: {message_content}; tokens:{total_tokens}" - ) + print(f"gpt-4o judge results: {message_content}; tokens:{total_tokens}") else: print(f"gpt-4o judge query failed...") message_content = "" diff --git a/lmms_eval/tasks/megabench/metrics/scoring/chess_jaccard.py b/lmms_eval/tasks/megabench/metrics/scoring/chess_jaccard.py index 63250bf3..7a5317a7 100644 --- a/lmms_eval/tasks/megabench/metrics/scoring/chess_jaccard.py +++ b/lmms_eval/tasks/megabench/metrics/scoring/chess_jaccard.py @@ -1,5 +1,6 @@ import logging -from typing import Dict, Any +from typing import Any, Dict + from metrics.scoring.common.conversions import str_to_set from metrics.scoring.common.metrics import jaccard_index diff --git a/lmms_eval/tasks/megabench/metrics/scoring/common/conversions.py b/lmms_eval/tasks/megabench/metrics/scoring/common/conversions.py index 8c999bdd..a1874f13 100644 --- a/lmms_eval/tasks/megabench/metrics/scoring/common/conversions.py +++ b/lmms_eval/tasks/megabench/metrics/scoring/common/conversions.py @@ -1,12 +1,13 @@ import ast import json import re -from matplotlib import font_manager -from PIL import Image, ImageDraw, ImageFont -from metrics.parsing.common.parsers import parse_json from numbers import Number from typing import Tuple, Union +from matplotlib import font_manager +from metrics.parsing.common.parsers import parse_json +from PIL import Image, ImageDraw, ImageFont + def freeze_structure(obj): """Freeze a structure and make it hashable.""" @@ -161,7 +162,6 @@ def parse_point_2d_from_xml(xml_string) -> Union[Tuple[float, float], None]: def parse_bboxes_from_xml(xml_string: str) -> list: - if not isinstance(xml_string, str): return [] @@ -170,7 +170,6 @@ def parse_bboxes_from_xml(xml_string: str) -> list: new_bboxes = [] for match in matches: - coords = match.split(",") if len(coords) != 4: continue @@ -190,9 +189,7 @@ def parse_bboxes_from_xml(xml_string: str) -> list: MONOSPACE_FONT_FILES = [] for font_name in MONOSPACE_FONTS: try: - MONOSPACE_FONT_FILES.append( - font_manager.findfont(font_name, fallback_to_default=False) - ) + MONOSPACE_FONT_FILES.append(font_manager.findfont(font_name, fallback_to_default=False)) except ValueError: continue @@ -214,9 +211,7 @@ def ascii_text_to_image( # Calculate initial image size based on text char_width = font_size * 0.6 # Approximate width of a character init_width = int(max(len(line) for line in lines) * char_width + 2 * padding) - init_height = int( - (len(lines) * font_size * line_spacing) + 2 * padding - ) # 1.2 for line spacing + init_height = int((len(lines) * font_size * line_spacing) + 2 * padding) # 1.2 for line spacing # Create a new image with the calculated size image = Image.new("RGB", (init_width, init_height), color=bg_color) diff --git a/lmms_eval/tasks/megabench/metrics/scoring/common/metrics.py b/lmms_eval/tasks/megabench/metrics/scoring/common/metrics.py index 8b7b2807..17444845 100644 --- a/lmms_eval/tasks/megabench/metrics/scoring/common/metrics.py +++ b/lmms_eval/tasks/megabench/metrics/scoring/common/metrics.py @@ -1,5 +1,5 @@ -from collections.abc import Iterable import math +from collections.abc import Iterable from numbers import Number @@ -94,9 +94,5 @@ def mse(predicted: Number, target: Number) -> Number: def point_distance(predicted: tuple[float, ...], target: tuple[float, ...]): """Return the distance between two points.""" if len(predicted) != len(target): - raise ValueError( - "point_distance: Predicted and target points have different dimensions." - ) - return math.sqrt( - sum((comp_res - comp_tar) ** 2 for comp_res, comp_tar in zip(predicted, target)) - ) + raise ValueError("point_distance: Predicted and target points have different dimensions.") + return math.sqrt(sum((comp_res - comp_tar) ** 2 for comp_res, comp_tar in zip(predicted, target))) diff --git a/lmms_eval/tasks/megabench/metrics/scoring/constrained_generation.py b/lmms_eval/tasks/megabench/metrics/scoring/constrained_generation.py index 1fea6bf6..8424979a 100644 --- a/lmms_eval/tasks/megabench/metrics/scoring/constrained_generation.py +++ b/lmms_eval/tasks/megabench/metrics/scoring/constrained_generation.py @@ -1,15 +1,16 @@ import collections import itertools import logging -from numbers import Number import re -from nltk.tokenize import sent_tokenize, word_tokenize -from nltk.stem import WordNetLemmatizer -from nltk.corpus import wordnet +import signal +from numbers import Number + import pronouncing -from metrics.scoring.common.conversions import str_to_iterable from metrics.parsing.common.parsers import parse_nested_str_list, parse_syllable_ranges -import signal +from metrics.scoring.common.conversions import str_to_iterable +from nltk.corpus import wordnet +from nltk.stem import WordNetLemmatizer +from nltk.tokenize import sent_tokenize, word_tokenize def custom_lemmatize(word, lemmatizer): @@ -78,10 +79,7 @@ def phones_for_word(text: str) -> list[str]: text = text.removesuffix("'s") if text in custom_phones_for_word: - return [ - pr + suffix - for pr, suffix in itertools.product(custom_phones_for_word[text], suffixes) - ] + return [pr + suffix for pr, suffix in itertools.product(custom_phones_for_word[text], suffixes)] pronunciations = pronouncing.phones_for_word(text) @@ -115,12 +113,7 @@ def phones_for_word(text: str) -> list[str]: prefixes_to_remove = ["AH0 "] text = "a" + text.removeprefix("'") pronunciations = pronouncing.phones_for_word(text) - pronunciations = [ - (prefix + pr + suffix).removeprefix(prefix_to_remove) - for prefix, pr, suffix, prefix_to_remove in itertools.product( - prefixes, pronunciations, suffixes, prefixes_to_remove - ) - ] + pronunciations = [(prefix + pr + suffix).removeprefix(prefix_to_remove) for prefix, pr, suffix, prefix_to_remove in itertools.product(prefixes, pronunciations, suffixes, prefixes_to_remove)] if not pronunciations: file_logger.error(f"OOV: {text}") @@ -158,9 +151,7 @@ def count_syllables(text: str) -> list[int]: pronunciations = [phones_for_word(p) for p in text.split()] syllable_counts = [] for pronun_possibility in itertools.product(*pronunciations): - syllable_counts.append( - sum([pronouncing.syllable_count(p) for p in pronun_possibility]) - ) + syllable_counts.append(sum([pronouncing.syllable_count(p) for p in pronun_possibility])) return syllable_counts @@ -193,10 +184,7 @@ def find_string_occurrences_with_variations(text, search_string): def word_to_stresses(word: str) -> list[list[int]]: """Convert a word to a list of stresses, for each valid pronunciation.""" pronunciations = phones_for_word(word) - stresses = { - tuple(int(stress) for stress in pronouncing.stresses(pronunc)) - for pronunc in pronunciations - } + stresses = {tuple(int(stress) for stress in pronouncing.stresses(pronunc)) for pronunc in pronunciations} return [list(pronunc_stresses) for pronunc_stresses in stresses] @@ -233,9 +221,7 @@ def backtrack(word_index: int, syllable_index: int, prev_stress: int) -> bool: word_syllable_index += 1 word_valid_iambic_pairs = True - for stress1, stress2 in grouper_ignore_last( - stress_pattern[word_syllable_index:], 2 - ): + for stress1, stress2 in grouper_ignore_last(stress_pattern[word_syllable_index:], 2): if not is_iambic_pair(stress1, stress2): word_valid_iambic_pairs = False break @@ -259,9 +245,7 @@ def backtrack(word_index: int, syllable_index: int, prev_stress: int) -> bool: return False - return backtrack( - 0, 0, -1 - ) # Start with -1 as prev_stress as a placeholder for the first syllable + return backtrack(0, 0, -1) # Start with -1 as prev_stress as a placeholder for the first syllable def parse_constraints(key_string, value_string): @@ -280,9 +264,7 @@ def parse_constraints(key_string, value_string): # Combine keys and values into a dictionary if len(key_components) == len(value_components): - result = { - key.lower(): value for key, value in zip(key_components, value_components) - } + result = {key.lower(): value for key, value in zip(key_components, value_components)} elif len(key_components) == 1 and len(value_components) == 1: result = {key_components[0].lower(): value_components[0]} else: @@ -305,9 +287,7 @@ def check_constraint(response, constraint, constraint_val): for cond in conditions: count, occurs = 0, [] for item in cond: # check one condition - count_, occurs_ = find_string_occurrences_with_variations( - response, item - ) + count_, occurs_ = find_string_occurrences_with_variations(response, item) if count_ > 0: count += count_ occurs.extend(occurs_) @@ -319,9 +299,7 @@ def check_constraint(response, constraint, constraint_val): items = str_to_iterable(list, parsed_constraint["contain"]) count, occurs = 0, [] for item in items: - count_, occurs_ = find_string_occurrences_with_variations( - response, item - ) + count_, occurs_ = find_string_occurrences_with_variations(response, item) if count_ > 0: count += count_ occurs.extend(occurs_) @@ -402,9 +380,7 @@ def check_constraint(response, constraint, constraint_val): response = response.replace('"', "") response = response.replace("-", " ") response = response.replace("—", " ") - response = re.sub( - " *\(\w\) *(?=\n|$)", "", response - ) # The parenthesized letter in the rhyming scheme + response = re.sub(" *\(\w\) *(?=\n|$)", "", response) # The parenthesized letter in the rhyming scheme lines = response.lower().split("\n") match constraint: @@ -413,15 +389,7 @@ def check_constraint(response, constraint, constraint_val): if len(lines) != len(syllable_count_intervals): return 0 try: - all_match = all( - any( - min_count <= syll_count <= max_count - for syll_count in count_syllables(line) - ) - for line, (min_count, max_count) in zip( - lines, syllable_count_intervals - ) - ) + all_match = all(any(min_count <= syll_count <= max_count for syll_count in count_syllables(line)) for line, (min_count, max_count) in zip(lines, syllable_count_intervals)) except IndexError: all_match = None score = 1 if all_match else 0 @@ -443,13 +411,7 @@ def check_constraint(response, constraint, constraint_val): # Check that 1. The words for the same letter all rhyme letter_to_rhyming_parts = {} for letter, words in letter_to_words.items(): - rhyming_parts: list[set[str]] = [ - { - rhyming_part_include_unstressed(pronunciations) - for pronunciations in phones_for_word(word) - } - for word in words - ] + rhyming_parts: list[set[str]] = [{rhyming_part_include_unstressed(pronunciations) for pronunciations in phones_for_word(word)} for word in words] common_rhyming_parts = set.intersection(*rhyming_parts) if not common_rhyming_parts: return 0 diff --git a/lmms_eval/tasks/megabench/metrics/scoring/coordinate_sequence_match.py b/lmms_eval/tasks/megabench/metrics/scoring/coordinate_sequence_match.py index 10e1ed1b..702cc7e7 100644 --- a/lmms_eval/tasks/megabench/metrics/scoring/coordinate_sequence_match.py +++ b/lmms_eval/tasks/megabench/metrics/scoring/coordinate_sequence_match.py @@ -1,6 +1,7 @@ import logging -from metrics.scoring.common.conversions import str_to_coords + import numpy as np +from metrics.scoring.common.conversions import str_to_coords class CoordsSequenceSimilarity: diff --git a/lmms_eval/tasks/megabench/metrics/scoring/dict_equality.py b/lmms_eval/tasks/megabench/metrics/scoring/dict_equality.py index 7f719293..32090b10 100644 --- a/lmms_eval/tasks/megabench/metrics/scoring/dict_equality.py +++ b/lmms_eval/tasks/megabench/metrics/scoring/dict_equality.py @@ -22,7 +22,6 @@ def match(cls, responses, targets) -> float: class DictPrecision: - @classmethod def match(cls, responses, targets) -> float: """Return the aggregated Jaccard index between targets and responses.""" diff --git a/lmms_eval/tasks/megabench/metrics/scoring/dict_jaccard_agg_jaccard.py b/lmms_eval/tasks/megabench/metrics/scoring/dict_jaccard_agg_jaccard.py index 3d8e1d4a..62035a0d 100644 --- a/lmms_eval/tasks/megabench/metrics/scoring/dict_jaccard_agg_jaccard.py +++ b/lmms_eval/tasks/megabench/metrics/scoring/dict_jaccard_agg_jaccard.py @@ -1,5 +1,5 @@ -from metrics.scoring.jaccard import Jaccard from metrics.scoring.common.conversions import cast_to_dict +from metrics.scoring.jaccard import Jaccard class DictJaccardAggJaccard: diff --git a/lmms_eval/tasks/megabench/metrics/scoring/dict_nbbox_iou_tuple_agg_jaccard.py b/lmms_eval/tasks/megabench/metrics/scoring/dict_nbbox_iou_tuple_agg_jaccard.py index a57b6a24..135d3ec3 100644 --- a/lmms_eval/tasks/megabench/metrics/scoring/dict_nbbox_iou_tuple_agg_jaccard.py +++ b/lmms_eval/tasks/megabench/metrics/scoring/dict_nbbox_iou_tuple_agg_jaccard.py @@ -1,5 +1,5 @@ -from metrics.scoring.nbbox_iou import NbboxIouTuple from metrics.scoring.common.conversions import cast_to_dict +from metrics.scoring.nbbox_iou import NbboxIouTuple class DictNbboxIouTupleAggJaccard: @@ -22,9 +22,7 @@ def match(cls, responses, targets) -> float: num_keys = 0 total_score = 0 for key in all_keys: - total_score += NbboxIouTuple.match( - responses.get(key, []), targets.get(key, []) - ) + total_score += NbboxIouTuple.match(responses.get(key, []), targets.get(key, [])) num_keys += 1 return total_score / num_keys diff --git a/lmms_eval/tasks/megabench/metrics/scoring/dict_set_equality_agg_jaccard.py b/lmms_eval/tasks/megabench/metrics/scoring/dict_set_equality_agg_jaccard.py index a7a4b937..2dab9e2a 100644 --- a/lmms_eval/tasks/megabench/metrics/scoring/dict_set_equality_agg_jaccard.py +++ b/lmms_eval/tasks/megabench/metrics/scoring/dict_set_equality_agg_jaccard.py @@ -1,5 +1,6 @@ -from metrics.scoring.set_equality import SetEquality from metrics.scoring.common.conversions import cast_to_dict +from metrics.scoring.set_equality import SetEquality + class DictSetEqualityAggJaccard: """Calculates the average set equality across the dict. @@ -22,9 +23,7 @@ def match(cls, responses, targets) -> float: num_keys = 0 total_score = 0 for key in all_keys: - total_score += SetEquality.match( - responses.get(key, []), targets.get(key, []) - ) + total_score += SetEquality.match(responses.get(key, []), targets.get(key, [])) num_keys += 1 return total_score / num_keys diff --git a/lmms_eval/tasks/megabench/metrics/scoring/exact_str_match.py b/lmms_eval/tasks/megabench/metrics/scoring/exact_str_match.py index 8b51d63f..524f73ef 100644 --- a/lmms_eval/tasks/megabench/metrics/scoring/exact_str_match.py +++ b/lmms_eval/tasks/megabench/metrics/scoring/exact_str_match.py @@ -1,4 +1,5 @@ import re + from metrics.parsing.common.utils import extract_code_block_content diff --git a/lmms_eval/tasks/megabench/metrics/scoring/general_numerical_match.py b/lmms_eval/tasks/megabench/metrics/scoring/general_numerical_match.py index d033e81c..e3b600e2 100644 --- a/lmms_eval/tasks/megabench/metrics/scoring/general_numerical_match.py +++ b/lmms_eval/tasks/megabench/metrics/scoring/general_numerical_match.py @@ -1,12 +1,11 @@ -import re -from metrics.scoring.simple_str_match import SimpleStrMatch - -from sympy.parsing.latex import parse_latex import math import multiprocessing - +import re import signal +from metrics.scoring.simple_str_match import SimpleStrMatch +from sympy.parsing.latex import parse_latex + class TimeoutException(Exception): pass @@ -202,12 +201,7 @@ def match(cls, responses, targets) -> float: tgt = number_it(targets) if res is not None and tgt is not None: - if ( - isinstance(res, list) - and isinstance(tgt, list) - or isinstance(res, tuple) - and isinstance(tgt, tuple) - ): + if isinstance(res, list) and isinstance(tgt, list) or isinstance(res, tuple) and isinstance(tgt, tuple): score = float(compare_two_list(res, tgt)) else: score = float(compare_two_numbers(res, tgt)) diff --git a/lmms_eval/tasks/megabench/metrics/scoring/geo_proximity.py b/lmms_eval/tasks/megabench/metrics/scoring/geo_proximity.py index 171cb27a..da28e892 100644 --- a/lmms_eval/tasks/megabench/metrics/scoring/geo_proximity.py +++ b/lmms_eval/tasks/megabench/metrics/scoring/geo_proximity.py @@ -1,21 +1,21 @@ -from datetime import timedelta import functools import logging import math import random import ssl +from datetime import timedelta + +import requests_cache from geopy.adapters import ( RequestsAdapter, RequestsHTTPAdapter, RequestsHTTPWithSSLContextAdapter, - requests_available, _normalize_proxies, + requests_available, ) from geopy.distance import distance -from geopy.geocoders import Nominatim from geopy.extra.rate_limiter import RateLimiter -import requests_cache - +from geopy.geocoders import Nominatim error_logger = logging.getLogger("errorLogger") @@ -53,12 +53,7 @@ def __init__( pool_block=False, ): if not requests_available: - raise ImportError( - "`requests` must be installed in order to use RequestsAdapter. " - "If you have installed geopy via pip, you may use " - "this command to install requests: " - '`pip install "geopy[requests]"`.' - ) + raise ImportError("`requests` must be installed in order to use RequestsAdapter. " "If you have installed geopy via pip, you may use " "this command to install requests: " '`pip install "geopy[requests]"`.') proxies = _normalize_proxies(proxies) if ssl_context is None: # By default requests uses CA bundle from `certifi` package. @@ -77,9 +72,7 @@ def __init__( ssl_context = ssl.create_default_context() super().__init__(proxies=proxies, ssl_context=ssl_context) - self.session = requests_cache.CachedSession( - backend="sqlite", expire_after=timedelta(days=30) - ) + self.session = requests_cache.CachedSession(backend="sqlite", expire_after=timedelta(days=30)) self.session.trust_env = False # don't use system proxies self.session.proxies = proxies @@ -130,9 +123,7 @@ def calculate_proximity_score(guess_coords, actual_coords, k=100): MAX_RETRIES = 30 -geocode = RateLimiter( - geolocator.geocode, min_delay_seconds=GEOLOCATION_TIMEOUT, max_retries=MAX_RETRIES -) +geocode = RateLimiter(geolocator.geocode, min_delay_seconds=GEOLOCATION_TIMEOUT, max_retries=MAX_RETRIES) @functools.cache @@ -140,21 +131,15 @@ def try_geolocate(query): """Try to look up the location.""" location = geocode(query) if location is None: - error_logger.error( - f"Geolocation API request failed due to timeout: exceeded {MAX_RETRIES} retries!" - ) + error_logger.error(f"Geolocation API request failed due to timeout: exceeded {MAX_RETRIES} retries!") return location -def location_to_coords( - country: str, province_or_state: str, municipality: str -) -> tuple[float, float] | None: +def location_to_coords(country: str, province_or_state: str, municipality: str) -> tuple[float, float] | None: if country == "" or province_or_state == "" or municipality == "": return None """Convert the location to longitude and latitude.""" - location = geolocator.geocode( - query={"country": country, "state": province_or_state, "city": municipality} - ) + location = geolocator.geocode(query={"country": country, "state": province_or_state, "city": municipality}) if location is not None: return (location.latitude, location.longitude) # Try searching without the province/state, as it can be non-standard for some questions @@ -183,15 +168,11 @@ def match(cls, responses, targets) -> float: return 0 if guess_coords is None: - error_logger.error( - f"GeoProximityLocationDict: could not load co-ordinates for {responses=}" - ) + error_logger.error(f"GeoProximityLocationDict: could not load co-ordinates for {responses=}") return 0 actual_coords = location_to_coords(**targets) if actual_coords is None: - error_logger.error( - f"GeoProximityLocationDict: could not load co-ordinates for {targets=}" - ) + error_logger.error(f"GeoProximityLocationDict: could not load co-ordinates for {targets=}") return 0 return calculate_proximity_score(guess_coords, actual_coords) diff --git a/lmms_eval/tasks/megabench/metrics/scoring/gleu.py b/lmms_eval/tasks/megabench/metrics/scoring/gleu.py index 6fb514fb..7390febb 100644 --- a/lmms_eval/tasks/megabench/metrics/scoring/gleu.py +++ b/lmms_eval/tasks/megabench/metrics/scoring/gleu.py @@ -1,4 +1,5 @@ from numbers import Number + import jieba from nltk.translate.gleu_score import sentence_gleu diff --git a/lmms_eval/tasks/megabench/metrics/scoring/jaccard.py b/lmms_eval/tasks/megabench/metrics/scoring/jaccard.py index 9bfb8600..6d69953e 100644 --- a/lmms_eval/tasks/megabench/metrics/scoring/jaccard.py +++ b/lmms_eval/tasks/megabench/metrics/scoring/jaccard.py @@ -30,44 +30,19 @@ def match(cls, responses, targets) -> float: targets = cast_to_set(targets) if isinstance(list(targets)[0], str): - new_responses = { - item.lower() if isinstance(item, str) else str(item).lower() - for item in responses - } + new_responses = {item.lower() if isinstance(item, str) else str(item).lower() for item in responses} new_targets = {item.lower() for item in targets} elif isinstance(list(targets)[0], tuple): new_responses = set() new_targets = set() try: for res in responses: - new_res = tuple( - [ - item.lower() - .replace(" ", "") - .replace("-", "") - .replace("\n", "") - .replace("\t", "") - .replace("_", "") - .replace(".", "") - for item in res - ] - ) + new_res = tuple([item.lower().replace(" ", "").replace("-", "").replace("\n", "").replace("\t", "").replace("_", "").replace(".", "") for item in res]) new_responses.add(new_res) except: # the data type of the response might be wrong, return 0 in this case return 0 for tgt in targets: - new_tgt = tuple( - [ - item.lower() - .replace(" ", "") - .replace("-", "") - .replace("\n", "") - .replace("\t", "") - .replace("_", "") - .replace(".", "") - for item in tgt - ] - ) + new_tgt = tuple([item.lower().replace(" ", "").replace("-", "").replace("\n", "").replace("\t", "").replace("_", "").replace(".", "") for item in tgt]) new_targets.add(new_tgt) else: return 0 diff --git a/lmms_eval/tasks/megabench/metrics/scoring/latex_expr_equality.py b/lmms_eval/tasks/megabench/metrics/scoring/latex_expr_equality.py index 8b28aee8..5c460efb 100644 --- a/lmms_eval/tasks/megabench/metrics/scoring/latex_expr_equality.py +++ b/lmms_eval/tasks/megabench/metrics/scoring/latex_expr_equality.py @@ -1,10 +1,11 @@ import re -from sympy.parsing.latex import parse_latex -from sympy.parsing.latex.errors import LaTeXParsingError -from sympy.core.sympify import SympifyError +import signal + from metrics.scoring.common.transformations import normalize_latex from metrics.scoring.simple_str_match import SimpleStrMatch -import signal +from sympy.core.sympify import SympifyError +from sympy.parsing.latex import parse_latex +from sympy.parsing.latex.errors import LaTeXParsingError class TimeoutException(Exception): diff --git a/lmms_eval/tasks/megabench/metrics/scoring/mse.py b/lmms_eval/tasks/megabench/metrics/scoring/mse.py index 3489d610..fb504e21 100644 --- a/lmms_eval/tasks/megabench/metrics/scoring/mse.py +++ b/lmms_eval/tasks/megabench/metrics/scoring/mse.py @@ -1,8 +1,9 @@ import ast -import numpy as np import math -from metrics.scoring.common.metrics import mse + +import numpy as np from metrics.scoring.common.conversions import str_to_list +from metrics.scoring.common.metrics import mse class MSE: diff --git a/lmms_eval/tasks/megabench/metrics/scoring/multi_ref_phrase.py b/lmms_eval/tasks/megabench/metrics/scoring/multi_ref_phrase.py index 8bbbf3b8..d63b13c1 100644 --- a/lmms_eval/tasks/megabench/metrics/scoring/multi_ref_phrase.py +++ b/lmms_eval/tasks/megabench/metrics/scoring/multi_ref_phrase.py @@ -1,4 +1,5 @@ from numbers import Number + from metrics.scoring.common.conversions import str_to_iterable from metrics.scoring.simple_str_match import SimpleStrMatch diff --git a/lmms_eval/tasks/megabench/metrics/scoring/nbbox_iou.py b/lmms_eval/tasks/megabench/metrics/scoring/nbbox_iou.py index 5a122621..cd183d24 100644 --- a/lmms_eval/tasks/megabench/metrics/scoring/nbbox_iou.py +++ b/lmms_eval/tasks/megabench/metrics/scoring/nbbox_iou.py @@ -1,8 +1,9 @@ -import logging import ast +import logging + +import numpy as np from metrics.scoring.common.conversions import str_to_bboxes from metrics.scoring.common.metrics import calculate_iou -import numpy as np class NbboxIouTuple: diff --git a/lmms_eval/tasks/megabench/metrics/scoring/near_str_match.py b/lmms_eval/tasks/megabench/metrics/scoring/near_str_match.py index 727898ef..92332b22 100644 --- a/lmms_eval/tasks/megabench/metrics/scoring/near_str_match.py +++ b/lmms_eval/tasks/megabench/metrics/scoring/near_str_match.py @@ -18,6 +18,4 @@ def match(response, correct_answer: str, threshold=0.9) -> int: return 0 response = approximate(response) correct_answer = approximate(correct_answer) - return rapidfuzz.distance.DamerauLevenshtein.normalized_similarity( - response, correct_answer, score_cutoff=threshold - ) + return rapidfuzz.distance.DamerauLevenshtein.normalized_similarity(response, correct_answer, score_cutoff=threshold) diff --git a/lmms_eval/tasks/megabench/metrics/scoring/nli_entailment.py b/lmms_eval/tasks/megabench/metrics/scoring/nli_entailment.py index 71a29042..0e69bf82 100644 --- a/lmms_eval/tasks/megabench/metrics/scoring/nli_entailment.py +++ b/lmms_eval/tasks/megabench/metrics/scoring/nli_entailment.py @@ -1,11 +1,8 @@ import torch from transformers import pipeline - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -pipe = pipeline( - "text-classification", model="microsoft/deberta-large-mnli", device=device -) +pipe = pipeline("text-classification", model="microsoft/deberta-large-mnli", device=device) class NliEntailment: diff --git a/lmms_eval/tasks/megabench/metrics/scoring/normalized_similarity_damerau_levenshtein.py b/lmms_eval/tasks/megabench/metrics/scoring/normalized_similarity_damerau_levenshtein.py index 110ed846..cc1b6de8 100644 --- a/lmms_eval/tasks/megabench/metrics/scoring/normalized_similarity_damerau_levenshtein.py +++ b/lmms_eval/tasks/megabench/metrics/scoring/normalized_similarity_damerau_levenshtein.py @@ -9,6 +9,4 @@ def match(response, correct_answer) -> int: """Normalized indel similarityuiio do between targets and responses.""" if not isinstance(response, str) and isinstance(correct_answer, str): return 0 - return rapidfuzz.distance.DamerauLevenshtein.normalized_similarity( - response, correct_answer - ) + return rapidfuzz.distance.DamerauLevenshtein.normalized_similarity(response, correct_answer) diff --git a/lmms_eval/tasks/megabench/metrics/scoring/program_judge.py b/lmms_eval/tasks/megabench/metrics/scoring/program_judge.py index e9cb2ab6..50412923 100644 --- a/lmms_eval/tasks/megabench/metrics/scoring/program_judge.py +++ b/lmms_eval/tasks/megabench/metrics/scoring/program_judge.py @@ -1,15 +1,16 @@ import io -import pathlib import json import multiprocessing -from unittest.mock import patch +import pathlib from multiprocessing.queues import Empty +from unittest.mock import patch BIG_BENCH_PATH = pathlib.Path(__file__).resolve().parent.parent.parent class ProgramJudge: """Python Program Judging.""" + @staticmethod def match(response: str, eval_context: str) -> int: # Load all test cases from the benchmark_tasks directory @@ -36,9 +37,7 @@ def __init__(self, user_code, test_cases, timeout=2, verbose=True): def run_user_code(self, input_data): input_str = "\n".join(input_data) + "\n" output_queue = multiprocessing.Queue() - process = multiprocessing.Process( - target=self.target, args=(output_queue, input_str) - ) + process = multiprocessing.Process(target=self.target, args=(output_queue, input_str)) process.start() process.join(self.timeout) @@ -85,9 +84,7 @@ def run_tests(self): results = [] for i, test_case in enumerate(self.test_cases, 1): - result, output = self.evaluate_test_case( - test_case["input"], test_case["expected"] - ) + result, output = self.evaluate_test_case(test_case["input"], test_case["expected"]) test_result = { "response": self.user_code, @@ -104,9 +101,7 @@ def run_tests(self): passed_tests += 1 else: if self.verbose: - print( - f"Test case {i}: Failed - Expected {test_case['expected']} but got {output}" - ) + print(f"Test case {i}: Failed - Expected {test_case['expected']} but got {output}") score = passed_tests / total_tests if total_tests > 0 else 0 return score, results diff --git a/lmms_eval/tasks/megabench/metrics/scoring/sacrebleu_bleu.py b/lmms_eval/tasks/megabench/metrics/scoring/sacrebleu_bleu.py index 63a2a265..ce57118d 100644 --- a/lmms_eval/tasks/megabench/metrics/scoring/sacrebleu_bleu.py +++ b/lmms_eval/tasks/megabench/metrics/scoring/sacrebleu_bleu.py @@ -1,4 +1,5 @@ from numbers import Number + import sacrebleu @@ -11,9 +12,7 @@ def match(response, correct_answer) -> Number: if isinstance(response, str) and isinstance(correct_answer, str): resp = [response] corr = [correct_answer] - elif isinstance(response, (list, tuple)) and isinstance( - correct_answer, (list, tuple) - ): + elif isinstance(response, (list, tuple)) and isinstance(correct_answer, (list, tuple)): resp = tuple(response) corr = tuple(correct_answer) else: diff --git a/lmms_eval/tasks/megabench/metrics/scoring/sequence_equality.py b/lmms_eval/tasks/megabench/metrics/scoring/sequence_equality.py index 2df1493e..d8906f5b 100644 --- a/lmms_eval/tasks/megabench/metrics/scoring/sequence_equality.py +++ b/lmms_eval/tasks/megabench/metrics/scoring/sequence_equality.py @@ -1,6 +1,7 @@ -from metrics.scoring.common.conversions import str_to_list from numbers import Number +from metrics.scoring.common.conversions import str_to_list + class SequenceEquality: """Determines how much of the first part of the list @@ -30,9 +31,7 @@ def match(cls, responses, targets) -> int: responses = str_to_list(responses) targets = str_to_list(targets) - responses = [ - item.lower() if isinstance(item, str) else str(item) for item in responses - ] + responses = [item.lower() if isinstance(item, str) else str(item) for item in responses] targets = [item.lower() for item in targets] return 1 if responses == targets else 0 diff --git a/lmms_eval/tasks/megabench/metrics/scoring/set_equality.py b/lmms_eval/tasks/megabench/metrics/scoring/set_equality.py index 2ca06fb9..5a26b746 100644 --- a/lmms_eval/tasks/megabench/metrics/scoring/set_equality.py +++ b/lmms_eval/tasks/megabench/metrics/scoring/set_equality.py @@ -37,12 +37,8 @@ def match(cls, responses, targets) -> int: responses = responses.replace("\\n", "\n") responses_set = set(responses.split("\n")) targets_set = set(targets.split("\n")) - responses_set = { - item.lower() if isinstance(item, str) else item for item in responses_set - } - targets_set = { - item.lower() if isinstance(item, str) else item for item in targets_set - } + responses_set = {item.lower() if isinstance(item, str) else item for item in responses_set} + targets_set = {item.lower() if isinstance(item, str) else item for item in targets_set} return 1 if responses_set == targets_set else 0 @@ -56,10 +52,6 @@ class StringSetEqualityCommaSplit: def match(cls, responses, targets) -> int: responses_set = str_to_set(responses) targets_set = str_to_set(targets) - responses_set = { - item.lower() if isinstance(item, str) else item for item in responses_set - } - targets_set = { - item.lower() if isinstance(item, str) else item for item in targets_set - } + responses_set = {item.lower() if isinstance(item, str) else item for item in responses_set} + targets_set = {item.lower() if isinstance(item, str) else item for item in targets_set} return 1 if responses_set == targets_set else 0 diff --git a/lmms_eval/tasks/megabench/metrics/scoring/simple_str_match.py b/lmms_eval/tasks/megabench/metrics/scoring/simple_str_match.py index 5c0f24e7..33fe2c9e 100644 --- a/lmms_eval/tasks/megabench/metrics/scoring/simple_str_match.py +++ b/lmms_eval/tasks/megabench/metrics/scoring/simple_str_match.py @@ -9,21 +9,7 @@ def match(response, correct_answer: str) -> int: """Simple string match between response and correct_answer.""" if not isinstance(response, str): response = str(response) # If it is JSON-like - response = ( - response.replace(" ", "") - .replace("-", "") - .replace("\n", "") - .replace("\t", "") - .replace(".", "") - .lower() - ) - correct_answer = ( - correct_answer.replace(" ", "") - .replace("-", "") - .replace("\n", "") - .replace("\t", "") - .replace(".", "") - .lower() - ) + response = response.replace(" ", "").replace("-", "").replace("\n", "").replace("\t", "").replace(".", "").lower() + correct_answer = correct_answer.replace(" ", "").replace("-", "").replace("\n", "").replace("\t", "").replace(".", "").lower() return ExactStrMatch.match(response, correct_answer) diff --git a/lmms_eval/tasks/megabench/metrics/scoring/symbolic_planning.py b/lmms_eval/tasks/megabench/metrics/scoring/symbolic_planning.py index 99900f11..f321fc6e 100644 --- a/lmms_eval/tasks/megabench/metrics/scoring/symbolic_planning.py +++ b/lmms_eval/tasks/megabench/metrics/scoring/symbolic_planning.py @@ -52,9 +52,7 @@ def parse_pddl_attr_from_string( if len(s_attr) == 1: return "", [] elif len(s_attr) == 2: - outer_str, inner_str, _ = parse_outer_inner_str( - s_attr[1], attr_ender, inner_starter, inner_ender - ) + outer_str, inner_str, _ = parse_outer_inner_str(s_attr[1], attr_ender, inner_starter, inner_ender) return attr_starter + outer_str, inner_str else: matched_dict = {} @@ -63,18 +61,14 @@ def parse_pddl_attr_from_string( while len(s.split(attr_starter)) > 1: s = s.split(attr_starter, 1)[1] name = re.split(r"\s+", s.strip())[0] - outer_str, inner_str, end_point = parse_outer_inner_str( - s, attr_ender, inner_starter, inner_ender - ) + outer_str, inner_str, end_point = parse_outer_inner_str(s, attr_ender, inner_starter, inner_ender) outer_list.append(attr_starter + outer_str) matched_dict[name] = inner_str s = s[end_point:] else: for seg in s_attr[1:]: name = re.split(r"\s+", seg.strip())[0] - outer_str, inner_str, _ = parse_outer_inner_str( - seg, attr_ender, inner_starter, inner_ender - ) + outer_str, inner_str, _ = parse_outer_inner_str(seg, attr_ender, inner_starter, inner_ender) outer_list.append(attr_starter + outer_str) matched_dict[name] = inner_str return outer_list, matched_dict @@ -120,19 +114,13 @@ class Domain: def __init__(self, domain_pddl): # Domain files self.domain_pddl = domain_pddl - self.action_name, self.action_params, self.action_params_dict = ( - self.get_domain_action() - ) + self.action_name, self.action_params, self.action_params_dict = self.get_domain_action() self.gt_cond_dict = self.parse_gt_pre_post_cond() def get_domain_action(self): - action_pddl_str_list, all_actions = parse_pddl_attr_from_string( - self.domain_pddl, attr_starter="(:action" - ) + action_pddl_str_list, all_actions = parse_pddl_attr_from_string(self.domain_pddl, attr_starter="(:action") action_name, action_params, action_params_dict = [], [], [] - for action_pddl_str, (name, action_attr) in zip( - action_pddl_str_list, all_actions.items() - ): + for action_pddl_str, (name, action_attr) in zip(action_pddl_str_list, all_actions.items()): assert len(action_attr) == 3 param_str, pre_cond_str, post_cond_str = action_attr action_name.append(name) @@ -171,9 +159,7 @@ def construct_param_to_obj(domain, action): def state_transition(current_state, effects, param_to_obj): for obj_cond in effects: for param in param_to_obj: - obj_cond = re.sub( - r"\?{}(?=[^\w-])".format(param), param_to_obj[param], obj_cond - ) + obj_cond = re.sub(r"\?{}(?=[^\w-])".format(param), param_to_obj[param], obj_cond) _, reversed_cond = parse_pddl_attr_from_string(obj_cond, attr_starter="(not ") if reversed_cond: assert len(reversed_cond) == 1 @@ -187,12 +173,8 @@ def state_transition(current_state, effects, param_to_obj): def check_pre_conds_satisfy(current_state, pre_conds, param_to_obj): for obj_cond in pre_conds: for param in param_to_obj: - obj_cond = re.sub( - r"\?{}(?=[^\w-])".format(param), param_to_obj[param], obj_cond - ) - if (obj_cond.startswith("(not ") and obj_cond in current_state) or ( - not obj_cond.startswith("(not ") and obj_cond not in current_state - ): + obj_cond = re.sub(r"\?{}(?=[^\w-])".format(param), param_to_obj[param], obj_cond) + if (obj_cond.startswith("(not ") and obj_cond in current_state) or (not obj_cond.startswith("(not ") and obj_cond not in current_state): return False return True @@ -217,9 +199,7 @@ def match(cls, response, eval_context): case tuple() | list(): candidates = list(response) case _: - raise ValueError( - f"`response` has unsupported type: {type(response)=}, {response=}" - ) + raise ValueError(f"`response` has unsupported type: {type(response)=}, {response=}") cand_traj = [cand_a.strip() for cand_a in candidates if cand_a.startswith("(")] try: task_pddl = eval_context["task_pddl"] @@ -234,22 +214,16 @@ def match(cls, response, eval_context): ## State transitions and check if satisfy the preconditions for cand_a in cand_traj: param_to_obj, a_name = construct_param_to_obj(domain, cand_a) - if not check_pre_conds_satisfy( - cur_state, domain.gt_cond_dict[f"{a_name}_pre"], param_to_obj - ): + if not check_pre_conds_satisfy(cur_state, domain.gt_cond_dict[f"{a_name}_pre"], param_to_obj): print(f"precondition of the action {cand_a} is not satisfied!") score = 0 break - cur_state = state_transition( - cur_state, domain.gt_cond_dict[f"{a_name}_post"], param_to_obj - ) + cur_state = state_transition(cur_state, domain.gt_cond_dict[f"{a_name}_post"], param_to_obj) ## Check if goal conditions are reached in the final state if score == 1: for g_state in goal_state: - if (g_state.startswith("(not ") and g_state in cur_state) or ( - not g_state.startswith("(not ") and g_state not in cur_state - ): + if (g_state.startswith("(not ") and g_state in cur_state) or (not g_state.startswith("(not ") and g_state not in cur_state): print(f"goal state {g_state} is not reached!") score = 0 break diff --git a/lmms_eval/tasks/megabench/metrics/scoring/vlm_as_judge.py b/lmms_eval/tasks/megabench/metrics/scoring/vlm_as_judge.py index 4a67c002..9d804281 100644 --- a/lmms_eval/tasks/megabench/metrics/scoring/vlm_as_judge.py +++ b/lmms_eval/tasks/megabench/metrics/scoring/vlm_as_judge.py @@ -1,10 +1,11 @@ import abc import base64 import os -import requests import re from io import BytesIO from mimetypes import guess_type + +import requests from PIL import Image @@ -30,22 +31,22 @@ def __init__( self.api_key = os.getenv("OPENAI_API_KEY") self.model = model self.resize = resize - self.max_side = max_side - + self.max_side = max_side + if os.getenv("MEGABENCH_OPEN_API_KEY") is not None: self.api_key = os.getenv("MEGABENCH_OPEN_API_KEY") self.url = os.getenv("MEGABENCH_OPEN_API_URL") if os.getenv("MEGABENCH_OPEN_API_MODEL") is not None: self.model = os.getenv("MEGABENCH_OPEN_API_MODEL") assert self.url, "You must set up the API URL for evaluating the Open tasks using your own API" - + @staticmethod def _update_image_path(image_path): hf_home = os.getenv("HF_HOME", "~/.cache/huggingface") base_cache_dir = os.path.expanduser(hf_home) - image_path = image_path.replace('./data/', f'{base_cache_dir}/megabench_data/data/') + image_path = image_path.replace("./data/", f"{base_cache_dir}/megabench_data/data/") return image_path - + def create_image_content(self, image_path): image_path = self._update_image_path(image_path) base64_image, mime_type = self.encode_image(image_path) @@ -53,17 +54,17 @@ def create_image_content(self, image_path): "type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{base64_image}"}, } - + @property def url(self) -> str: - """The server URL. We use OpenAI API by default. """ - return self._url if hasattr(self, '_url') else "https://api.openai.com/v1/chat/completions" + """The server URL. We use OpenAI API by default.""" + return self._url if hasattr(self, "_url") else "https://api.openai.com/v1/chat/completions" @url.setter def url(self, value: str) -> None: """Set the server URL.""" self._url = value - + @staticmethod def _rgba_to_rgb(image): background = Image.new("RGBA", image.size, (255, 255, 255, 255)) @@ -76,13 +77,13 @@ def _resize_image(self, image): int(image.size[1] * resize_scale), ) return image.resize(new_size) - + def _encode_image(self, image, image_format): with BytesIO() as output: image.convert("RGB").save(output, format=image_format) base64_encoded_data = base64.b64encode(output.getvalue()).decode("utf-8") return base64_encoded_data - + def encode_image(self, image_path, max_side=None): mime_type, _ = guess_type(image_path) if mime_type is None: @@ -102,9 +103,7 @@ def encode_image(self, image_path, max_side=None): return encoded_image, mime_type - def prepare_eval_prompt( - self, reference, response, images, question, eval_context=None - ): + def prepare_eval_prompt(self, reference, response, images, question, eval_context=None): content = [] if self.judge_model_type == "with image": for image_path in images: @@ -134,9 +133,7 @@ def query(self, reference_info, response, images, question, eval_context=None): "Authorization": f"Bearer {self.api_key}", } - context = self.prepare_eval_prompt( - reference_info, response, images, question, eval_context - ) + context = self.prepare_eval_prompt(reference_info, response, images, question, eval_context) query_payload = { "model": self.model, @@ -154,19 +151,14 @@ def query(self, reference_info, response, images, question, eval_context=None): ) response_ = response.json() except (requests.exceptions.JSONDecodeError, requests.exceptions.ConnectionError) as e: - print(f'Error in requests: {e}') - print('Retry...') + print(f"Error in requests: {e}") + print("Retry...") continue if "error" in response_: error_info = response_["error"] - print( - f"Got error with type: {error_info['type']}. Message: {error_info['message']}" - ) - if ( - error_info["message"] - == "Sorry! We've encountered an issue with repetitive patterns in your prompt. Please try again with a different prompt." - ): + print(f"Got error with type: {error_info['type']}. Message: {error_info['message']}") + if error_info["message"] == "Sorry! We've encountered an issue with repetitive patterns in your prompt. Please try again with a different prompt.": print(query_payload) # If the model's response has too many repetitive tokens, then we give it a score of 0. print(f"gpt-4o judge query failed...") @@ -183,9 +175,7 @@ def query(self, reference_info, response, images, question, eval_context=None): choices = response_data["choices"] if choices and "message" in choices[0]: message_content = choices[0]["message"]["content"] - print( - f"gpt-4o judge results: {message_content}; tokens:{total_tokens}" - ) + print(f"gpt-4o judge results: {message_content}; tokens:{total_tokens}") else: print(f"gpt-4o judge query failed...") message_content = "" @@ -220,11 +210,7 @@ def parse_results(self, eval_results): return score / 10.0, info_str - def match( - self, response, reference_dict, images, question, eval_context=None - ) -> int: - eval_results = self.model.query( - reference_dict, response, images, question, eval_context - ) + def match(self, response, reference_dict, images, question, eval_context=None) -> int: + eval_results = self.model.query(reference_dict, response, images, question, eval_context) score = self.parse_results(eval_results) return score diff --git a/lmms_eval/tasks/megabench/metrics/scoring/xml_nbbox_iou.py b/lmms_eval/tasks/megabench/metrics/scoring/xml_nbbox_iou.py index 15c38853..d8b71d80 100644 --- a/lmms_eval/tasks/megabench/metrics/scoring/xml_nbbox_iou.py +++ b/lmms_eval/tasks/megabench/metrics/scoring/xml_nbbox_iou.py @@ -1,8 +1,9 @@ import logging -from metrics.scoring.common.metrics import calculate_iou -from metrics.scoring.common.conversions import parse_bboxes_from_xml from numbers import Number +from metrics.scoring.common.conversions import parse_bboxes_from_xml +from metrics.scoring.common.metrics import calculate_iou + class XmlNbboxIouSingle: """Calculates the IoU of bounding box. @@ -13,7 +14,6 @@ class XmlNbboxIouSingle: @classmethod def match(cls, responses, targets) -> float: - logging.debug(f"{responses=}, {targets=}") if not isinstance(responses, (tuple | list)): responses = parse_bboxes_from_xml(responses) diff --git a/lmms_eval/tasks/megabench/metrics/scoring/xml_norm_point_in_bbox.py b/lmms_eval/tasks/megabench/metrics/scoring/xml_norm_point_in_bbox.py index 7c0129ca..6c348ea8 100644 --- a/lmms_eval/tasks/megabench/metrics/scoring/xml_norm_point_in_bbox.py +++ b/lmms_eval/tasks/megabench/metrics/scoring/xml_norm_point_in_bbox.py @@ -12,12 +12,8 @@ class XmlNormPointInBbox: def match(cls, responses, eval_context) -> int: """Determine if the point is in the bounding box and return which bounding box was matched, if any.""" - bounding_box_has_match = { - bbox: False for bbox in eval_context["bounding_boxes"] - } - bounding_boxes = [ - str_to_bboxes(bbox_str)[0] for bbox_str in eval_context["bounding_boxes"] - ] + bounding_box_has_match = {bbox: False for bbox in eval_context["bounding_boxes"]} + bounding_boxes = [str_to_bboxes(bbox_str)[0] for bbox_str in eval_context["bounding_boxes"]] assert bounding_boxes if not isinstance(responses, (tuple | list)): diff --git a/lmms_eval/tasks/megabench/utils.py b/lmms_eval/tasks/megabench/utils.py index 04a617b8..589b4a3a 100644 --- a/lmms_eval/tasks/megabench/utils.py +++ b/lmms_eval/tasks/megabench/utils.py @@ -1,16 +1,19 @@ +import json import os -import yaml -from pathlib import Path -from itertools import chain from ast import literal_eval from collections import defaultdict -import json - +from itertools import chain +from pathlib import Path +import yaml from loguru import logger as eval_logger -from lmms_eval.tasks._task_utils.file_utils import generate_submission_file -from lmms_eval.tasks.megabench.image_video_utils import read_image, is_video_file, process_text_and_mixed_media +from lmms_eval.tasks._task_utils.file_utils import generate_submission_file +from lmms_eval.tasks.megabench.image_video_utils import ( + is_video_file, + process_text_and_mixed_media, + read_image, +) hf_home = os.getenv("HF_HOME", "~/.cache/huggingface") base_cache_dir = os.path.expanduser(hf_home) @@ -66,7 +69,7 @@ def megabench_doc_to_visual(doc, lmms_eval_specific_kwargs=None): else: # mixed video and image input, convert video to image frames cache_dir = os.path.join(base_cache_dir, cache_name) _, medias = process_text_and_mixed_media(doc, lmms_eval_specific_kwargs["max_video_subsample_frame"], cache_dir) - + return medias @@ -111,7 +114,7 @@ def megabench_aggregate_results_for_submission(results, args): all_query_response.append(sample_response) task_result["query_response"] = all_query_response submission_results.append(task_result) - + submission_path = generate_submission_file(f"{args.tasks}_all_query_responses.json", args) with open(submission_path, "w", encoding="utf-8") as fd: json.dump(submission_results, fd, indent=4) diff --git a/lmms_eval/tasks/mmmu/mmmu_val_interleaved.yaml b/lmms_eval/tasks/mmmu/mmmu_val_interleaved.yaml new file mode 100755 index 00000000..a94e6e5f --- /dev/null +++ b/lmms_eval/tasks/mmmu/mmmu_val_interleaved.yaml @@ -0,0 +1,16 @@ +dataset_path: lmms-lab/MMMU +task: "mmmu_val_interleaved" +test_split: validation +output_type: generate_until +doc_to_visual: !function utils.mmmu_doc_to_visual +doc_to_text: !function utils.mmmu_doc_to_text_interleave +doc_to_target: "answer" +# The return value of process_results will be used by metrics +process_results: !function utils.mmmu_process_results + +metric_list: + - metric: mmmu_acc + aggregation: !function utils.mmmu_aggregate_results + higher_is_better: true + +include: _default_template_yaml \ No newline at end of file diff --git a/lmms_eval/tasks/mmmu/utils.py b/lmms_eval/tasks/mmmu/utils.py index 789bd5a6..6e967f02 100755 --- a/lmms_eval/tasks/mmmu/utils.py +++ b/lmms_eval/tasks/mmmu/utils.py @@ -60,6 +60,12 @@ def mmmu_doc_to_text(doc): return question +def mmmu_doc_to_text_interleave(doc): + question = construct_prompt(doc) + question = re.sub(r"", lambda m: f"", question) + return question + + def mmmu_doc_to_visual(doc): prompt = construct_prompt(doc) image_tokens = re.findall(r"", prompt)