From 74840f7883b4bea56d9595123be1107d882ebbd1 Mon Sep 17 00:00:00 2001 From: mertyg Date: Fri, 12 Jul 2024 09:00:49 -0700 Subject: [PATCH 1/6] Add multimodal tasks --- textgrad/tasks/multimodal/__init__.py | 16 ++ textgrad/tasks/multimodal/mathvista.py | 300 +++++++++++++++++++++++++ textgrad/tasks/multimodal/scienceqa.py | 227 +++++++++++++++++++ 3 files changed, 543 insertions(+) create mode 100644 textgrad/tasks/multimodal/__init__.py create mode 100644 textgrad/tasks/multimodal/mathvista.py create mode 100644 textgrad/tasks/multimodal/scienceqa.py diff --git a/textgrad/tasks/multimodal/__init__.py b/textgrad/tasks/multimodal/__init__.py new file mode 100644 index 0000000..267779f --- /dev/null +++ b/textgrad/tasks/multimodal/__init__.py @@ -0,0 +1,16 @@ +from textgrad.engine import EngineLM + + +def load_multimodal_instance_task(task_name: str, evaluation_api: EngineLM, *args, **kwargs): + if task_name == "mathvista": + from textgrad.tasks.multimodal.mathvista import MathVistaDataset + test_set = MathVistaDataset(evaluation_api=evaluation_api, split="testmini", *args, **kwargs) + return test_set + + elif task_name == "scienceqa": + from textgrad.tasks.multimodal.scienceqa import ScienceQADataset + test_set = ScienceQADataset(evaluation_api=evaluation_api, split="test", *args, **kwargs) + return test_set + + else: + raise ValueError(f"Instance task {task_name} not found.") diff --git a/textgrad/tasks/multimodal/mathvista.py b/textgrad/tasks/multimodal/mathvista.py new file mode 100644 index 0000000..ce88220 --- /dev/null +++ b/textgrad/tasks/multimodal/mathvista.py @@ -0,0 +1,300 @@ +import re +import io +import platformdirs +from PIL import Image + +from textgrad.tasks.base import Dataset +from textgrad.loss import ImageQALoss +from textgrad.variable import Variable +try: + from Levenshtein import distance +except ImportError: + raise ImportError("Please install the Levenshtein package using 'pip install python-Levenshtein' to use mathvista.") + +def compress_image(decoded_image, max_size_bytes=3.6*1024*1024): + # Convert image to RGB if it's in a mode that JPEG does not support + if decoded_image.mode not in ['RGB', 'L']: + decoded_image = decoded_image.convert('RGB') + + buffer = io.BytesIO() + decoded_image.save(buffer, format='JPEG') + size = buffer.tell() + + if size <= max_size_bytes: + buffer.seek(0) + return buffer.getvalue() + + width, height = decoded_image.size + while size > max_size_bytes: + print(f"Compressing image to {width}x{height}...") + width = int(width * 0.9) + height = int(height * 0.9) + resized_image = decoded_image.resize((width, height), Image.LANCZOS) + + buffer = io.BytesIO() + resized_image.save(buffer, format='JPEG') + size = buffer.tell() + + if width <= 1 or height <= 1: + raise ValueError("Unable to compress image to the desired size without excessive loss of resolution") + + buffer.seek(0) + return buffer.getvalue() + + +def verify_extraction(extraction): + extraction = extraction.strip() + if extraction == "" or extraction == None: + return False + return True + + +def create_test_prompt(demo_prompt, query, response): + demo_prompt = demo_prompt.strip() + test_prompt = f"{query}\n\n{response}" + full_prompt = f"{demo_prompt}\n\n{test_prompt}\n\nExtracted answer: " + return full_prompt + + +def extract_answer(response, problem, quick_extract=False): + question_type = problem['question_type'] + answer_type = problem['answer_type'] + choices = problem['choices'] + query = problem['query'] + pid = problem['pid'] + + if response == "": + return "" + + if question_type == 'multi_choice' and response in choices: + return response + + if answer_type == "integer": + try: + extraction = int(response) + return str(extraction) + except: + pass + + if answer_type == "float": + try: + extraction = str(float(response)) + return extraction + except: + pass + + # quick extraction + if quick_extract: + print("Quickly extracting answer...") + try: + result = re.search(r'The answer is "(.*)"\.', response) + if result: + extraction = result.group(1) + return extraction + except: + pass + + else: + raise NotImplementedError("Extraction using LLMs are to-be-implemented.") + + +def get_most_similar(prediction, choices): + """ + Use the Levenshtein distance (or edit distance) to determine which of the choices is most similar to the given prediction + """ + distances = [distance(prediction, choice) for choice in choices] + ind = distances.index(min(distances)) + return choices[ind] + # return min(choices, key=lambda choice: distance(prediction, choice)) + + +def normalize_extracted_answer(extraction, question_data): + """ + Normalize the extracted answer to match the answer type + """ + choices = question_data["choices"] + question_type = question_data["question_type"] + answer_type = question_data["answer_type"] + precision = question_data["precision"] + + if question_type == 'multi_choice': + # make sure the extraction is a string + if isinstance(extraction, str): + extraction = extraction.strip() + else: + try: + extraction = str(extraction) + except: + extraction = "" + + # extract "A" from "(A) text" + letter = re.findall(r'\(([a-zA-Z])\)', extraction) + if len(letter) > 0: + extraction = letter[0].upper() + + options = [chr(ord('A') + i) for i in range(len(choices))] + + if extraction in options: + # convert option letter to text, e.g. "A" -> "text" + ind = options.index(extraction) + extraction = choices[ind] + else: + # select the most similar option + extraction = get_most_similar(extraction, choices) + assert extraction in choices + + elif answer_type == 'integer': + try: + extraction = str(int(float(extraction))) + except: + extraction = None + + elif answer_type == 'float': + try: + extraction = str(round(float(extraction), int(precision))) + except: + extraction = None + + elif answer_type == 'list': + try: + extraction = str(extraction) + except: + extraction = None + + return extraction + + +def safe_equal(prediction, answer): + """ + Check if the prediction is equal to the answer, even if they are of different types + """ + try: + if prediction == answer: + return True + return False + except Exception as e: + print(e) + return False + + +class MathVistaDataset(Dataset): + def __init__(self, evaluation_api:str, root: str=None, split: str="testmini", task_instruction: str=None, evaluation_instruction: str=None, *args, **kwargs): + """MathVista dataset from HF.""" + from datasets import load_dataset + if root is None: + root = platformdirs.user_cache_dir("textgrad") + self.root = root + assert split in ["testmini", "test"] + self.data = load_dataset("AI4Math/MathVista", cache_dir=root, split=split) + self.split = split + self.evaluation_api = evaluation_api + self.anwer_extraction_openai_engine = "gpt-3.5-turbo" # robust enough for answer extraction + self.task_instruction = self.get_default_task_instruction(task_instruction) # NOTE: check the task instruction + self.evaluation_instruction = self.get_default_evaluation_instruction(evaluation_instruction) # NOTE: check the evaluation instruction + + def __getitem__(self, index): + row = self.data[index] + pid = row["pid"] + # question = row["question"] + # image = row["image"] + decoded_image = row["decoded_image"] + choices = row["choices"] + unit = row["unit"] + precision = row["precision"] + answer = row["answer"] + question_type = row["question_type"] + answer_type = row["answer_type"] + metadata = row["metadata"] + query = row["query"] + query = f"{self.task_instruction}\n{query}" # NOTE: Add the task description + + # NOTE: convert image to bytes + if "claude" in self.evaluation_api.model_string: + # TODO @lupantech This does not seem neat. + image_bytes = compress_image(decoded_image) + else: + buffer = io.BytesIO() + decoded_image.save(buffer, format='png') + image_bytes = buffer.getvalue() + buffer.close() + + # NOTE: ques_data stores other fields that might be useful later + ques_data = { + "pid": pid, + "query": query, + "choices": choices, + "unit": unit, + "precision": precision, + "answer": answer, + "question_type": question_type, + "answer_type": answer_type, + "metadata": metadata + } + test_time_objective = self._get_instance_test_time_objective(query, image_bytes) + instance_eval_fn = self._get_instance_eval_fn(query, answer, ques_data) + return image_bytes, query, answer, ques_data, test_time_objective, instance_eval_fn # NOTE: check the sample format + + def __len__(self): + return len(self.data) + + def get_default_task_instruction(self, instruction): + if instruction is not None: + print("Using user-defined task instruction:\n", instruction, "\n") + task_instruction = instruction + else: + task_instruction = "You will answer a mathematical reasoning question based on an image. Please ensure you accurately interpret the image and think step by step." + return task_instruction + + def get_default_evaluation_instruction(self, instruction): + if instruction is not None: + print("Using user-defined evaluation instruction:\n", instruction, "\n") + evaluation_instruction = instruction + else: + evaluation_instruction = "Please evaluate the existing answer to the visual math problem without solving it yourself. Verify that the answer provides accurate reasoning logic to address the question." + return evaluation_instruction + + def create_test_prompt(demo_prompt, query, response): + demo_prompt = demo_prompt.strip() + test_prompt = f"{query}\n\n{response}" + full_prompt = f"{demo_prompt}\n\n{test_prompt}\n\nExtracted answer: " + return full_prompt + + def _get_instance_test_time_objective(self, question: str, image: bytes): + """Define the loss function for the test time optimization.""" + eval_fn = ImageQALoss(evaluation_instruction=self.evaluation_instruction, engine=self.evaluation_api) + + def test_time_objective(instance: Variable): + var_image = Variable(image, role_description="image input", requires_grad=False) + var_question = Variable(question, role_description="question input", requires_grad=False) + return eval_fn(question=var_question, image=var_image, response=instance) + + return test_time_objective + + def eval_extraction_and_matching(self, response_text, correct_answer, question_data): + # Extract the precited answer text from the response + extracted_answer = extract_answer(response_text, question_data) + + # Normalize the extracted answer to match the answer type + normalized_answer = normalize_extracted_answer(extracted_answer, question_data) + + # Verify the prediction is true or false + true_false = safe_equal(normalized_answer, correct_answer) + + # Calculate the score and store the result data + # NOTE: check the result data format + score = 1 if true_false else 0 + result_data = { + "extracted_answer": extracted_answer, + "normalized_answer": normalized_answer, + "true_false": true_false + } + return score, result_data + + def _get_instance_eval_fn(self, question_prompt: str, answer: str, ques_data: dict): + """ + Define the evaluation function for scoring the response. + Extraxct the short answer from the response and compare it with the ground truth. + """ + # NOTE: check the evaluation function format + eval_extraction_based_fn = lambda response: self.eval_extraction_and_matching(response.value, answer, ques_data) + return eval_extraction_based_fn diff --git a/textgrad/tasks/multimodal/scienceqa.py b/textgrad/tasks/multimodal/scienceqa.py new file mode 100644 index 0000000..18a4633 --- /dev/null +++ b/textgrad/tasks/multimodal/scienceqa.py @@ -0,0 +1,227 @@ +import platformdirs +import re +import os +import io +from collections import Counter +from PIL import Image + +from textgrad.tasks.base import Dataset +from textgrad.loss import ImageQALoss +from textgrad.variable import Variable + + +def compress_image(decoded_image, max_size_bytes=3.6*1024*1024): + # First, try saving as PNG without any compression + buffer = io.BytesIO() + decoded_image.save(buffer, format='PNG') + size = buffer.tell() + + # If the original PNG is already small enough, return it + if size <= max_size_bytes: + buffer.seek(0) + return buffer.getvalue() + + # If PNG is too large, resize the image + width, height = decoded_image.size + while size > max_size_bytes: + print(f"Compressing image to {width}x{height}...") + width = int(width * 0.9) + height = int(height * 0.9) + resized_image = decoded_image.resize((width, height), Image.LANCZOS) + + buffer = io.BytesIO() + resized_image.save(buffer, format='PNG') + size = buffer.tell() + + if width <= 1 or height <= 1: + raise ValueError("Unable to compress image to the desired size without excessive loss of resolution") + + buffer.seek(0) + return buffer.getvalue() + +def preprocess_text(text): + # Remove punctuation and convert to lowercase + text = re.sub(r'[^\w\s]', '', text).lower() + return text + +def find_most_similar_choice(text, choices): + # Preprocess the given text + text = preprocess_text(text) + text_words = Counter(text.split()) + scores = [] + for choice in choices: + choice_text = preprocess_text(choice) + choice_words = Counter(choice_text.split()) + common_words = sum((text_words & choice_words).values()) # the number of common words + scores.append(common_words) + most_similar_index = scores.index(max(scores)) + return most_similar_index + +def extract_answer(response_text): + """ + Extract the answer from the response. + For example, "xxxxx. Answer: (A) a gas." -> "A" + If extraction fails, return the entire string after "Answer: ". + """ + # Attempt to match the format "Answer: (A)" or "Answer: (a)" in case-sensitive manner + match = re.search(r"Answer: \(([A-Z])\)", response_text) + if match: + return match.group(1).upper() # Return as uppercase + else: + # Fallback: match the format "Answer: " followed by any characters until the next period or end of line + fallback_match = re.search(r"Answer: ([^\.]+)", response_text) + if fallback_match: + return fallback_match.group(1).strip() + return response_text + +def normalize_extracted_answer(extracted_answer, question_data, options): + # Normalize the extracted answer + choices = question_data["choices"] + options = options[:len(choices)] + + # 'A' -> one of the choices + if extracted_answer in options: + normalized_answer = options.index(extracted_answer) + return normalized_answer + + # '(a) a gas.' + for choice in choices: + if choice.lower() in extracted_answer.lower(): + normalized_answer = choices.index(choice) + return normalized_answer + + # find the most similar choice + normalized_answer = find_most_similar_choice(extracted_answer, choices) + return normalized_answer + +def safe_equal(a, b): + # Check if two intergers are equal + return a == b + +class ScienceQADataset(Dataset): + def __init__(self, evaluation_api:str, root: str=None, split: str="test", task_instruction: str=None, evaluation_instruction: str=None, *args, **kwargs): + """ScienceQA dataset from HF.""" + if root is None: + root = platformdirs.user_cache_dir("textgrad") + self.root = root + self.split = split + assert split in ["test"] + self.data = self.load_scienceqa_data() + self.options = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J"] + self.evaluation_api = evaluation_api + self.task_instruction = self.get_default_task_instruction(task_instruction) # NOTE: check the task instruction + self.evaluation_instruction = self.get_default_evaluation_instruction(evaluation_instruction) # NOTE: check the evaluation instruction + + def __getitem__(self, index): + row = self.data[index] + pid = row["pid"] + image = row["image"] + question = row["question"] + choices = row["choices"] + answer = row["answer"] + hint = row["hint"] + + query = f"{self.task_instruction}" # NOTE: Add the task description + if hint is not None and len(hint) > 0: + query += f"\nContext: {hint}" + query += f"\nQuestion: {question}" + if choices: + choice_list = [] + for i, c in enumerate(choices): + choice_list.append("({}) {}".format(self.options[i], c)) + choice_txt = " ".join(choice_list) + query += f"\nChoices: {choice_txt}" + + # NOTE: convert image to bytes + if "claude" in self.evaluation_api.model_string: + image_bytes = compress_image(image) + # print("Image size:", len(image_bytes)) + else: + buffer = io.BytesIO() + image.save(buffer, format='png') + image_bytes = buffer.getvalue() + buffer.close() + + # NOTE: ques_data stores other fields that might be useful later + ques_data = { + "pid": pid, + "question": question, + "choices": choices, + "hint": hint + } + test_time_objective = self._get_instance_test_time_objective(query, image_bytes) + instance_eval_fn = self._get_instance_eval_fn(query, answer, ques_data) + return image_bytes, query, answer, ques_data, test_time_objective, instance_eval_fn # NOTE: check the sample format + + def __len__(self): + return len(self.data) + + def get_default_task_instruction(self, instruction): + if instruction is not None: + print("Using user-defined task instruction:\n", instruction, "\n") + task_instruction = instruction + else: + task_instruction = "You will answer a scientific question based on an image. Please ensure you accurately interpret the image and think step by step. The last line of your answer should be formatted as follows: 'Answer: (X) Your Option.'" + return task_instruction + + def load_scienceqa_data(self): + scienceqa_dir = os.path.join(self.root, "scienceqa") + try: + from datasets import Dataset + data = Dataset.load_from_disk(scienceqa_dir) + print("Loaded ScienceQA dataset from cache.") + return data + except FileNotFoundError: + from datasets import load_dataset + data = load_dataset("derek-thomas/ScienceQA", split=self.split) + data_img = data.filter(lambda x: x['image'] is not None) # select examples with a non-empty question image + data_img = data_img.map(lambda x, i: {'pid': str(i), **x}, with_indices=True) # add index ID (string) for each example + data_img.save_to_disk(scienceqa_dir) + print("Loaded ScienceQA dataset from HF.") + return data_img + + def get_default_evaluation_instruction(self, instruction): + if instruction is not None: + print("Using user-defined evaluation instruction:\n", instruction, "\n") + evaluation_instruction = instruction + else: + evaluation_instruction = "Please evaluate the existing answer to the visual scientific problem without solving it yourself. Verify that the answer accurately understands the image, provides appropriate knowledge and reasoning logic to address the question." + return evaluation_instruction + + def _get_instance_test_time_objective(self, question: str, image: bytes): + """Define the loss function for the test time optimization.""" + eval_fn = ImageQALoss(evaluation_instruction=self.evaluation_instruction, engine=self.evaluation_api) + def test_time_objective(instance: Variable): + var_image = Variable(image, role_description="image input", requires_grad=False) + var_question = Variable(question, role_description="question input", requires_grad=False) + return eval_fn(question=var_question, image=var_image, response=instance) + return test_time_objective + + def eval_extraction_and_matching(self, response_text, correct_answer, question_data): + # Extract the precited answer text from the response + extracted_answer = extract_answer(response_text) + + # Normalize the extracted answer to match the answer type + normalized_answer = normalize_extracted_answer(extracted_answer, question_data, self.options) + + # Verify the prediction is true or false + true_false = safe_equal(normalized_answer, correct_answer) + + # Calculate the score and store the result data + # NOTE: check the result data format + score = 1 if true_false else 0 + result_data = { + "extracted_answer": extracted_answer, + "normalized_answer": normalized_answer, + "true_false": true_false + } + return score, result_data + + def _get_instance_eval_fn(self, question_prompt: str, answer: str, ques_data: dict): + """ + Define the evaluation function for scoring the response. + Extraxct the short answer from the response and compare it with the ground truth. + """ + # NOTE: check the evaluation function format + eval_extraction_based_fn = lambda response: self.eval_extraction_and_matching(response.value, answer, ques_data) + return eval_extraction_based_fn From 09d827ade25eab43db3ebd86e38cbd2ccdfff567 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 12 Jul 2024 16:01:09 +0000 Subject: [PATCH 2/6] docs(contributor): contrib-readme-action has updated readme --- README.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index bd46d38..84b30b2 100644 --- a/README.md +++ b/README.md @@ -327,6 +327,13 @@ We are grateful for all the help we got from our contributors! Mert Yuksekgonul + + + nihalnayak +
+ Nihal Nayak +
+ sugatoray @@ -347,13 +354,6 @@ We are grateful for all the help we got from our contributors!
David Ruan
- - - - nihalnayak -
- Nihal Nayak -
From 14a93f7b5b1f1b8e02657eacc79fc163e4652090 Mon Sep 17 00:00:00 2001 From: Pan Lu Date: Fri, 12 Jul 2024 10:44:28 -0700 Subject: [PATCH 3/6] Update mathvista.py Added prompts for answer extraction. --- textgrad/tasks/multimodal/mathvista.py | 74 ++++++++++++++++++++++---- 1 file changed, 63 insertions(+), 11 deletions(-) diff --git a/textgrad/tasks/multimodal/mathvista.py b/textgrad/tasks/multimodal/mathvista.py index ce88220..4cea5ed 100644 --- a/textgrad/tasks/multimodal/mathvista.py +++ b/textgrad/tasks/multimodal/mathvista.py @@ -1,15 +1,18 @@ +import platformdirs +from .base import Dataset +from textgrad.loss import ImageQALoss +from textgrad.variable import Variable + import re import io -import platformdirs +# !pip install python-Levenshtein +from Levenshtein import distance from PIL import Image -from textgrad.tasks.base import Dataset -from textgrad.loss import ImageQALoss -from textgrad.variable import Variable -try: - from Levenshtein import distance -except ImportError: - raise ImportError("Please install the Levenshtein package using 'pip install python-Levenshtein' to use mathvista.") +from textgrad.engine.openai import ChatOpenAI + +local_llm_engine = ChatOpenAI(model_string="gpt-3.5-turbo", is_multimodal=False) +print("Local OpenAI engine initialized.\n") def compress_image(decoded_image, max_size_bytes=3.6*1024*1024): # Convert image to RGB if it's in a mode that JPEG does not support @@ -42,6 +45,47 @@ def compress_image(decoded_image, max_size_bytes=3.6*1024*1024): return buffer.getvalue() +# Demos (pids = 852, 104, 824, 506, 540) from MathVista +demo_prompt = """ +Please read the following example. Then extract the answer from the model response and type it at the end of the prompt. + +Hint: Please answer the question requiring an integer answer and provide the final value, e.g., 1, 2, 3, at the end. +Question: Which number is missing? + +Model response: The number missing in the sequence is 14. + +Extracted answer: 14 + +Hint: Please answer the question requiring a floating-point number with one decimal place and provide the final value, e.g., 1.2, 1.3, 1.4, at the end. +Question: What is the fraction of females facing the camera? + +Model response: The fraction of females facing the camera is 0.6, which means that six out of ten females in the group are facing the camera. + +Extracted answer: 0.6 + +Hint: Please answer the question requiring a floating-point number with two decimal places and provide the final value, e.g., 1.23, 1.34, 1.45, at the end. +Question: How much money does Luca need to buy a sour apple candy and a butterscotch candy? (Unit: $) + +Model response: Luca needs $1.45 to buy a sour apple candy and a butterscotch candy. + +Extracted answer: 1.45 + +Hint: Please answer the question requiring a Python list as an answer and provide the final list, e.g., [1, 2, 3], [1.2, 1.3, 1.4], at the end. +Question: Between which two years does the line graph saw its maximum peak? + +Model response: The line graph saw its maximum peak between 2007 and 2008. + +Extracted answer: [2007, 2008] + +Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end. +Question: What fraction of the shape is blue?\nChoices:\n(A) 3/11\n(B) 8/11\n(C) 6/11\n(D) 3/5 + +Model response: The correct answer is (B) 8/11. + +Extracted answer: B +""" + + def verify_extraction(extraction): extraction = extraction.strip() if extraction == "" or extraction == None: @@ -86,6 +130,7 @@ def extract_answer(response, problem, quick_extract=False): # quick extraction if quick_extract: print("Quickly extracting answer...") + # The answer is "text". -> "text" try: result = re.search(r'The answer is "(.*)"\.', response) if result: @@ -94,8 +139,16 @@ def extract_answer(response, problem, quick_extract=False): except: pass - else: - raise NotImplementedError("Extraction using LLMs are to-be-implemented.") + # general extraction + try: + full_prompt = create_test_prompt(demo_prompt, query, response) + extraction = local_llm_engine(full_prompt) + return extraction + except Exception as e: + print(e) + print(f"Error in extracting answer for {pid}") + + return "" def get_most_similar(prediction, choices): @@ -210,7 +263,6 @@ def __getitem__(self, index): # NOTE: convert image to bytes if "claude" in self.evaluation_api.model_string: - # TODO @lupantech This does not seem neat. image_bytes = compress_image(decoded_image) else: buffer = io.BytesIO() From 368241d62b2f4a46bad22a702a4e57fc7f8ac128 Mon Sep 17 00:00:00 2001 From: Pan Lu Date: Fri, 12 Jul 2024 11:39:39 -0700 Subject: [PATCH 4/6] Update mathvista.py Added the error handling when importing packages. --- textgrad/tasks/multimodal/mathvista.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/textgrad/tasks/multimodal/mathvista.py b/textgrad/tasks/multimodal/mathvista.py index 4cea5ed..a43eb0b 100644 --- a/textgrad/tasks/multimodal/mathvista.py +++ b/textgrad/tasks/multimodal/mathvista.py @@ -1,14 +1,16 @@ -import platformdirs -from .base import Dataset -from textgrad.loss import ImageQALoss -from textgrad.variable import Variable - import re import io -# !pip install python-Levenshtein -from Levenshtein import distance +import platformdirs from PIL import Image +from textgrad.tasks.base import Dataset +from textgrad.loss import ImageQALoss +from textgrad.variable import Variable +try: + from Levenshtein import distance +except ImportError: + raise ImportError("Please install the Levenshtein package using 'pip install python-Levenshtein' to use mathvista.") + from textgrad.engine.openai import ChatOpenAI local_llm_engine = ChatOpenAI(model_string="gpt-3.5-turbo", is_multimodal=False) From 652e13cf4dd6b6c6dbc10aa309af7f22c7f7e409 Mon Sep 17 00:00:00 2001 From: mertyg Date: Fri, 12 Jul 2024 12:30:23 -0700 Subject: [PATCH 5/6] raise exception if cannot parse answer, do not return an arbitrary answer --- textgrad/tasks/multimodal/mathvista.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/textgrad/tasks/multimodal/mathvista.py b/textgrad/tasks/multimodal/mathvista.py index a43eb0b..9523774 100644 --- a/textgrad/tasks/multimodal/mathvista.py +++ b/textgrad/tasks/multimodal/mathvista.py @@ -11,10 +11,6 @@ except ImportError: raise ImportError("Please install the Levenshtein package using 'pip install python-Levenshtein' to use mathvista.") -from textgrad.engine.openai import ChatOpenAI - -local_llm_engine = ChatOpenAI(model_string="gpt-3.5-turbo", is_multimodal=False) -print("Local OpenAI engine initialized.\n") def compress_image(decoded_image, max_size_bytes=3.6*1024*1024): # Convert image to RGB if it's in a mode that JPEG does not support @@ -31,7 +27,6 @@ def compress_image(decoded_image, max_size_bytes=3.6*1024*1024): width, height = decoded_image.size while size > max_size_bytes: - print(f"Compressing image to {width}x{height}...") width = int(width * 0.9) height = int(height * 0.9) resized_image = decoded_image.resize((width, height), Image.LANCZOS) @@ -131,26 +126,24 @@ def extract_answer(response, problem, quick_extract=False): # quick extraction if quick_extract: - print("Quickly extracting answer...") - # The answer is "text". -> "text" try: result = re.search(r'The answer is "(.*)"\.', response) if result: extraction = result.group(1) return extraction - except: - pass + except Exception as e: + raise Exception(f"Error in extracting answer for {pid}: {e}. Remove this line responsibly.") # general extraction try: + from textgrad.engine.openai import ChatOpenAI + local_llm_engine = ChatOpenAI(model_string="gpt-3.5-turbo", is_multimodal=False) + full_prompt = create_test_prompt(demo_prompt, query, response) extraction = local_llm_engine(full_prompt) return extraction except Exception as e: - print(e) - print(f"Error in extracting answer for {pid}") - - return "" + raise Exception(f"Error in extracting answer for {pid}: {e}") def get_most_similar(prediction, choices): From 67b7417238166a694eeee234a0824c08d08b6d78 Mon Sep 17 00:00:00 2001 From: mertyg Date: Fri, 12 Jul 2024 12:32:04 -0700 Subject: [PATCH 6/6] remove unused comments --- textgrad/tasks/multimodal/mathvista.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/textgrad/tasks/multimodal/mathvista.py b/textgrad/tasks/multimodal/mathvista.py index 9523774..5f2c44b 100644 --- a/textgrad/tasks/multimodal/mathvista.py +++ b/textgrad/tasks/multimodal/mathvista.py @@ -124,7 +124,6 @@ def extract_answer(response, problem, quick_extract=False): except: pass - # quick extraction if quick_extract: try: result = re.search(r'The answer is "(.*)"\.', response) @@ -134,7 +133,6 @@ def extract_answer(response, problem, quick_extract=False): except Exception as e: raise Exception(f"Error in extracting answer for {pid}: {e}. Remove this line responsibly.") - # general extraction try: from textgrad.engine.openai import ChatOpenAI local_llm_engine = ChatOpenAI(model_string="gpt-3.5-turbo", is_multimodal=False) @@ -153,7 +151,6 @@ def get_most_similar(prediction, choices): distances = [distance(prediction, choice) for choice in choices] ind = distances.index(min(distances)) return choices[ind] - # return min(choices, key=lambda choice: distance(prediction, choice)) def normalize_extracted_answer(extraction, question_data): @@ -243,8 +240,6 @@ def __init__(self, evaluation_api:str, root: str=None, split: str="testmini", ta def __getitem__(self, index): row = self.data[index] pid = row["pid"] - # question = row["question"] - # image = row["image"] decoded_image = row["decoded_image"] choices = row["choices"] unit = row["unit"]