diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml
index 631340892..336bd9c3b 100644
--- a/.github/workflows/lint.yml
+++ b/.github/workflows/lint.yml
@@ -20,4 +20,4 @@ jobs:
pip install pre-commit
pre-commit install
- name: Linting
- run: pre-commit run --all-files
\ No newline at end of file
+ run: pre-commit run --all-files
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 3e1613917..a486c71c8 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,15 +1,15 @@
exclude: |
(?x)^(
- results/|
scripts/|
- assets/
+ assets/|
+ vlmeval/config.py
)
repos:
- repo: https://github.com/PyCQA/flake8
rev: 5.0.4
hooks:
- id: flake8
- args: ["--max-line-length=120", "--ignore=F401,F405,E402"]
+ args: ["--max-line-length=120", "--ignore=F401,F403,F405,E402,E722,E741,W503"]
exclude: ^configs/
- repo: https://github.com/pre-commit/mirrors-yapf
rev: v0.30.0
diff --git a/Custom_Benchmark_and_Model.md b/Custom_Benchmark_and_Model.md
index bdf56035b..9def84f89 100644
--- a/Custom_Benchmark_and_Model.md
+++ b/Custom_Benchmark_and_Model.md
@@ -1,8 +1,8 @@
-# 🛠️ How to implement a new Benchmark / VLM in VLMEvalKit?
+# 🛠️ How to implement a new Benchmark / VLM in VLMEvalKit?
## Implement a new benchmark
-Currently, we organize a benchmark as one single TSV file. During inference, the data file will be automatically downloaded to `$LMUData` (default path is `$HOME/LMUData`, if not set explicitly). All existing benchmark TSV files are handled by `TSVDataset` implemented in `vlmeval/utils/dataset_config.py`.
+Currently, we organize a benchmark as one single TSV file. During inference, the data file will be automatically downloaded to `$LMUData` (default path is `$HOME/LMUData`, if not set explicitly). All existing benchmark TSV files are handled by `TSVDataset` implemented in `vlmeval/utils/dataset_config.py`.
| Dataset Name \ Fields | index | image | image_path | question | hint | multi-choice
options | answer | category | l2-category | split |
| ---------------------- | ----- | ----- | ---------- | -------- | ---- | ----------------------- | ------ | -------- | ----------- | ----- |
@@ -23,7 +23,7 @@ Currently, we organize a benchmark as one single TSV file. During inference, the
**Intro to some fields:**
- **index:** Integer, Unique for each line in `tsv`
-- **image:** the base64 of the image, you can use APIs implemented in `vlmeval/smp.py` for encoding and decoding:
+- **image:** the base64 of the image, you can use APIs implemented in `vlmeval/smp.py` for encoding and decoding:
- Encoding: `encode_image_to_base64 `(for PIL Image) / `encode_image_file_to_base64` (for image file path)
- Decoding: `decode_base64_to_image`(for PIL Image) / `decode_base64_to_image_file` (for image file path)
@@ -31,6 +31,6 @@ Besides, your dataset class **should implement the method `build_prompt(self, li
## Implement a new model
-All existing models are implemented in `vlmeval/vlm`. For a minimal model, your model class **should implement the method** `generate(image_path, prompt, dataset=None)`. In this function, you feed the image and prompt to your VLM and return the VLM prediction (which is a string). The optional argument `dataset` can be used as the flag for the model to switch among various inference strategies.
+All existing models are implemented in `vlmeval/vlm`. For a minimal model, your model class **should implement the method** `generate(image_path, prompt, dataset=None)`. In this function, you feed the image and prompt to your VLM and return the VLM prediction (which is a string). The optional argument `dataset` can be used as the flag for the model to switch among various inference strategies.
-Besides, your model can support custom prompt building by implementing an optional method `build_prompt(line, dataset=None)`. In this function, the line is a dictionary that includes the necessary information of a data sample, while `dataset` can be used as the flag for the model to switch among various prompt building strategies.
\ No newline at end of file
+Besides, your model can support custom prompt building by implementing an optional method `build_prompt(line, dataset=None)`. In this function, the line is a dictionary that includes the necessary information of a data sample, while `dataset` can be used as the flag for the model to switch among various prompt building strategies.
diff --git a/README.md b/README.md
index 82e7ce650..dca5b09db 100644
--- a/README.md
+++ b/README.md
@@ -16,7 +16,7 @@
-**VLMEvalKit** (the python package name is **vlmeval**) is an **open-source evaluation toolkit** of **large vision-language models (LVLMs)**. It enables **one-command evaluation** of LVLMs on various benchmarks, without the heavy workload of data preparation under multiple repositories. In VLMEvalKit, we adopt **generation-based evaluation** for all LVLMs (obtain the answer via `generate` / `chat` interface), and provide the evaluation results obtained with both **exact matching** and **LLM(ChatGPT)-based answer extraction**.
+**VLMEvalKit** (the python package name is **vlmeval**) is an **open-source evaluation toolkit** of **large vision-language models (LVLMs)**. It enables **one-command evaluation** of LVLMs on various benchmarks, without the heavy workload of data preparation under multiple repositories. In VLMEvalKit, we adopt **generation-based evaluation** for all LVLMs (obtain the answer via `generate` / `chat` interface), and provide the evaluation results obtained with both **exact matching** and **LLM(ChatGPT)-based answer extraction**.
## 🆕 News
@@ -27,7 +27,7 @@
- **[2024-02-24]** We have supported [**InternVL-Chat Series**](https://github.com/OpenGVLab/InternVL). The models achieve over 80% Top-1 accuracies on MMBench v1.0 [[**Blog**](https://github.com/OpenGVLab/InternVL/blob/main/BLOG.md)]. 🔥🔥🔥
- **[2024-02-07]** We have supported two new models: [**MiniCPM-V**](https://huggingface.co/openbmb/MiniCPM-V) and [**OmniLMM-12B**](https://huggingface.co/openbmb/OmniLMM-12B). 🔥🔥🔥
- **[2024-01-30]** We have supported three new models: [**QwenVLMax**](https://huggingface.co/spaces/Qwen/Qwen-VL-Max), [**InternLM-XComposer2-7B**](https://huggingface.co/internlm/internlm-xcomposer2-vl-7b), [**MMAlaya**](https://huggingface.co/DataCanvas/MMAlaya). 🔥🔥🔥
-- **[2024-01-30]** We have merged all performance numbers on our leaderboards into a single json file: [**OpenVLM.json**](http://opencompass.openxlab.space/utils/OpenVLM.json).
+- **[2024-01-30]** We have merged all performance numbers on our leaderboards into a single json file: [**OpenVLM.json**](http://opencompass.openxlab.space/utils/OpenVLM.json).
- **[2024-01-27]** We have supported the evaluation of [**MMMU_TEST**](https://mmmu-benchmark.github.io). 🔥🔥🔥
- **[2024-01-24]** We have supported [**Yi-VL**](https://huggingface.co/01-ai/Yi-VL-6B). 🔥🔥🔥
@@ -36,7 +36,7 @@
**The performance numbers on our official multi-modal leaderboards can be downloaded from here!**
-[**OpenCompass Multi-Modal Leaderboard**](https://rank.opencompass.org.cn/leaderboard-multimodal): [Download All DETAILED Results](http://opencompass.openxlab.space/utils/OpenVLM.json).
+[**OpenCompass Multi-Modal Leaderboard**](https://rank.opencompass.org.cn/leaderboard-multimodal): [Download All DETAILED Results](http://opencompass.openxlab.space/utils/OpenVLM.json).
**Supported Dataset**
@@ -60,7 +60,7 @@
| [**OCRBench**](https://github.com/Yuliang-Liu/MultimodalOCR) | OCRBench | ✅ | ✅ | **TBD.** |
| [**Core-MM**](https://github.com/core-mm/core-mm) | CORE_MM | ✅ | | **N/A** |
-**There are some known issues with VQA tasks like OCRVQA, TextVQA, ChartQA, etc. We will fix them asap.**
+**There are some known issues with VQA tasks like OCRVQA, TextVQA, ChartQA, etc. We will fix them asap.**
**Supported API Models**
@@ -77,9 +77,9 @@
| [**Monkey**](https://github.com/Yuliang-Liu/Monkey)🚅 | [**EMU2 / EMU2-Chat**](https://github.com/baaivision/Emu)🚅🎞️ | [**Yi-VL-[6B/34B]**](https://huggingface.co/01-ai/Yi-VL-6B) | [**MMAlaya**](https://huggingface.co/DataCanvas/MMAlaya)🚅 |
| [**InternLM-XComposer2-7B**](https://huggingface.co/internlm/internlm-xcomposer2-vl-7b)🚅🎞️ | [**MiniCPM-V**](https://huggingface.co/openbmb/MiniCPM-V)🚅 | [**OmniLMM-12B**](https://huggingface.co/openbmb/OmniLMM-12B) | [**InternVL-Chat Series**](https://github.com/OpenGVLab/InternVL)🚅 |
-🎞️: Support multiple images as inputs, via the `interleave_generate` interface.
+🎞️: Support multiple images as inputs, via the `interleave_generate` interface.
-🚅: Model can be used without any additional configuration / operation.
+🚅: Model can be used without any additional configuration / operation.
**Transformers Version Recommendation: ** Note that some VLMs may not be able to run under certain transformer versions, we recommend the following settings to evaluate each VLM:
@@ -100,9 +100,9 @@ print(ret) # There are two apples in the provided images.
## 🏗️ QuickStart
-Before running the evaluation script, you need to **configure** the VLMs and set the model_paths properly.
+Before running the evaluation script, you need to **configure** the VLMs and set the model_paths properly.
-After that, you can use a single script `run.py` to inference and evaluate multiple VLMs and benchmarks at a same time.
+After that, you can use a single script `run.py` to inference and evaluate multiple VLMs and benchmarks at a same time.
### Step0. Installation
@@ -118,18 +118,18 @@ pip install -e .
Following VLMs require the configuration step:
-**Code Preparation & Installation**: InstructBLIP ([LAVIS](https://github.com/salesforce/LAVIS)), LLaVA ([LLaVA](https://github.com/haotian-liu/LLaVA)), MiniGPT-4 ([MiniGPT-4](https://github.com/Vision-CAIR/MiniGPT-4)), mPLUG-Owl2 ([mPLUG-Owl2](https://github.com/X-PLUG/mPLUG-Owl/tree/main/mPLUG-Owl2)), OpenFlamingo-v2 ([OpenFlamingo](https://github.com/mlfoundations/open_flamingo)), PandaGPT-13B ([PandaGPT](https://github.com/yxuansu/PandaGPT)), TransCore-M ([TransCore-M](https://github.com/PCIResearch/TransCore-M)).
+**Code Preparation & Installation**: InstructBLIP ([LAVIS](https://github.com/salesforce/LAVIS)), LLaVA ([LLaVA](https://github.com/haotian-liu/LLaVA)), MiniGPT-4 ([MiniGPT-4](https://github.com/Vision-CAIR/MiniGPT-4)), mPLUG-Owl2 ([mPLUG-Owl2](https://github.com/X-PLUG/mPLUG-Owl/tree/main/mPLUG-Owl2)), OpenFlamingo-v2 ([OpenFlamingo](https://github.com/mlfoundations/open_flamingo)), PandaGPT-13B ([PandaGPT](https://github.com/yxuansu/PandaGPT)), TransCore-M ([TransCore-M](https://github.com/PCIResearch/TransCore-M)).
**Manual Weight Preparation & Configuration**: InstructBLIP, LLaVA-v1-7B, MiniGPT-4, PandaGPT-13B
-### Step2. Evaluation
+### Step2. Evaluation
We use `run.py` for evaluation. To use the script, you can use `$VLMEvalKit/run.py` or create a soft-link of the script (to use the script anywhere):
**Arguments**
-- `--data (list[str])`: Set the dataset names that are supported in VLMEvalKit (defined in `vlmeval/utils/dataset_config.py`).
-- `--model (list[str])`: Set the VLM names that are supported in VLMEvalKit (defined in `supported_VLM` in `vlmeval/config.py`).
+- `--data (list[str])`: Set the dataset names that are supported in VLMEvalKit (defined in `vlmeval/utils/dataset_config.py`).
+- `--model (list[str])`: Set the VLM names that are supported in VLMEvalKit (defined in `supported_VLM` in `vlmeval/config.py`).
- `--mode (str, default to 'all', choices are ['all', 'infer'])`: When `mode` set to "all", will perform both inference and evaluation; when set to "infer", will only perform the inference.
- `--nproc (int, default to 4)`: The number of threads for OpenAI API calling.
@@ -138,24 +138,24 @@ We use `run.py` for evaluation. To use the script, you can use `$VLMEvalKit/run.
You can run the script with `python` or `torchrun`:
```bash
-# When running with `python`, only one VLM instance is instantiated, and it might use multiple GPUs (depending on its default behavior).
+# When running with `python`, only one VLM instance is instantiated, and it might use multiple GPUs (depending on its default behavior).
# That is recommended for evaluating very large VLMs (like IDEFICS-80B-Instruct).
# IDEFICS-80B-Instruct on MMBench_DEV_EN, MME, and SEEDBench_IMG, Inference and Evalution
-python run.py --data MMBench_DEV_EN MME SEEDBench_IMG --model idefics_80b_instruct --verbose
+python run.py --data MMBench_DEV_EN MME SEEDBench_IMG --model idefics_80b_instruct --verbose
# IDEFICS-80B-Instruct on MMBench_DEV_EN, MME, and SEEDBench_IMG, Inference only
python run.py --data MMBench_DEV_EN MME SEEDBench_IMG --model idefics_80b_instruct --verbose --mode infer
-# When running with `torchrun`, one VLM instance is instantiated on each GPU. It can speed up the inference.
-# However, that is only suitable for VLMs that consume small amounts of GPU memory.
+# When running with `torchrun`, one VLM instance is instantiated on each GPU. It can speed up the inference.
+# However, that is only suitable for VLMs that consume small amounts of GPU memory.
# IDEFICS-9B-Instruct, Qwen-VL-Chat, mPLUG-Owl2 on MMBench_DEV_EN, MME, and SEEDBench_IMG. On a node with 8 GPU. Inference and Evaluation.
torchrun --nproc-per-node=8 run.py --data MMBench_DEV_EN MME SEEDBench_IMG --model idefics_80b_instruct qwen_chat mPLUG-Owl2 --verbose
-# Qwen-VL-Chat on MME. On a node with 2 GPU. Inference and Evaluation.
+# Qwen-VL-Chat on MME. On a node with 2 GPU. Inference and Evaluation.
torchrun --nproc-per-node=2 run.py --data MME --model qwen_chat --verbose
```
-The evaluation results will be printed as logs, besides. **Result Files** will also be generated in the directory `$YOUR_WORKING_DIRECTORY/{model_name}`. Files ending with `.csv` contain the evaluated metrics.
+The evaluation results will be printed as logs, besides. **Result Files** will also be generated in the directory `$YOUR_WORKING_DIRECTORY/{model_name}`. Files ending with `.csv` contain the evaluated metrics.
## 🛠️ Custom Benchmark or VLM
@@ -169,13 +169,13 @@ To implement a custom benchmark or VLM in **VLMEvalKit**, please refer to [Custo
**The codebase is designed to:**
1. Provide an **easy-to-use**, **opensource evaluation toolkit** to make it convenient for researchers & developers to evaluate existing LVLMs and make evaluation results **easy to reproduce**.
-2. Make it easy for VLM developers to evaluate their own models. To evaluate the VLM on multiple supported benchmarks, one just need to **implement a single `generate` function**, all other workloads (data downloading, data preprocessing, prediction inference, metric calculation) are handled by the codebase.
+2. Make it easy for VLM developers to evaluate their own models. To evaluate the VLM on multiple supported benchmarks, one just need to **implement a single `generate` function**, all other workloads (data downloading, data preprocessing, prediction inference, metric calculation) are handled by the codebase.
**The codebase is not designed to:**
1. Reproduce the exact accuracy number reported in the original papers of all **3rd party benchmarks**. The reason can be two-fold:
- 1. VLMEvalKit uses **generation-based evaluation** for all VLMs (and optionally with **LLM-based answer extraction**). Meanwhile, some benchmarks may use different approaches (SEEDBench uses PPL-based evaluation, *eg.*). For those benchmarks, we compare both scores in the corresponding result. We encourage developers to support other evaluation paradigms in the codebase.
- 2. By default, we use the same prompt template for all VLMs to evaluate on a benchmark. Meanwhile, **some VLMs may have their specific prompt templates** (some may not covered by the codebase at this time). We encourage VLM developers to implement their own prompt template in VLMEvalKit, if that is not covered currently. That will help to improve the reproducibility.
+ 1. VLMEvalKit uses **generation-based evaluation** for all VLMs (and optionally with **LLM-based answer extraction**). Meanwhile, some benchmarks may use different approaches (SEEDBench uses PPL-based evaluation, *eg.*). For those benchmarks, we compare both scores in the corresponding result. We encourage developers to support other evaluation paradigms in the codebase.
+ 2. By default, we use the same prompt template for all VLMs to evaluate on a benchmark. Meanwhile, **some VLMs may have their specific prompt templates** (some may not covered by the codebase at this time). We encourage VLM developers to implement their own prompt template in VLMEvalKit, if that is not covered currently. That will help to improve the reproducibility.
## 🖊️ Citation
@@ -195,4 +195,4 @@ If you use VLMEvalKit in your research or wish to refer to the published OpenSou
- [opencompass](https://github.com/open-compass/opencompass/): An LLM evaluation platform, supporting a wide range of models (LLaMA, LLaMa2, ChatGLM2, ChatGPT, Claude, etc) over 50+ datasets.
- [MMBench](https://github.com/open-compass/MMBench/): Official Repo of "MMBench: Is Your Multi-modal Model an All-around Player?"
- [BotChat](https://github.com/open-compass/BotChat/): Evaluating LLMs' multi-round chatting capability.
-- [LawBench](https://github.com/open-compass/LawBench): Benchmarking Legal Knowledge of Large Language Models.
+- [LawBench](https://github.com/open-compass/LawBench): Benchmarking Legal Knowledge of Large Language Models.
diff --git a/requirements.txt b/requirements.txt
index 6057e742a..83350f42f 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,28 +1,28 @@
+einops
+gradio==4.15.0
+huggingface_hub
+matplotlib
numpy>=1.23.4
+omegaconf
openai==1.3.5
-requests
-tqdm
+opencv-python>=4.4.0.46
+openpyxl
pandas>=1.5.3
-gradio==4.15.0
-tiktoken
-rich
+pillow
portalocker
-torch>=2.0.1
+pycocoevalcap
+requests
+rich
+seaborn
+sentencepiece
+sty
+tabulate
+tiktoken
timeout-decorator
+torch>=2.0.1
+tqdm
transformers
-opencv-python>=4.4.0.46
typing_extensions==4.7.1
-pillow
-omegaconf
-matplotlib
-einops
-sentencepiece
-sty
-xtuner
-huggingface_hub
visual_genome
-pycocoevalcap
-openpyxl
-seaborn
-tabulate
-xlsxwriter
\ No newline at end of file
+xlsxwriter
+xtuner
diff --git a/results/Caption.md b/results/Caption.md
deleted file mode 100644
index 75b4970c7..000000000
--- a/results/Caption.md
+++ /dev/null
@@ -1,48 +0,0 @@
-# Caption Results
-
-## COCO Caption
-
-> By default, we evaluate COCO Caption Validation set (5000 samples), and report the following metrics: `BLEU-1, BLEU-4, CIDEr, ROUGE-L
->
-> We use the following prompt to evaluate all VLMs: `Please describe this image in general. Directly provide the description, do not include prefix like "This image depicts". `
->
-> **No specific prompt is adopted for all VLMs.**
-
-### Evaluation Results
-
-| Model | BLEU-4 | BLEU-1 | ROUGE-L | CIDEr | Word_cnt mean. | Word_cnt std. |
-|:----------------------------|---------:|---------:|----------:|--------:|-----------------:|----------------:|
-| EMU2-Chat | 38.7 | 78.2 | 56.9 | 109.2 | 9.6 | 1.1 |
-| Qwen-VL-Chat | 34 | 75.8 | 54.9 | 98.9 | 10 | 1.7 |
-| IDEFICS-80B-Instruct | 32.5 | 76.1 | 54.1 | 94.9 | 9.7 | 3.2 |
-| IDEFICS-9B-Instruct | 29.4 | 72.7 | 53.4 | 90.4 | 10.5 | 4.4 |
-| InstructBLIP-7B | 20.9 | 56.8 | 39.9 | 58.1 | 11.6 | 5.9 |
-| InstructBLIP-13B | 16.9 | 50 | 37 | 52.4 | 11.8 | 12.8 |
-| InternLM-XComposer-VL | 12.4 | 38.3 | 37.9 | 41 | 26.3 | 22.2 |
-| GeminiProVision | 8.4 | 33.2 | 31.2 | 9.7 | 35.2 | 15.7 |
-| LLaVA-v1.5-7B (QLoRA) | 7.2 | 25 | 36.6 | 43.2 | 48.8 | 42.9 |
-| mPLUG-Owl2 | 7.1 | 25.8 | 33.6 | 35 | 45.8 | 32.1 |
-| LLaVA-v1-7B | 6.7 | 27.3 | 26.7 | 6.1 | 40.9 | 16.1 |
-| VisualGLM | 5.4 | 28.6 | 23.6 | 0.2 | 41.5 | 11.5 |
-| LLaVA-v1.5-13B (QLoRA) | 5.3 | 19.6 | 25.8 | 17.8 | 72.2 | 39.4 |
-| LLaVA-v1.5-13B | 5.1 | 20.7 | 21.2 | 0.3 | 70.6 | 22.3 |
-| LLaVA-v1.5-7B | 4.6 | 19.6 | 19.9 | 0.1 | 72.5 | 21.7 |
-| PandaGPT-13B | 4.6 | 19.9 | 19.3 | 0.1 | 65.4 | 16.6 |
-| MiniGPT-4-v1-13B | 4.4 | 20 | 19.8 | 1.3 | 64.4 | 30.5 |
-| MiniGPT-4-v1-7B | 4.3 | 19.6 | 17.5 | 0.8 | 61.9 | 30.6 |
-| LLaVA-InternLM-7B (QLoRA) | 4 | 17.3 | 17.2 | 0.1 | 82.3 | 21 |
-| LLaVA-InternLM2-20B (QLoRA) | 4 | 17.9 | 17.3 | 0 | 83.2 | 20.4 |
-| CogVLM-17B-Chat | 3.6 | 21.3 | 20 | 0.1 | 56.2 | 13.7 |
-| Qwen-VL | 3.5 | 11.6 | 30 | 41.1 | 46.6 | 105.2 |
-| GPT-4v (detail: low) | 3.3 | 18 | 18.1 | 0 | 77.8 | 20.4 |
-| TransCore-M | 2.1 | 14.2 | 13.8 | 0.2 | 92 | 6.7 |
-| ShareGPT4V-7B | 1.4 | 9.7 | 10.6 | 0.1 | 147.9 | 45.4 |
-| MiniGPT-4-v2 | 1.4 | 12.6 | 13.3 | 0.1 | 83 | 27.1 |
-| OpenFlamingo v2 | 1.3 | 6.4 | 15.8 | 14.9 | 60 | 81.9 |
-| SharedCaptioner | 1 | 8.8 | 9.2 | 0 | 164.2 | 31.6 |
-
-We noticed that, VLMs that generate long image descriptions tend to achieve inferior scores under different caption metrics.
-
-### Error Analysis & Case Study
-
-TBD.
\ No newline at end of file
diff --git a/results/ScienceQA.md b/results/ScienceQA.md
deleted file mode 100644
index 7d8cb55ed..000000000
--- a/results/ScienceQA.md
+++ /dev/null
@@ -1,40 +0,0 @@
-# ScienceQA Evaluation Results
-
-> We benchmark the **image** subset of ScienceQA validation and test set, and report the Top-1 accuracy.
->
-> During evaluation, we use `GPT-3.5-Turbo-0613` as the choice extractor for all VLMs if the choice can not be extracted via heuristic matching. **Zero-shot** inference is adopted.
-
-## ScienceQA Accuracy
-
-| Model | ScienceQA-Image Val | ScienceQA-Image Test |
-|:----------------------------|:----------------------|-----------------------:|
-| InternLM-XComposer-VL | 88.0 | 89.8 |
-| Human Performance | N/A | 87.5 |
-| SharedCaptioner | 81.0 | 82.3 |
-| GPT-4v (detail: low) | 84.6 | 82.1 |
-| GeminiProVision | 80.1 | 81.4 |
-| LLaVA-InternLM2-20B (QLoRA) | 72.7 | 73.7 |
-| Monkey | 68.2 | 72.1 |
-| LLaVA-v1.5-13B | 69.2 | 72 |
-| TransCore-M | 68.8 | 71.2 |
-| LLaVA-v1.5-13B (QLoRA) | 68.9 | 70.3 |
-| mPLUG-Owl2 | 69.5 | 69.5 |
-| ShareGPT4V-7B | 68.1 | 69.4 |
-| LLaVA-v1.5-7B | 66.6 | 68.9 |
-| Qwen-VL-Chat | 65.5 | 68.8 |
-| LLaVA-v1.5-7B (QLoRA) | 68.8 | 68.7 |
-| LLaVA-InternLM-7B (QLoRA) | 65.3 | 68.4 |
-| EMU2-Chat | 65.3 | 68.2 |
-| CogVLM-17B-Chat | 65.6 | 66.2 |
-| PandaGPT-13B | 60.9 | 63.2 |
-| IDEFICS-80B-Instruct | 59.9 | 61.8 |
-| Qwen-VL | 57.7 | 61.1 |
-| LLaVA-v1-7B | 59.9 | 60.5 |
-| InstructBLIP-13B | 53.3 | 58.3 |
-| VisualGLM | 53.4 | 56.1 |
-| MiniGPT-4-v2 | 54.1 | 54.7 |
-| InstructBLIP-7B | 54.7 | 54.1 |
-| IDEFICS-9B-Instruct | 51.6 | 53.5 |
-| MiniGPT-4-v1-13B | 44.3 | 46 |
-| OpenFlamingo v2 | 45.7 | 44.8 |
-| MiniGPT-4-v1-7B | 39.0 | 39.6 |
diff --git a/run.py b/run.py
index 80d1b7f5f..126870346 100644
--- a/run.py
+++ b/run.py
@@ -1,37 +1,39 @@
import torch
import torch.distributed as dist
from vlmeval.smp import *
-from vlmeval.evaluate import COCO_eval, YOrN_eval, MMVet_eval, multiple_choice_eval, VQAEval, MathVista_eval, LLaVABench_eval, OCRBench_eval
+from vlmeval.evaluate import *
from vlmeval.inference import infer_data_job, prefetch_acc
from vlmeval.config import supported_VLM
from vlmeval.utils import dataset_URLs, DATASET_TYPE, abbr2full
+
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, nargs='+', required=True)
- parser.add_argument("--model", type=str, nargs='+', required=True)
- parser.add_argument("--work-dir", type=str, default='.', help="select the output directory")
- parser.add_argument("--mode", type=str, default='all', choices=['all', 'infer'])
- parser.add_argument("--nproc", type=int, default=4, help="Parallel API calling")
- parser.add_argument("--retry", type=int, default=None, help="retry numbers for API VLMs")
- parser.add_argument("--ignore", action='store_true', help="Ignore failed indices. ")
- parser.add_argument("--verbose", action='store_true')
- parser.add_argument("--prefetch", action='store_true')
+ parser.add_argument('--model', type=str, nargs='+', required=True)
+ parser.add_argument('--work-dir', type=str, default='.', help='select the output directory')
+ parser.add_argument('--mode', type=str, default='all', choices=['all', 'infer'])
+ parser.add_argument('--nproc', type=int, default=4, help='Parallel API calling')
+ parser.add_argument('--retry', type=int, default=None, help='retry numbers for API VLMs')
+ parser.add_argument('--ignore', action='store_true', help='Ignore failed indices. ')
+ parser.add_argument('--verbose', action='store_true')
+ parser.add_argument('--prefetch', action='store_true')
args = parser.parse_args()
return args
+
def main():
logger = get_logger('RUN')
args = parse_args()
- assert len(args.data), "--data should be a list of data files"
+ assert len(args.data), '--data should be a list of data files'
if args.retry is not None:
for k, v in supported_VLM.items():
if hasattr(v, 'keywords') and 'retry' in v.keywords:
v.keywords['retry'] = args.retry
supported_VLM[k] = v
-
+
rank, world_size = get_rank_and_world_size()
if world_size > 1:
torch.cuda.set_device(rank)
@@ -48,7 +50,7 @@ def main():
if dataset_name not in dataset_URLs:
dataset_name = abbr2full(dataset_name)
-
+
if dataset_name not in dataset_URLs:
logger.warning(f'Dataset {dataset_name} is not officially supported. ')
file_path = osp.join(LMUDataRoot(), f'{dataset_name}.tsv')
@@ -59,24 +61,40 @@ def main():
custom_flag = True
result_file = f'{pred_root}/{model_name}_{dataset_name}.xlsx'
-
+
if model is None:
- model = model_name # which is only a name
+ model = model_name # which is only a name
# CHECKER
if dataset_name == 'CORE_MM':
MULTI_IMG = getattr(supported_VLM[model_name].func, 'interleave_generate', None)
if MULTI_IMG is not None:
- logger.error(f'Model {model_name} does not support the `interleave_generate` interface, which is required for testing CORE_MM, skip it. ')
+ logger.error(
+ f'Model {model_name} does not support the `interleave_generate` interface, '
+ 'which is required for testing CORE_MM, skip it. '
+ )
continue
if args.mode == 'all':
- logger.error(f'Dataset {dataset_name} does not support `evaluation` now, will skip the evaluation. ')
+ logger.error(
+ f'Dataset {dataset_name} does not support `evaluation` now, '
+ 'will skip the evaluation. '
+ )
- model = infer_data_job(model, work_dir=pred_root, model_name=model_name, dataset_name=dataset_name, verbose=args.verbose, api_nproc=args.nproc, ignore_failed=args.ignore)
+ model = infer_data_job(
+ model,
+ work_dir=pred_root,
+ model_name=model_name,
+ dataset_name=dataset_name,
+ verbose=args.verbose,
+ api_nproc=args.nproc,
+ ignore_failed=args.ignore)
- if dataset_name in ['MMBench_TEST_CN', 'MMBench_TEST_EN', "MMMU_TEST"]:
+ if dataset_name in ['MMBench_TEST_CN', 'MMBench_TEST_EN', 'MMMU_TEST']:
if not MMBenchOfficialServer():
- logger.error(f'Can not evaluate {dataset_name} on non-official servers, will skip the evaluation. ')
+ logger.error(
+ f'Can not evaluate {dataset_name} on non-official servers, '
+ 'will skip the evaluation. '
+ )
continue
if rank == 0 and args.prefetch:
@@ -90,13 +108,25 @@ def main():
logger.info(f'{model_name} prefetching: ')
logger.info(res)
dump(res, result_file.replace('.xlsx', '_prefetch.xlsx'))
-
+
if rank == 0 and args.mode == 'all':
if DATASET_TYPE(dataset_name) == 'multi-choice':
- dataset_name = "default" if custom_flag else dataset_name
- multiple_choice_eval(result_file, dataset=dataset_name, model='chatgpt-0613', nproc=args.nproc, verbose=args.verbose)
+ dataset_name = 'default' if custom_flag else dataset_name
+ multiple_choice_eval(
+ result_file,
+ dataset=dataset_name,
+ model='chatgpt-0613',
+ nproc=args.nproc,
+ verbose=args.verbose
+ )
elif DATASET_TYPE(dataset_name) == 'Y/N':
- YOrN_eval(result_file, model='chatgpt-0613', nproc=args.nproc, verbose=args.verbose, dataset=dataset_name)
+ YOrN_eval(
+ result_file,
+ model='chatgpt-0613',
+ nproc=args.nproc,
+ verbose=args.verbose,
+ dataset=dataset_name
+ )
elif DATASET_TYPE(dataset_name) == 'Caption':
COCO_eval(result_file)
elif dataset_name == 'MMVet':
@@ -111,6 +141,7 @@ def main():
LLaVABench_eval(result_file, model='gpt-4-turbo', nproc=args.nproc, verbose=args.verbose)
else:
logger.error(f'Dataset {dataset_name} is not handled by evaluator, will be skipped. ')
-
+
+
if __name__ == '__main__':
- main()
\ No newline at end of file
+ main()
diff --git a/setup.py b/setup.py
index e5e8df483..72b742e07 100644
--- a/setup.py
+++ b/setup.py
@@ -3,6 +3,7 @@
from os.path import exists
from setuptools import find_packages, setup
+
def parse_requirements(fname='requirements.txt', with_version=True):
"""Parse the package dependencies listed in a requirements file but strips
specific versioning information.
@@ -88,7 +89,7 @@ def do_setup():
name='vlmeval',
version='0.1.0',
description='OpenCompass VLM Evaluation Kit',
- author="Haodong Duan",
+ author='Haodong Duan',
author_email='dhd.efz@gmail.com',
maintainer='Haodong Duan',
maintainer_email='dhd.efz@gmail.com',
@@ -104,7 +105,7 @@ def do_setup():
]),
keywords=['AI', 'NLP', 'in-context learning'],
entry_points={
- "console_scripts": []
+ 'console_scripts': []
},
classifiers=[
'Programming Language :: Python :: 3.7',
diff --git a/vlmeval/__init__.py b/vlmeval/__init__.py
index 1d7a44a29..b8e96c287 100644
--- a/vlmeval/__init__.py
+++ b/vlmeval/__init__.py
@@ -8,4 +8,4 @@
from .evaluate import *
from .utils import *
from .vlm import *
-from .config import *
\ No newline at end of file
+from .config import *
diff --git a/vlmeval/api/__init__.py b/vlmeval/api/__init__.py
index 6982e1248..63a2e3101 100644
--- a/vlmeval/api/__init__.py
+++ b/vlmeval/api/__init__.py
@@ -5,10 +5,10 @@
from .qwen_vl_api import QwenVLWrapper, QwenVLAPI
from .qwen_api import QwenAPI
from .stepai import Step1V
-from .claude import Claude_Wrapper,Claude3V
+from .claude import Claude_Wrapper, Claude3V
__all__ = [
'OpenAIWrapper', 'HFChatModel', 'OpenAIWrapperInternal', 'GeminiWrapper',
- 'GPT4V', 'GPT4V_Internal', 'GeminiProVision','QwenVLWrapper', 'QwenVLAPI',
- 'QwenAPI', 'Step1V','Claude3V','Claude_Wrapper'
-]
\ No newline at end of file
+ 'GPT4V', 'GPT4V_Internal', 'GeminiProVision', 'QwenVLWrapper', 'QwenVLAPI',
+ 'QwenAPI', 'Step1V', 'Claude3V', 'Claude_Wrapper'
+]
diff --git a/vlmeval/api/base.py b/vlmeval/api/base.py
index b80ebec5b..f53979d0b 100644
--- a/vlmeval/api/base.py
+++ b/vlmeval/api/base.py
@@ -3,16 +3,17 @@
from abc import abstractmethod
from ..smp import get_logger
+
class BaseAPI:
-
- def __init__(self,
- retry=10,
- wait=3,
- system_prompt=None,
+
+ def __init__(self,
+ retry=10,
+ wait=3,
+ system_prompt=None,
verbose=True,
fail_msg='Failed to obtain answer via API.',
**kwargs):
- self.wait = wait
+ self.wait = wait
self.retry = retry
self.system_prompt = system_prompt
self.kwargs = kwargs
@@ -21,16 +22,16 @@ def __init__(self,
self.logger = get_logger('ChatAPI')
if len(kwargs):
self.logger.info(f'BaseAPI received the following kwargs: {kwargs}')
- self.logger.info(f'Will try to use them as kwargs for `generate`. ')
+ self.logger.info('Will try to use them as kwargs for `generate`. ')
@abstractmethod
def generate_inner(self, inputs, **kwargs):
- self.logger.warning(f'For APIBase, generate_inner is an abstract method. ')
+ self.logger.warning('For APIBase, generate_inner is an abstract method. ')
assert 0, 'generate_inner not defined'
ret_code, answer, log = None, None, None
# if ret_code is 0, means succeed
return ret_code, answer, log
-
+
def working(self):
retry = 3
while retry > 0:
@@ -54,7 +55,7 @@ def generate(self, inputs, **kwargs):
# a very small random delay [0s - 0.5s]
T = rd.random() * 0.5
time.sleep(T)
-
+
for i in range(self.retry):
try:
ret_code, answer, log = self.generate_inner(inputs, **kwargs)
@@ -67,8 +68,8 @@ def generate(self, inputs, **kwargs):
try:
log = log.text
except:
- self.logger.warning(f"Failed to parse {log} as an http response. ")
- self.logger.info(f"RetCode: {ret_code}\nAnswer: {answer}\nLog: {log}")
+ self.logger.warning(f'Failed to parse {log} as an http response. ')
+ self.logger.info(f'RetCode: {ret_code}\nAnswer: {answer}\nLog: {log}')
except Exception as err:
if self.verbose:
self.logger.error(f'An error occured during try {i}:')
@@ -76,5 +77,5 @@ def generate(self, inputs, **kwargs):
# delay before each retry
T = rd.random() * self.wait * 2
time.sleep(T)
-
+
return self.fail_msg if answer in ['', None] else answer
diff --git a/vlmeval/api/claude.py b/vlmeval/api/claude.py
index 7ee4b44cc..2aaa86b9a 100644
--- a/vlmeval/api/claude.py
+++ b/vlmeval/api/claude.py
@@ -3,27 +3,28 @@
from time import sleep
import base64
-url = "https://openxlab.org.cn/gw/alles-apin-hub/v1/claude/v1/text/chat"
+url = 'https://openxlab.org.cn/gw/alles-apin-hub/v1/claude/v1/text/chat'
headers = {
'alles-apin-token': '',
'Content-Type': 'application/json'
}
+
class Claude_Wrapper(BaseAPI):
is_api: bool = True
- def __init__(self,
+ def __init__(self,
model: str = 'claude-3-opus-20240229',
key: str = None,
retry: int = 10,
wait: int = 3,
system_prompt: str = None,
verbose: bool = True,
- temperature: float = 0,
+ temperature: float = 0,
max_tokens: int = 1024,
**kwargs):
-
+
self.model = model
self.headers = headers
self.temperature = temperature
@@ -35,28 +36,28 @@ def __init__(self,
self.headers['alles-apin-token'] = self.key
super().__init__(retry=retry, wait=wait, verbose=verbose, system_prompt=system_prompt, **kwargs)
-
+
@staticmethod
def build_msgs(msgs_raw):
messages = []
- message = {"role": "user", "content": []}
+ message = {'role': 'user', 'content': []}
for msg in msgs_raw:
if isimg(msg):
media_type_map = {
- 'jpg': 'image/jpeg',
- 'jpeg': 'image/jpeg',
- 'png': 'image/png',
- 'gif': 'image/gif',
+ 'jpg': 'image/jpeg',
+ 'jpeg': 'image/jpeg',
+ 'png': 'image/png',
+ 'gif': 'image/gif',
'webp': 'iamge/webp'
}
media_type = media_type_map[msg.split('.')[-1].lower()]
- with open(msg, "rb") as file:
- image_data = base64.b64encode(file.read()).decode("utf-8")
+ with open(msg, 'rb') as file:
+ image_data = base64.b64encode(file.read()).decode('utf-8')
item = {
- 'type': 'image',
+ 'type': 'image',
'source': {'type': 'base64', 'media_type': media_type, 'data': image_data}
}
-
+
else:
item = {'type': 'text', 'text': msg}
message['content'].append(item)
@@ -66,21 +67,21 @@ def build_msgs(msgs_raw):
def generate_inner(self, inputs, **kwargs) -> str:
payload = json.dumps({
- "model": self.model,
- "max_tokens": self.max_tokens,
- "messages": self.build_msgs(msgs_raw=inputs),
- **kwargs
+ 'model': self.model,
+ 'max_tokens': self.max_tokens,
+ 'messages': self.build_msgs(msgs_raw=inputs),
+ **kwargs
})
- response = requests.request("POST", url, headers=headers, data=payload)
+ response = requests.request('POST', url, headers=headers, data=payload)
ret_code = response.status_code
retry = self.retry
while ret_code == 429 and retry > 0:
sleep(15)
- response = requests.request("POST", url, headers=headers, data=payload)
+ response = requests.request('POST', url, headers=headers, data=payload)
ret_code = response.status_code
retry -= 1
-
+
ret_code = 0 if (200 <= int(ret_code) < 300) else ret_code
answer = self.fail_msg
@@ -97,6 +98,6 @@ class Claude3V(Claude_Wrapper):
def generate(self, image_path, prompt, dataset=None):
return super(Claude_Wrapper, self).generate([image_path, prompt])
-
+
def interleave_generate(self, ti_list, dataset=None):
- return super(Claude_Wrapper, self).generate(ti_list)
\ No newline at end of file
+ return super(Claude_Wrapper, self).generate(ti_list)
diff --git a/vlmeval/api/gemini.py b/vlmeval/api/gemini.py
index 99c701340..afba2f28e 100644
--- a/vlmeval/api/gemini.py
+++ b/vlmeval/api/gemini.py
@@ -3,16 +3,17 @@
headers = 'Content-Type: application/json'
+
class GeminiWrapper(BaseAPI):
is_api: bool = True
- def __init__(self,
+ def __init__(self,
retry: int = 5,
- wait: int = 5,
+ wait: int = 5,
key: str = None,
- verbose: bool = True,
- temperature: float = 0.0,
+ verbose: bool = True,
+ temperature: float = 0.0,
system_prompt: str = None,
max_tokens: int = 1024,
proxy: str = None,
@@ -28,10 +29,10 @@ def __init__(self,
if proxy is not None:
proxy_set(proxy)
super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
-
+
@staticmethod
def build_msgs(msgs_raw, system_prompt=None):
- msgs = cp.deepcopy(msgs_raw)
+ msgs = cp.deepcopy(msgs_raw)
assert len(msgs) % 2 == 1
if system_prompt is not None:
@@ -68,7 +69,7 @@ def generate_inner(self, inputs, **kwargs) -> str:
shutil.remove(pth)
else:
messages.append(s)
- gen_config = dict(max_output_tokens=self.max_tokens, temperature=self.temperature)
+ gen_config = dict(max_output_tokens=self.max_tokens, temperature=self.temperature)
gen_config.update(self.kwargs)
try:
answer = model.generate_content(messages, generation_config=genai.types.GenerationConfig(**gen_config)).text
@@ -76,16 +77,15 @@ def generate_inner(self, inputs, **kwargs) -> str:
except Exception as err:
if self.verbose:
self.logger.error(err)
- self.logger.error(f"The input messages are {inputs}.")
+ self.logger.error(f'The input messages are {inputs}.')
return -1, '', ''
-
class GeminiProVision(GeminiWrapper):
def generate(self, image_path, prompt, dataset=None):
return super(GeminiProVision, self).generate([image_path, prompt])
-
+
def interleave_generate(self, ti_list, dataset=None):
- return super(GeminiProVision, self).generate(ti_list)
\ No newline at end of file
+ return super(GeminiProVision, self).generate(ti_list)
diff --git a/vlmeval/api/gpt.py b/vlmeval/api/gpt.py
index e2bbec21c..dc05a4079 100644
--- a/vlmeval/api/gpt.py
+++ b/vlmeval/api/gpt.py
@@ -1,48 +1,50 @@
from ..smp import *
-import os, sys
+import os
+import sys
from .base import BaseAPI
APIBASES = {
- 'OFFICIAL': "https://api.openai.com/v1/chat/completions",
+ 'OFFICIAL': 'https://api.openai.com/v1/chat/completions',
}
def GPT_context_window(model):
length_map = {
- 'gpt-4-1106-preview': 128000,
- 'gpt-4-vision-preview': 128000,
+ 'gpt-4-1106-preview': 128000,
+ 'gpt-4-vision-preview': 128000,
'gpt-4': 8192,
'gpt-4-32k': 32768,
- 'gpt-4-0613': 8192,
+ 'gpt-4-0613': 8192,
'gpt-4-32k-0613': 32768,
- 'gpt-3.5-turbo-1106': 16385,
- 'gpt-3.5-turbo': 4096,
- 'gpt-3.5-turbo-16k': 16385,
- 'gpt-3.5-turbo-instruct': 4096,
- 'gpt-3.5-turbo-0613': 4096,
- 'gpt-3.5-turbo-16k-0613': 16385,
+ 'gpt-3.5-turbo-1106': 16385,
+ 'gpt-3.5-turbo': 4096,
+ 'gpt-3.5-turbo-16k': 16385,
+ 'gpt-3.5-turbo-instruct': 4096,
+ 'gpt-3.5-turbo-0613': 4096,
+ 'gpt-3.5-turbo-16k-0613': 16385,
}
if model in length_map:
return length_map[model]
else:
return 4096
+
class OpenAIWrapper(BaseAPI):
is_api: bool = True
- def __init__(self,
- model: str = 'gpt-3.5-turbo-0613',
+ def __init__(self,
+ model: str = 'gpt-3.5-turbo-0613',
retry: int = 5,
- wait: int = 5,
+ wait: int = 5,
key: str = None,
- verbose: bool = True,
+ verbose: bool = True,
system_prompt: str = None,
temperature: float = 0,
timeout: int = 60,
api_base: str = 'OFFICIAL',
max_tokens: int = 1024,
- img_size: int = 512,
+ img_size: int = 512,
img_detail: str = 'low',
**kwargs):
@@ -64,8 +66,11 @@ def __init__(self,
if model == 'gpt-4-vision-preview':
self.vision = True
self.timeout = timeout
-
- assert isinstance(openai_key, str) and openai_key.startswith('sk-'), f'Illegal openai_key {openai_key}. Please set the environment variable OPENAI_API_KEY to your openai key. '
+
+ assert isinstance(openai_key, str) and openai_key.startswith('sk-'), (
+ f'Illegal openai_key {openai_key}. '
+ 'Please set the environment variable OPENAI_API_KEY to your openai key. '
+ )
super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
if api_base in APIBASES:
@@ -73,11 +78,11 @@ def __init__(self,
elif api_base.startswith('http'):
self.api_base = api_base
else:
- self.logger.error("Unknown API Base. ")
+ self.logger.error('Unknown API Base. ')
sys.exit(-1)
if 'OPENAI_API_BASE' in os.environ and os.environ['OPENAI_API_BASE'] != '':
- self.logger.error("Environment variable OPENAI_API_BASE is set. Will override the api_base arg. ")
+ self.logger.error('Environment variable OPENAI_API_BASE is set. Will override the api_base arg. ')
self.api_base = os.environ['OPENAI_API_BASE']
# inputs can be a lvl-2 nested list: [content1, content2, content3, ...]
@@ -118,7 +123,7 @@ def prepare_inputs(self, inputs):
for role, msg in zip(roles, inputs):
input_msgs.append(dict(role=role, content=msg))
return input_msgs
- raise NotImplemented("list of list prompt not implemented now. ")
+ raise NotImplementedError('list of list prompt not implemented now. ')
def generate_inner(self, inputs, **kwargs) -> str:
input_msgs = self.prepare_inputs(inputs)
@@ -128,16 +133,19 @@ def generate_inner(self, inputs, **kwargs) -> str:
context_window = GPT_context_window(self.model)
max_tokens = min(max_tokens, context_window - self.get_token_len(inputs))
if 0 < max_tokens <= 100:
- self.logger.warning('Less than 100 tokens left, may exceed the context window with some additional meta symbols. ')
+ self.logger.warning(
+ 'Less than 100 tokens left, '
+ 'may exceed the context window with some additional meta symbols. '
+ )
if max_tokens <= 0:
return 0, self.fail_msg + 'Input string longer than context window. ', 'Length Exceeded. '
-
+
headers = {'Content-Type': 'application/json', 'Authorization': f'Bearer {self.openai_key}'}
payload = dict(
- model=self.model,
+ model=self.model,
messages=input_msgs,
max_tokens=max_tokens,
- n=1,
+ n=1,
temperature=temperature,
**kwargs)
response = requests.post(self.api_base, headers=headers, data=json.dumps(payload), timeout=self.timeout * 1.1)
@@ -170,14 +178,14 @@ def get_token_len(self, inputs) -> int:
for item in inputs:
res += self.get_token_len(item)
return res
-
+
+
class GPT4V(OpenAIWrapper):
def generate(self, image_path, prompt, dataset=None):
assert self.model == 'gpt-4-vision-preview'
return super(GPT4V, self).generate([image_path, prompt])
-
+
def interleave_generate(self, ti_list, dataset=None):
assert self.model == 'gpt-4-vision-preview'
return super(GPT4V, self).generate(ti_list)
-
\ No newline at end of file
diff --git a/vlmeval/api/gpt_int.py b/vlmeval/api/gpt_int.py
index c23593b83..2844bd6f5 100644
--- a/vlmeval/api/gpt_int.py
+++ b/vlmeval/api/gpt_int.py
@@ -4,29 +4,29 @@
from ..smp import *
from .gpt import GPT_context_window, OpenAIWrapper
-
-url = "http://ecs.sv.us.alles-apin.openxlab.org.cn/v1/openai/v2/text/chat"
+url = 'http://ecs.sv.us.alles-apin.openxlab.org.cn/v1/openai/v2/text/chat'
headers = {
- "Content-Type": "application/json"
+ 'Content-Type': 'application/json'
}
+
class OpenAIWrapperInternal(OpenAIWrapper):
is_api: bool = True
- def __init__(self,
- model: str = 'gpt-3.5-turbo-0613',
+ def __init__(self,
+ model: str = 'gpt-3.5-turbo-0613',
retry: int = 5,
wait: int = 3,
verbose: bool = True,
- system_prompt: str = None,
- temperature: float = 0,
+ system_prompt: str = None,
+ temperature: float = 0,
timeout: int = 60,
max_tokens: int = 1024,
- img_size: int = 512,
+ img_size: int = 512,
img_detail: str = 'low',
**kwargs):
-
+
self.model = model
if 'KEYS' in os.environ and osp.exists(os.environ['KEYS']):
keys = load(os.environ['KEYS'])
@@ -52,7 +52,7 @@ def __init__(self,
def generate_inner(self, inputs, **kwargs) -> str:
input_msgs = self.prepare_inputs(inputs)
-
+
temperature = kwargs.pop('temperature', self.temperature)
max_tokens = kwargs.pop('max_tokens', self.max_tokens)
@@ -65,15 +65,15 @@ def generate_inner(self, inputs, **kwargs) -> str:
return 0, self.fail_msg + 'Input string longer than context window. ', 'Length Exceeded. '
payload = dict(
- model=self.model,
- messages=input_msgs,
+ model=self.model,
+ messages=input_msgs,
max_tokens=max_tokens,
n=1,
stop=None,
timeout=self.timeout,
temperature=temperature,
**kwargs)
-
+
response = requests.post(url, headers=headers, data=json.dumps(payload), timeout=self.timeout * 1.1)
ret_code = response.status_code
ret_code = 0 if (200 <= int(ret_code) < 300) else ret_code
@@ -86,14 +86,14 @@ def generate_inner(self, inputs, **kwargs) -> str:
except:
pass
return ret_code, answer, response
-
+
class GPT4V_Internal(OpenAIWrapperInternal):
def generate(self, image_path, prompt, dataset=None):
assert self.model == 'gpt-4-vision-preview'
return super(GPT4V_Internal, self).generate([image_path, prompt])
-
+
def interleave_generate(self, ti_list, dataset=None):
assert self.model == 'gpt-4-vision-preview'
- return super(GPT4V_Internal, self).generate(ti_list)
\ No newline at end of file
+ return super(GPT4V_Internal, self).generate(ti_list)
diff --git a/vlmeval/api/hf_chat_model.py b/vlmeval/api/hf_chat_model.py
index b12b81336..9f2ae3c73 100644
--- a/vlmeval/api/hf_chat_model.py
+++ b/vlmeval/api/hf_chat_model.py
@@ -1,8 +1,10 @@
-import os, sys
+import os
+import sys
import os.path as osp
import torch
from ..smp import *
+
def get_gpu_num(model_name):
model_name = model_name.lower()
kws = {
@@ -17,9 +19,10 @@ def get_gpu_num(model_name):
return k
return 8
+
validated_llms = [
'internlm/internlm-chat-7b', 'internlm/internlm-chat-7b-8k', 'internlm/internlm-chat-20b',
- 'Qwen/Qwen-7B-Chat', 'Qwen/Qwen-14B-Chat',
+ 'Qwen/Qwen-7B-Chat', 'Qwen/Qwen-14B-Chat',
'THUDM/chatglm2-6b', 'THUDM/chatglm2-6b-32k', 'THUDM/chatglm3-6b', 'THUDM/chatglm3-6b-32k',
'baichuan-inc/Baichuan2-7B-Chat', 'baichuan-inc/Baichuan2-13B-Chat',
'lmsys/vicuna-7b-v1.5', 'lmsys/vicuna-13b-v1.5',
@@ -27,6 +30,7 @@ def get_gpu_num(model_name):
]
Auto_model = ['chatglm']
+
class HFChatModel:
def _get_context_length(self, model, model_path):
@@ -42,71 +46,72 @@ def _get_context_length(self, model, model_path):
# chatglm & qwen
context_window = model.config.seq_length
return context_window
-
+
def _get_context_length_robust(self, model, model_path):
try:
context_window = self._get_context_length(model, model_path)
return context_window
except:
self.logger.critical(
- "Failed to extract context_window information from config / generation_config. "
- "Please read the above code and check if the logic works for you model path"
+ 'Failed to extract context_window information from config / generation_config. '
+ 'Please read the above code and check if the logic works for you model path'
)
raise NotImplementedError
-
- def __init__(self,
- model_path,
- system_prompt: str=None,
+
+ def __init__(self,
+ model_path,
+ system_prompt: str = None,
**kwargs):
-
+
self.logger = get_logger('HFChatModel')
if 'vicuna' in model_path.lower():
try:
from fastchat.model import get_conversation_template
except:
- self.logger.critical("Please install fastchat first to use vicuna. ")
+ self.logger.critical('Please install fastchat first to use vicuna. ')
sys.exit(-1)
self.explicit_device = kwargs.pop('device', None)
if self.explicit_device is None:
# If CUDA_VISIBLE_DEVICES is not properly set
- if 'CUDA_VISIBLE_DEVICES' not in os.environ or os.environ['CUDA_VISIBLE_DEVICES'] in ['', '0,1,2,3,4,5,6,7']:
+ if 'CUDA_VISIBLE_DEVICES' not in os.environ or os.environ['CUDA_VISIBLE_DEVICES'] == '0,1,2,3,4,5,6,7':
num_gpu = get_gpu_num(model_path)
gpu_offset = kwargs.pop('gpu_offset', 0)
- cuda_visible_devices = ','.join([str(i) for i in range(gpu_offset, gpu_offset+num_gpu)])
+ cuda_visible_devices = ','.join([str(i) for i in range(gpu_offset, gpu_offset + num_gpu)])
os.environ['CUDA_VISIBLE_DEVICES'] = cuda_visible_devices
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
from transformers.generation import GenerationConfig
-
+
if model_path not in validated_llms:
- self.logger.warning(f"{model_path} not in validated LLMs, may have inference troubles. ")
+ self.logger.warning(f'{model_path} not in validated LLMs, may have inference troubles. ')
self.model_path = model_path
if listinstr(Auto_model, model_path):
LoadModel = AutoModel
else:
LoadModel = AutoModelForCausalLM
-
+
assert osp.exists(model_path) or len(model_path.split('/')) == 2
- device = self.explicit_device if self.explicit_device else "auto"
-
+ device = self.explicit_device if self.explicit_device else 'auto'
+
precision = {}
if 'internlm-chat-7b' in model_path:
precision = {'torch_dtype': torch.float16}
elif 'internlm-chat-20b' in model_path:
precision = {'torch_dtype': torch.bfloat16}
-
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = LoadModel.from_pretrained(model_path, trust_remote_code=True, device_map='cpu', **precision)
model = model.eval()
-
+
if device != 'cpu':
model = model.to(f'cuda:{device}' if isinstance(device, int) else 'cuda')
try:
- model.generation_config = GenerationConfig.from_pretrained(model_path, trust_remote_code=True, device_map=device)
+ model.generation_config = GenerationConfig.from_pretrained(
+ model_path, trust_remote_code=True, device_map=device)
except:
pass
@@ -116,21 +121,21 @@ def __init__(self,
self.answer_buffer = 192
self.system_prompt = system_prompt
for k, v in kwargs.items():
- self.logger.info(f'Following args are passed and will be used as generation hyper-paras (If not set specifically), {k}: {v}. ')
+ self.logger.info(f'Following args will be used for generation (If not set specifically), {k}: {v}. ')
self.kwargs = kwargs
-
+
def generate_str(self, input, **kwargs):
if 'baichuan' in self.model_path.lower():
- messages=[]
- messages.append({"role": "user", "content": input})
- resp= self.model.chat(self.tokenizer, messages, **kwargs)
+ messages = []
+ messages.append({'role': 'user', 'content': input})
+ resp = self.model.chat(self.tokenizer, messages, **kwargs)
elif 'vicuna' in self.model_path.lower():
from fastchat.model import get_conversation_template
conv = get_conversation_template('vicuna')
conv.append_message(conv.roles[0], input)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
- inputs = self.tokenizer([prompt], return_tensors="pt")
+ inputs = self.tokenizer([prompt], return_tensors='pt')
if torch.cuda.is_available():
for k in inputs:
inputs[k] = inputs[k].cuda()
@@ -139,7 +144,10 @@ def generate_str(self, input, **kwargs):
params.update(self.kwargs)
params.update(kwargs)
outputs = self.model.generate(**inputs, **params)
- resp = self.tokenizer.decode(outputs[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True, spaces_between_special_tokens=False)
+ resp = self.tokenizer.decode(
+ outputs[0][len(inputs['input_ids'][0]):],
+ skip_special_tokens=True,
+ spaces_between_special_tokens=False)
else:
params = self.kwargs
@@ -153,16 +161,16 @@ def length_ok(self, inputs):
for s in inputs:
tot += len(self.tokenizer.encode(s))
return tot + self.answer_buffer < self.context_length
-
+
def generate_list(self, full_inputs, offset=0, **kwargs):
assert isinstance(full_inputs, list)
inputs = full_inputs[offset:]
if not self.length_ok(inputs):
return self.chat(full_inputs, offset + 1)
-
+
model_path = self.model_path.lower()
-
+
if sum([x in model_path for x in ['baichuan']]):
input_msgs = []
if self.system_prompt is not None:
@@ -181,7 +189,7 @@ def generate_list(self, full_inputs, offset=0, **kwargs):
if len(inputs) % 2 == 1:
if self.system_prompt is not None:
conv.append_message(conv.roles[0], self.system_prompt)
- for i in range(len(inputs)//2):
+ for i in range(len(inputs) // 2):
conv.append_message(conv.roles[0], inputs[2 * i])
conv.append_message(conv.roles[1], inputs[2 * i + 1])
else:
@@ -194,7 +202,7 @@ def generate_list(self, full_inputs, offset=0, **kwargs):
conv.append_message(conv.roles[0], inputs[-1])
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
- inputs = self.tokenizer([prompt], return_tensors="pt")
+ inputs = self.tokenizer([prompt], return_tensors='pt')
if torch.cuda.is_available():
for k in inputs:
inputs[k] = inputs[k].cuda()
@@ -204,7 +212,10 @@ def generate_list(self, full_inputs, offset=0, **kwargs):
params.update(kwargs)
outputs = self.model.generate(**inputs, **params)
- response = self.tokenizer.decode(outputs[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True, spaces_between_special_tokens=False)
+ response = self.tokenizer.decode(
+ outputs[0][len(inputs['input_ids'][0]):],
+ skip_special_tokens=True,
+ spaces_between_special_tokens=False)
response = response.lstrip('\n')
else:
# The default option, support internlm, chatglm, qwen
@@ -212,8 +223,8 @@ def generate_list(self, full_inputs, offset=0, **kwargs):
if len(inputs) % 2 == 1:
if self.system_prompt is not None:
history = [(self.system_prompt, '')]
- for i in range(len(inputs)//2):
- history.append((inputs[2 * i], inputs[2 * i + 1]))
+ for i in range(len(inputs) // 2):
+ history.append((inputs[2 * i], inputs[2 * i + 1]))
else:
assert self.system_prompt is not None
history = [(self.system_prompt, inputs[0])]
@@ -224,11 +235,11 @@ def generate_list(self, full_inputs, offset=0, **kwargs):
params = self.kwargs
params.update(kwargs)
response, _ = self.model.chat(self.tokenizer, msg, history=history, **params)
-
+
return response, offset
-
+
def generate(self, inputs, **kwargs):
if isinstance(inputs, str):
return self.generate_str(inputs, **kwargs)
elif isinstance(inputs, list):
- return self.generate_list(inputs, **kwargs)
\ No newline at end of file
+ return self.generate_list(inputs, **kwargs)
diff --git a/vlmeval/api/qwen_api.py b/vlmeval/api/qwen_api.py
index e1bf07505..8d83da2a2 100644
--- a/vlmeval/api/qwen_api.py
+++ b/vlmeval/api/qwen_api.py
@@ -3,19 +3,20 @@
from vlmeval.api.base import BaseAPI
from vlmeval.smp import *
+
class QwenAPI(BaseAPI):
is_api: bool = True
- def __init__(self,
+ def __init__(self,
model: str = 'qwen-max-1201',
retry: int = 5,
- wait: int = 5,
- verbose: bool = True,
- seed: int = 2680,
- temperature: float = 0.0,
+ wait: int = 5,
+ verbose: bool = True,
+ seed: int = 2680,
+ temperature: float = 0.0,
system_prompt: str = None,
- key: str = None,
+ key: str = None,
max_tokens: int = 1024,
proxy: str = None,
**kwargs):
@@ -29,15 +30,18 @@ def __init__(self,
self.seed = seed
if key is None:
key = os.environ.get('DASHSCOPE_API_KEY', None)
- assert key is not None, "Please set the API Key (obtain it here: https://help.aliyun.com/zh/dashscope/developer-reference/vl-plus-quick-start)"
+ assert key is not None, (
+ 'Please set the API Key (obtain it here: '
+ 'https://help.aliyun.com/zh/dashscope/developer-reference/vl-plus-quick-start)'
+ )
dashscope.api_key = key
if proxy is not None:
proxy_set(proxy)
super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
-
+
@staticmethod
def build_msgs(msgs_raw, system_prompt=None):
- msgs = cp.deepcopy(msgs_raw)
+ msgs = cp.deepcopy(msgs_raw)
ret = []
if system_prompt is not None:
ret.append(dict(role='system', content=system_prompt))
@@ -45,7 +49,7 @@ def build_msgs(msgs_raw, system_prompt=None):
role = 'user' if i % 2 == 0 else 'assistant'
ret.append(dict(role=role, content=msg))
return ret
-
+
def generate_inner(self, inputs, **kwargs) -> str:
from dashscope import MultiModalConversation
assert isinstance(inputs, str) or isinstance(inputs, list)
@@ -63,8 +67,8 @@ def generate_inner(self, inputs, **kwargs) -> str:
)
if response.status_code != HTTPStatus.OK:
return -1, 'Error: Bad Response Statuse Code. ', f'The response status code is {response.status_code}. '
-
+
try:
return 0, response['output']['choices'][0]['message']['content'].strip(), 'Succeeded! '
except Exception as err:
- return -1, f'Error: Failed to parse the response. {err}', response
\ No newline at end of file
+ return -1, f'Error: Failed to parse the response. {err}', response
diff --git a/vlmeval/api/qwen_vl_api.py b/vlmeval/api/qwen_vl_api.py
index da9122be7..b413bd42e 100644
--- a/vlmeval/api/qwen_vl_api.py
+++ b/vlmeval/api/qwen_vl_api.py
@@ -1,17 +1,18 @@
from vlmeval.smp import *
from vlmeval.api.base import BaseAPI
+
class QwenVLWrapper(BaseAPI):
is_api: bool = True
- def __init__(self,
+ def __init__(self,
model: str = 'qwen-vl-plus',
retry: int = 5,
- wait: int = 5,
+ wait: int = 5,
key: str = None,
- verbose: bool = True,
- temperature: float = 0.0,
+ verbose: bool = True,
+ temperature: float = 0.0,
system_prompt: str = None,
max_tokens: int = 1024,
proxy: str = None,
@@ -25,15 +26,18 @@ def __init__(self,
self.temperature = temperature
if key is None:
key = os.environ.get('DASHSCOPE_API_KEY', None)
- assert key is not None, "Please set the API Key (obtain it here: https://help.aliyun.com/zh/dashscope/developer-reference/vl-plus-quick-start)"
+ assert key is not None, (
+ 'Please set the API Key (obtain it here: '
+ 'https://help.aliyun.com/zh/dashscope/developer-reference/vl-plus-quick-start)'
+ )
dashscope.api_key = key
if proxy is not None:
proxy_set(proxy)
super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
-
+
@staticmethod
def build_msgs(msgs_raw, system_prompt=None):
- msgs = cp.deepcopy(msgs_raw)
+ msgs = cp.deepcopy(msgs_raw)
ret = []
if system_prompt is not None:
content = list(dict(text=system_prompt))
@@ -48,7 +52,7 @@ def build_msgs(msgs_raw, system_prompt=None):
content.append(dict(text=msg))
ret.append(dict(role='user', content=content))
return ret
-
+
def generate_inner(self, inputs, **kwargs) -> str:
from dashscope import MultiModalConversation
assert isinstance(inputs, str) or isinstance(inputs, list)
@@ -59,26 +63,26 @@ def generate_inner(self, inputs, **kwargs) -> str:
pure_text = False
assert not pure_text
messages = self.build_msgs(msgs_raw=inputs, system_prompt=self.system_prompt)
- gen_config = dict(max_output_tokens=self.max_tokens, temperature=self.temperature)
+ gen_config = dict(max_output_tokens=self.max_tokens, temperature=self.temperature)
gen_config.update(self.kwargs)
try:
response = MultiModalConversation.call(model=self.model, messages=messages)
if self.verbose:
- print(response)
+ print(response)
answer = response.output.choices[0]['message']['content'][0]['text']
return 0, answer, 'Succeeded! '
except Exception as err:
if self.verbose:
self.logger.error(err)
- self.logger.error(f"The input messages are {inputs}.")
+ self.logger.error(f'The input messages are {inputs}.')
return -1, '', ''
+
class QwenVLAPI(QwenVLWrapper):
def generate(self, image_path, prompt, dataset=None):
return super(QwenVLAPI, self).generate([image_path, prompt])
-
+
def interleave_generate(self, ti_list, dataset=None):
return super(QwenVLAPI, self).generate(ti_list)
-
\ No newline at end of file
diff --git a/vlmeval/api/stepai.py b/vlmeval/api/stepai.py
index f0d0c517d..45c6b3b26 100644
--- a/vlmeval/api/stepai.py
+++ b/vlmeval/api/stepai.py
@@ -1,28 +1,30 @@
from vlmeval.smp import *
from vlmeval.api.base import BaseAPI
-url = "https://b-openapi.basemind.com/openapi/v1/chat/completions"
+url = 'https://b-openapi.basemind.com/openapi/v1/chat/completions'
headers = {
'X-Request-Orgcode': 'companyA',
'Authorization': 'Bearer {}',
'Content-Type': 'application/json'
}
+
def convert_image_to_base64(image_path):
- with open(image_path, "rb") as image_file:
+ with open(image_path, 'rb') as image_file:
encoded_string = base64.b64encode(image_file.read()).decode()
return encoded_string
+
class StepAPI(BaseAPI):
is_api: bool = True
- def __init__(self,
+ def __init__(self,
model: str = 'stepapi-rankboard',
retry: int = 10,
wait: int = 3,
key: str = None,
- temperature: float = 0,
+ temperature: float = 0,
max_tokens: int = 300,
verbose: bool = True,
system_prompt: str = None,
@@ -40,35 +42,35 @@ def __init__(self,
headers['Authorization'] = headers['Authorization'].format(self.key)
super().__init__(retry=retry, wait=wait, verbose=verbose, system_prompt=system_prompt, **kwargs)
-
+
@staticmethod
def build_msgs(msgs_raw):
messages = []
- message = {"role": "user", "content": []}
-
+ message = {'role': 'user', 'content': []}
+
for msg in msgs_raw:
if isimg(msg):
image_b64 = convert_image_to_base64(msg)
message['content'].append({
- "image_b64": {'b64_json': image_b64},
- "type": "image_b64"
+ 'image_b64': {'b64_json': image_b64},
+ 'type': 'image_b64'
})
else:
message['content'].append({
'text': msg,
- "type": 'text'
+ 'type': 'text'
})
messages.append(message)
return messages
-
+
def generate_inner(self, inputs, **kwargs) -> str:
print(inputs, '\n')
payload = dict(
- model=self.model,
+ model=self.model,
max_tokens=self.max_tokens,
temperature=self.temperature,
- messages= self.build_msgs(msgs_raw=inputs), #需要构建message
+ messages=self.build_msgs(msgs_raw=inputs),
**kwargs)
response = requests.post(url, headers=headers, data=json.dumps(payload))
# print('response is here!!:',response.text,'\n')
@@ -87,12 +89,12 @@ def generate_inner(self, inputs, **kwargs) -> str:
pass
# print('finial answer is',answer)
return ret_code, answer, response
-
+
class Step1V(StepAPI):
def generate(self, image_path, prompt, dataset=None):
return super(StepAPI, self).generate([image_path, prompt])
-
+
def interleave_generate(self, ti_list, dataset=None):
- return super(StepAPI, self).generate(ti_list)
\ No newline at end of file
+ return super(StepAPI, self).generate(ti_list)
diff --git a/vlmeval/config.py b/vlmeval/config.py
index 28288d868..0f48e4bd0 100644
--- a/vlmeval/config.py
+++ b/vlmeval/config.py
@@ -16,10 +16,10 @@
'PandaGPT_13B': partial(PandaGPT, name='PandaGPT_13B', root=PandaGPT_ROOT),
'flamingov2': partial(OpenFlamingo, name='v2', mpt_pth='anas-awadalla/mpt-7b', ckpt_pth='openflamingo/OpenFlamingo-9B-vitl-mpt7b'),
'flamingov2_fs': partial(OpenFlamingo, name='v2', with_context=True, mpt_pth='anas-awadalla/mpt-7b', ckpt_pth='openflamingo/OpenFlamingo-9B-vitl-mpt7b'),
- 'idefics_9b_instruct': partial(IDEFICS, model_pth="HuggingFaceM4/idefics-9b-instruct"),
- 'idefics_80b_instruct': partial(IDEFICS, model_pth="HuggingFaceM4/idefics-80b-instruct"),
- 'idefics_9b_instruct_fs': partial(IDEFICS, model_pth="HuggingFaceM4/idefics-9b-instruct", with_context=True),
- 'idefics_80b_instruct_fs': partial(IDEFICS, model_pth="HuggingFaceM4/idefics-80b-instruct", with_context=True),
+ 'idefics_9b_instruct': partial(IDEFICS, model_pth='HuggingFaceM4/idefics-9b-instruct'),
+ 'idefics_80b_instruct': partial(IDEFICS, model_pth='HuggingFaceM4/idefics-80b-instruct'),
+ 'idefics_9b_instruct_fs': partial(IDEFICS, model_pth='HuggingFaceM4/idefics-9b-instruct', with_context=True),
+ 'idefics_80b_instruct_fs': partial(IDEFICS, model_pth='HuggingFaceM4/idefics-80b-instruct', with_context=True),
'llava_v1.5_7b': partial(LLaVA, model_pth='liuhaotian/llava-v1.5-7b'),
'llava_v1.5_13b': partial(LLaVA, model_pth='liuhaotian/llava-v1.5-13b'),
'llava_v1_7b': partial(LLaVA, model_pth=LLAVA_V1_7B_MODEL_PTH),
@@ -27,13 +27,13 @@
'sharegpt4v_13b': partial(LLaVA, model_pth='Lin-Chen/ShareGPT4V-13B'),
'instructblip_7b': partial(InstructBLIP, name='instructblip_7b'),
'instructblip_13b': partial(InstructBLIP, name='instructblip_13b'),
- 'VisualGLM_6b': partial(VisualGLM, model_path="THUDM/visualglm-6b"),
+ 'VisualGLM_6b': partial(VisualGLM, model_path='THUDM/visualglm-6b'),
'MiniGPT-4-v2': partial(MiniGPT4, mode='v2', root=MiniGPT4_ROOT),
'MiniGPT-4-v1-7B': partial(MiniGPT4, mode='v1_7b', root=MiniGPT4_ROOT),
'MiniGPT-4-v1-13B': partial(MiniGPT4, mode='v1_13b', root=MiniGPT4_ROOT),
- "XComposer": partial(XComposer, model_path='internlm/internlm-xcomposer-vl-7b'),
- "XComposer2": partial(XComposer2, model_path='internlm/internlm-xcomposer2-vl-7b'),
- "mPLUG-Owl2": partial(mPLUG_Owl2, model_path='MAGAer13/mplug-owl2-llama2-7b'),
+ 'XComposer': partial(XComposer, model_path='internlm/internlm-xcomposer-vl-7b'),
+ 'XComposer2': partial(XComposer2, model_path='internlm/internlm-xcomposer2-vl-7b'),
+ 'mPLUG-Owl2': partial(mPLUG_Owl2, model_path='MAGAer13/mplug-owl2-llama2-7b'),
'cogvlm-grounding-generalist':partial(CogVlm, name='cogvlm-grounding-generalist',tokenizer_name ='lmsys/vicuna-7b-v1.5'),
'cogvlm-chat':partial(CogVlm, name='cogvlm-chat',tokenizer_name ='lmsys/vicuna-7b-v1.5'),
'sharedcaptioner':partial(SharedCaptioner, model_path='Lin-Chen/ShareCaptioner'),
@@ -56,12 +56,12 @@
# Internal Only
'GPT4V_INT': partial(GPT4V_Internal, model='gpt-4-vision-preview', temperature=0, img_size=512, img_detail='low', retry=10),
'GPT4V_SHORT': partial(
- GPT4V, model='gpt-4-vision-preview', temperature=0, img_size=512, img_detail='low', retry=10,
- system_prompt="Please responde to the following question / request in a short reply. "),
+ GPT4V, model='gpt-4-vision-preview', temperature=0, img_size=512, img_detail='low', retry=10,
+ system_prompt='Please responde to the following question / request in a short reply. '),
# Internal Only
'GPT4V_SHORT_INT': partial(
GPT4V_Internal, model='gpt-4-vision-preview', temperature=0, img_size=512, img_detail='low', retry=10,
- system_prompt="Please responde to the following question / request in a short reply. "),
+ system_prompt='Please responde to the following question / request in a short reply. '),
'GeminiProVision': partial(GeminiProVision, temperature=0, retry=10),
'QwenVLPlus': partial(QwenVLAPI, model='qwen-vl-plus', temperature=0, retry=10),
'QwenVLMax': partial(QwenVLAPI, model='qwen-vl-max', temperature=0, retry=10),
diff --git a/vlmeval/evaluate/OCRBench.py b/vlmeval/evaluate/OCRBench.py
index 4f3a6a3c0..06cf767c3 100644
--- a/vlmeval/evaluate/OCRBench.py
+++ b/vlmeval/evaluate/OCRBench.py
@@ -1,8 +1,18 @@
from vlmeval.smp import *
-OCRBench_score = {"Regular Text Recognition":0,"Irregular Text Recognition":0,"Artistic Text Recognition":0,"Handwriting Recognition":0,
-"Digit String Recognition":0,"Non-Semantic Text Recognition":0,"Scene Text-centric VQA":0,"Doc-oriented VQA":0,
-"Key Information Extraction":0,"Handwritten Mathematical Expression Recognition":0}
+OCRBench_score = {
+ 'Regular Text Recognition': 0,
+ 'Irregular Text Recognition': 0,
+ 'Artistic Text Recognition': 0,
+ 'Handwriting Recognition': 0,
+ 'Digit String Recognition': 0,
+ 'Non-Semantic Text Recognition': 0,
+ 'Scene Text-centric VQA': 0,
+ 'Doc-oriented VQA': 0,
+ 'Key Information Extraction': 0,
+ 'Handwritten Mathematical Expression Recognition': 0
+}
+
def OCRBench_eval(eval_file):
logger = get_logger('Evaluation')
@@ -15,32 +25,41 @@ def OCRBench_eval(eval_file):
predict = str(line['prediction'])
answers = eval(line['answer'])
category = line['category']
- if category == "Handwritten Mathematical Expression Recognition":
+ if category == 'Handwritten Mathematical Expression Recognition':
for j in range(len(answers)):
- answer = answers[j].strip().replace("\n"," ").replace(" ","")
- predict = predict.strip().replace("\n"," ").replace(" ","")
+ answer = answers[j].strip().replace('\n', ' ').replace(' ', '')
+ predict = predict.strip().replace('\n', ' ').replace(' ', '')
if answer in predict:
- OCRBench_score[category]+= 1
+ OCRBench_score[category] += 1
break
else:
for j in range(len(answers)):
- answer = answers[j].lower().strip().replace("\n"," ")
- predict = predict.lower().strip().replace("\n"," ")
+ answer = answers[j].lower().strip().replace('\n', ' ')
+ predict = predict.lower().strip().replace('\n', ' ')
if answer in predict:
- OCRBench_score[category]+= 1
+ OCRBench_score[category] += 1
break
+
final_score_dict = {}
- final_score_dict['Text Recognition']=OCRBench_score['Regular Text Recognition']+OCRBench_score['Irregular Text Recognition']+OCRBench_score['Artistic Text Recognition']+OCRBench_score['Handwriting Recognition']+OCRBench_score['Digit String Recognition']+OCRBench_score['Non-Semantic Text Recognition']
+ final_score_dict['Text Recognition'] = (
+ OCRBench_score['Regular Text Recognition'] + OCRBench_score['Irregular Text Recognition']
+ + OCRBench_score['Artistic Text Recognition'] + OCRBench_score['Handwriting Recognition']
+ + OCRBench_score['Digit String Recognition'] + OCRBench_score['Non-Semantic Text Recognition']
+ )
final_score_dict['Scene Text-centric VQA'] = OCRBench_score['Scene Text-centric VQA']
final_score_dict['Doc-oriented VQA'] = OCRBench_score['Doc-oriented VQA']
final_score_dict['Key Information Extraction'] = OCRBench_score['Key Information Extraction']
- final_score_dict['Handwritten Mathematical Expression Recognition'] = OCRBench_score['Handwritten Mathematical Expression Recognition']
- final_score_dict['Final Score'] = final_score_dict['Text Recognition'] + final_score_dict['Scene Text-centric VQA'] + final_score_dict['Doc-oriented VQA'] + final_score_dict['Key Information Extraction'] + final_score_dict['Handwritten Mathematical Expression Recognition']
- final_score_dict['Final Score Norm'] = float(final_score_dict['Final Score'])/10
- score_pth = eval_file.replace('.xlsx','_score.json')
+ final_score_dict['Handwritten Mathematical Expression Recognition'] = \
+ OCRBench_score['Handwritten Mathematical Expression Recognition']
+ final_score_dict['Final Score'] = (
+ final_score_dict['Text Recognition'] + final_score_dict['Scene Text-centric VQA']
+ + final_score_dict['Doc-oriented VQA'] + final_score_dict['Key Information Extraction']
+ + final_score_dict['Handwritten Mathematical Expression Recognition']
+ )
+ final_score_dict['Final Score Norm'] = float(final_score_dict['Final Score']) / 10
+ score_pth = eval_file.replace('.xlsx', '_score.json')
dump(final_score_dict, score_pth)
logger.info(f'OCRBench_eval successfully finished evaluating {eval_file}, results saved in {score_pth}')
- logger.info(f'Score: ')
+ logger.info('Score: ')
for key, value in final_score_dict.items():
logger.info('{}:{}'.format(key, value))
-
diff --git a/vlmeval/evaluate/__init__.py b/vlmeval/evaluate/__init__.py
index a12673f20..10248c4b6 100644
--- a/vlmeval/evaluate/__init__.py
+++ b/vlmeval/evaluate/__init__.py
@@ -6,4 +6,4 @@
from .mathvista_eval import MathVista_eval
from .llavabench import LLaVABench_eval
from .misc import build_judge
-from .OCRBench import OCRBench_eval
\ No newline at end of file
+from .OCRBench import OCRBench_eval
diff --git a/vlmeval/evaluate/coco_eval.py b/vlmeval/evaluate/coco_eval.py
index 3fc5570b3..9ae8f4f85 100644
--- a/vlmeval/evaluate/coco_eval.py
+++ b/vlmeval/evaluate/coco_eval.py
@@ -10,13 +10,13 @@ def __init__(self, ref, gt):
self.gt = gt
print('setting up scorers...')
self.scorers = [
- (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
+ (Bleu(4), ['Bleu_1', 'Bleu_2', 'Bleu_3', 'Bleu_4']),
# (Meteor(), "METEOR"), # need java version 11.0.16+
- (Rouge(), "ROUGE_L"),
- (Cider(), "CIDEr"),
+ (Rouge(), 'ROUGE_L'),
+ (Cider(), 'CIDEr'),
# (Spice(), "SPICE"), # need java version 11.0.16+
]
-
+
def compute_scores(self):
total_scores = {}
for scorer, method in self.scorers:
@@ -24,49 +24,51 @@ def compute_scores(self):
score, scores = scorer.compute_score(self.gt, self.ref)
if type(method) == list:
for sc, scs, m in zip(score, scores, method):
- print("%s: %0.3f" % (m, sc * 100))
- total_scores["Bleu"] = [x * 100 for x in score]
+ print('%s: %0.3f' % (m, sc * 100))
+ total_scores['Bleu'] = [x * 100 for x in score]
else:
- print("%s: %0.3f" % (method, score * 100))
+ print('%s: %0.3f' % (method, score * 100))
total_scores[method] = score * 100
-
+
print('*****DONE*****')
for key, value in total_scores.items():
print('{}:{}'.format(key, value))
return total_scores
+
def COCO_eval(eval_file, nproc=4, verbose=False):
logger = get_logger('Evaluation')
data = load(eval_file)
-
+
lt = len(data)
lines = [data.iloc[i] for i in range(lt)]
ref = {}
gt = {}
- for i,(line) in enumerate(lines):
+ for i, line in enumerate(lines):
ref[str(i)] = [str(line['prediction'])]
gt[str(i)] = eval(line['answer'])
- scorer = COCO_Caption_Scorer(ref,gt)
+ scorer = COCO_Caption_Scorer(ref, gt)
coco_caption_score_dict = scorer.compute_scores()
-
- score_pth = eval_file.replace('.xlsx','_score.json')
+
+ score_pth = eval_file.replace('.xlsx', '_score.json')
dump(coco_caption_score_dict, score_pth)
logger.info(f'COCO_eval successfully finished evaluating {eval_file}, results saved in {score_pth}')
- logger.info(f'Score: ')
+ logger.info('Score: ')
for key, value in coco_caption_score_dict.items():
logger.info('{}:{}'.format(key, value))
+
def parse_args():
- parser = argparse.ArgumentParser(description="Inference LLM Answers. ")
- parser.add_argument("--data", type=str, help="The question set for inference, in excel / tsv / json format. ")
- parser.add_argument("--nproc", type=int, default=4)
- parser.add_argument("--verbose", action='store_true')
+ parser = argparse.ArgumentParser(description='Inference LLM Answers. ')
+ parser.add_argument('--data', type=str, help='The question set for inference, in excel / tsv / json format. ')
+ parser.add_argument('--nproc', type=int, default=4)
+ parser.add_argument('--verbose', action='store_true')
args = parser.parse_args()
return args
+
if __name__ == '__main__':
args = parse_args()
COCO_eval(eval_file=args.data, nproc=args.nproc, verbose=args.verbose)
-
diff --git a/vlmeval/evaluate/llavabench.py b/vlmeval/evaluate/llavabench.py
index e8e291256..a462ed482 100644
--- a/vlmeval/evaluate/llavabench.py
+++ b/vlmeval/evaluate/llavabench.py
@@ -7,14 +7,16 @@
from vlmeval.utils import track_progress_rich
rule_dict = {
- "llava_bench_conv": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with a few sentences describing the image. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."},
- "llava_bench_detail": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with a few sentences describing the image. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."},
- "llava_bench_complex": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with a few sentences describing the image. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}
+ 'llava_bench_conv': {'role': 'Assistant', 'prompt': 'We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with a few sentences describing the image. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment.'}, # noqa: E501
+ 'llava_bench_detail': {'role': 'Assistant', 'prompt': 'We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with a few sentences describing the image. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment.'}, # noqa: E501
+ 'llava_bench_complex': {'role': 'Assistant', 'prompt': 'We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with a few sentences describing the image. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment.'} # noqa: E501
}
+
def get_eval(judge, content):
return judge.generate(content)
+
def parse_score(review):
logger = get_logger('Evaluation')
try:
@@ -30,6 +32,7 @@ def parse_score(review):
logger.error(e, 'error', review)
return [-1, -1]
+
def build_prompt(line):
cap_str = line['caption']
question = line['question']
@@ -38,7 +41,7 @@ def build_prompt(line):
category = 'llava_bench_' + line['category']
rule = rule_dict[category]
role, prompt = rule['role'], rule['prompt']
-
+
content = (f'[Context]\n{cap_str}\n\n'
f'[Question]\n{question}\n\n'
f'[{role} 1]\n{ans1}\n\n[End of {role} 1]\n\n'
@@ -46,16 +49,17 @@ def build_prompt(line):
f'[System]\n{prompt}\n\n')
return content
+
def LLaVABench_atomeval(model, prompt):
review = get_eval(model, prompt)
- scores = parse_score(review)
+ scores = parse_score(review)
return scores
-
+
+
def LLaVABench_score(data):
cates = ['overall'] + list(set(data['category']))
ret = defaultdict(list)
-
-
+
for c in cates:
ret['split'].append(c)
sub = data[data['category'] == c] if c != 'overall' else data
@@ -64,8 +68,9 @@ def LLaVABench_score(data):
ret['GPT4 Score'].append(np.mean(sub['gpt4_score']) * 10)
return pd.DataFrame(ret)
+
def LLaVABench_eval(eval_file, model='gpt-4-0314', nproc=4, verbose=False):
- suffix = '.' + eval_file.split('.')[-1]
+ suffix = '.' + eval_file.split('.')[-1]
record_file = eval_file.replace(suffix, '_openai_result' + suffix)
score_file = eval_file.replace(suffix, '_score.csv')
@@ -73,7 +78,7 @@ def LLaVABench_eval(eval_file, model='gpt-4-0314', nproc=4, verbose=False):
data = load(eval_file)
lines = [data.iloc[i] for i in range(len(data))]
model = build_judge(
- model, temperature=0.2, retry=10, verbose=verbose,
+ model, temperature=0.2, retry=10, verbose=verbose,
system_prompt='You are a helpful and precise assistant for checking the quality of the answer.')
prompts = [build_prompt(line) for line in lines]
tups = [(model, prompt) for prompt in prompts]
@@ -81,24 +86,26 @@ def LLaVABench_eval(eval_file, model='gpt-4-0314', nproc=4, verbose=False):
data['gpt4_score'] = [x[0] for x in scores]
data['score'] = [x[1] for x in scores]
dump(data, record_file)
-
+
data = load(record_file)
ret = LLaVABench_score(data).round(1)
print(ret)
dump(ret, score_file)
return ret
-
+
+
def parse_args():
- parser = argparse.ArgumentParser(description="LLaVABench Evaluation. ")
- parser.add_argument("data", type=str, help="The question set for inference, in excel / tsv / json format. ")
+ parser = argparse.ArgumentParser(description='LLaVABench Evaluation. ')
+ parser.add_argument('data', type=str, help='The question set for inference, in excel / tsv / json format. ')
parser.add_argument(
- "--model", type=str, help="The LLM (GPT) used for inference. ", default="gpt-4-turbo",
+ '--model', type=str, help='The LLM (GPT) used for inference. ', default='gpt-4-turbo',
choices=['gpt-4-0613', 'gpt-4-turbo', 'chatgpt-1106', 'chatgpt-0613', 'gpt-4-0314'])
- parser.add_argument("--nproc", type=int, default=4)
- parser.add_argument("--verbose", action='store_true')
+ parser.add_argument('--nproc', type=int, default=4)
+ parser.add_argument('--verbose', action='store_true')
args = parser.parse_args()
return args
+
if __name__ == '__main__':
args = parse_args()
LLaVABench_eval(eval_file=args.data, model=args.model, nproc=args.nproc, verbose=args.verbose)
diff --git a/vlmeval/evaluate/mathvista_eval.py b/vlmeval/evaluate/mathvista_eval.py
index d67f4916f..177962e4c 100644
--- a/vlmeval/evaluate/mathvista_eval.py
+++ b/vlmeval/evaluate/mathvista_eval.py
@@ -3,47 +3,57 @@
from vlmeval.utils import track_progress_rich
from vlmeval.utils.matching_util import can_infer
+
def get_gpt4_ICE():
example_1 = """
- Hint: Please answer the question requiring an integer answer and provide the final value, e.g., 1, 2, 3, at the end.\n
- Question: Which number is missing?\n
- Model response: The number missing in the sequence is 14.\n
- Extracted answer: 14
- """
-
+Hint: Please answer the question requiring an integer answer and provide the final value,
+e.g., 1, 2, 3, at the end.\n
+Question: Which number is missing?\n
+Model response: The number missing in the sequence is 14.\n
+Extracted answer: 14
+"""
+
example_2 = """
- Hint: Please answer the question requiring a floating-point number with one decimal place and provide the final value, e.g., 1.2, 1.3, 1.4, at the end.\n
- Question: What is the fraction of females facing the camera?\n
- Model response: The fraction of females facing the camera is 0.6, which means that six out of ten females in the group are facing the camera.\n
- Extracted answer: 0.6
- """
+Hint: Please answer the question requiring a floating-point number with one decimal place and provide the final value,
+e.g., 1.2, 1.3, 1.4, at the end.\n
+Question: What is the fraction of females facing the camera?\n
+Model response: The fraction of females facing the camera is 0.6,
+which means that six out of ten females in the group are facing the camera.\n
+Extracted answer: 0.6
+"""
example_3 = """
- Hint: Please answer the question requiring a floating-point number with two decimal places and provide the final value, e.g., 1.23, 1.34, 1.45, at the end.\n
- Question: How much money does Luca need to buy a sour apple candy and a butter-scotch candy? (Unit: $)\n
- Model response: Luca needs $1.45 to buy a sour apple candy and a butterscotch candy.\n
- Extracted answer: 1.45
- """
-
+Hint: Please answer the question requiring a floating-point number with two decimal places and provide the final value,
+e.g., 1.23, 1.34, 1.45, at the end.\n
+Question: How much money does Luca need to buy a sour apple candy and a butter-scotch candy? (Unit: $)\n
+Model response: Luca needs $1.45 to buy a sour apple candy and a butterscotch candy.\n
+Extracted answer: 1.45
+"""
+
example_4 = """
- Hint: Please answer the question requiring a Python list as an answer and provide the final list, e.g., [1, 2, 3], [1.2, 1.3, 1.4], at the end.\n
- Question: Between which two years does the line graph saw its maximum peak?\n
- Model response: The line graph saw its maximum peak between 2007 and 2008.\n
- Extracted answer: [2007, 2008]
- """
-
+Hint: Please answer the question requiring a Python list as an answer and provide the final list,
+e.g., [1, 2, 3], [1.2, 1.3, 1.4], at the end.\n
+Question: Between which two years does the line graph saw its maximum peak?\n
+Model response: The line graph saw its maximum peak between 2007 and 2008.\n
+Extracted answer: [2007, 2008]
+"""
+
example_5 = """
- Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end.\n
- Question: What fraction of the shape is blue?\n
- Choices: (A) 3/11 (B) 8/11 (C) 6/11 (D) 3/5\n
- Model response: The correct answer is (B) 8/11.\n
- Extracted answer: B
- """
- return [example_1,example_2,example_3,example_4,example_5]
-
-
+Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end.\n
+Question: What fraction of the shape is blue?\n
+Choices: (A) 3/11 (B) 8/11 (C) 6/11 (D) 3/5\n
+Model response: The correct answer is (B) 8/11.\n
+Extracted answer: B
+"""
+
+ return [example_1, example_2, example_3, example_4, example_5]
+
+
def build_mathvista_gpt4_prompt(line):
- task_description = """ Please read the following example. Then extract the answer from the model response and type it at the end of the prompt.\n"""
+ task_description = """
+Please read the following example.
+Then extract the answer from the model response and type it at the end of the prompt.\n
+"""
question = line['question']
prediction = str(line['prediction'])
prompt = task_description
@@ -55,9 +65,11 @@ def build_mathvista_gpt4_prompt(line):
prompt += 'Extracted answer:'
return prompt
+
def list_to_dict(lst):
return {chr(65 + i): val for i, val in enumerate(lst)}
+
def post_check(line, prefetch=False):
res = None
ans = line['answer']
@@ -81,12 +93,13 @@ def post_check(line, prefetch=False):
ans = str(ans)
except ValueError:
pass
-
+
if res == ans:
return res
else:
return False
+
def MathVista_auxeval(model, line):
prompt = build_mathvista_gpt4_prompt(line)
log = ''
@@ -101,10 +114,11 @@ def MathVista_auxeval(model, line):
log += f'Try {i}: output is {prediction}, failed to parse.\n'
else:
log += 'Succeed'
- return dict(log=log, res= res)
+ return dict(log=log, res=res)
log += 'All 5 retries failed.\n'
return dict(log=log, res='')
+
def MathVista_acc(result_file):
data = load(result_file)
tot = defaultdict(lambda: 0)
@@ -114,7 +128,6 @@ def MathVista_acc(result_file):
skill_list = []
for i in range(lt):
item = data.iloc[i]
- index = item['index']
cate = item['task']
tot['Overall'] += 1
try:
@@ -136,7 +149,7 @@ def MathVista_acc(result_file):
hit[cate] += 1
for skill in skills:
hit[skill] += 1
-
+
res = defaultdict(list)
for k in tot.keys():
res['Task&Skill'].append(k)
@@ -148,19 +161,20 @@ def MathVista_acc(result_file):
res = pd.DataFrame(res)
return res
+
def MathVista_eval(eval_file, model='gpt-4-turbo', nproc=4, verbose=False):
- logger = get_logger('Evaluation')
+ logger = get_logger('Evaluation')
suffix = eval_file.split('.')[-1]
storage = eval_file.replace(f'.{suffix}', f'_{model}.xlsx')
tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
if osp.exists(storage):
- logger.warning(f"GPT scoring file {storage} already exists, will reuse it in MathVista_eval. ")
+ logger.warning(f'GPT scoring file {storage} already exists, will reuse it in MathVista_eval. ')
else:
data = load(eval_file)
gpt_version = model
model = build_judge(gpt_version, verbose=verbose, max_tokens=128, retry=10)
-
+
lt = len(data)
lines = [data.iloc[i] for i in range(lt)]
tups = [(model, line) for line in lines]
@@ -171,16 +185,16 @@ def MathVista_eval(eval_file, model='gpt-4-turbo', nproc=4, verbose=False):
ans = load(tmp_file)
tups = [x for x, i in zip(tups, indices) if i not in ans]
indices = [i for i in indices if i not in ans]
-
+
if len(indices):
new_results = track_progress_rich(
MathVista_auxeval, tups, nproc=nproc, chunksize=nproc,
keys=indices, save=tmp_file)
ans = load(tmp_file)
for k, v in zip(indices, new_results):
- assert k in ans
+ assert k in ans
assert ans[k]['log'] == v['log'] and ans[k]['res'] == v['res']
-
+
log_map, res_map = {}, {}
all_inds = [line['index'] for line in lines]
for k in all_inds:
@@ -189,30 +203,31 @@ def MathVista_eval(eval_file, model='gpt-4-turbo', nproc=4, verbose=False):
data['res'] = [res_map[idx] for idx in data['index']]
data['log'] = [log_map[idx] for idx in data['index']]
dump(data, storage)
-
+
score = MathVista_acc(storage)
- score_pth = storage.replace('.xlsx','_score.csv')
-
- dump(score,score_pth)
+ score_pth = storage.replace('.xlsx', '_score.csv')
+
+ dump(score, score_pth)
logger.info(f'MathVista_eval successfully finished evaluating {eval_file}, results saved in {score_pth}')
- logger.info(f'Score: ')
+ logger.info('Score: ')
logger.info(score)
-
+
+
def parse_args():
- parser = argparse.ArgumentParser(description="Inference LLM Answers. ")
- parser.add_argument("data", type=str, help="The question set for inference, in excel / tsv / json format. ")
+ parser = argparse.ArgumentParser(description='Inference LLM Answers. ')
+ parser.add_argument('data', type=str, help='The question set for inference, in excel / tsv / json format. ')
parser.add_argument(
- "--model",
- type=str,
- help="The LLM (GPT) used for inference. ",
- default="gpt-4-turbo",
+ '--model',
+ type=str,
+ help='The LLM (GPT) used for inference. ',
+ default='gpt-4-turbo',
choices=['gpt-4-0613', 'gpt-4-turbo', 'chatgpt-1106', 'chatgpt-0613'])
- parser.add_argument("--nproc", type=int, default=4)
- parser.add_argument("--verbose", action='store_true')
+ parser.add_argument('--nproc', type=int, default=4)
+ parser.add_argument('--verbose', action='store_true')
args = parser.parse_args()
return args
+
if __name__ == '__main__':
args = parse_args()
MathVista_eval(eval_file=args.data, model=args.model, nproc=args.nproc, verbose=args.verbose)
-
diff --git a/vlmeval/evaluate/misc.py b/vlmeval/evaluate/misc.py
index 9893a4b14..423dce13c 100644
--- a/vlmeval/evaluate/misc.py
+++ b/vlmeval/evaluate/misc.py
@@ -3,12 +3,13 @@
INTERNAL = os.environ.get('INTERNAL', 0)
+
def build_judge(version, **kwargs):
model_map = {
- 'gpt-4-turbo': 'gpt-4-1106-preview',
+ 'gpt-4-turbo': 'gpt-4-1106-preview',
'gpt-4-0613': 'gpt-4-0613',
'gpt-4-0314': 'gpt-4-0314',
- 'gpt-4-0125': 'gpt-4-0125-preview',
+ 'gpt-4-0125': 'gpt-4-0125-preview',
'chatgpt-1106': 'gpt-3.5-turbo-1106',
'chatgpt-0613': 'gpt-3.5-turbo-0613',
'chatgpt-0125': 'gpt-3.5-turbo-0125'
@@ -18,4 +19,4 @@ def build_judge(version, **kwargs):
model = OpenAIWrapperInternal(model_version, **kwargs)
else:
model = OpenAIWrapper(model_version, **kwargs)
- return model
\ No newline at end of file
+ return model
diff --git a/vlmeval/evaluate/mmvet_eval.py b/vlmeval/evaluate/mmvet_eval.py
index f4739ed55..137125525 100644
--- a/vlmeval/evaluate/mmvet_eval.py
+++ b/vlmeval/evaluate/mmvet_eval.py
@@ -2,25 +2,48 @@
from vlmeval.smp import *
from vlmeval.utils import track_progress_rich
+
def build_mmvet_gpt4_prompt(line):
question = line['question']
gt = str(line['answer'])
prediction = str(line['prediction'])
- prompt = """Compare the ground truth and prediction from AI models, to give a correctness score for the prediction. in the ground truth means it is totally right only when all elements in the ground truth are present in the prediction, and means it is totally right when any one element in the ground truth is present in the prediction. The correctness score is 0.0 (totally wrong), 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, or 1.0 (totally right). Just complete the last space of the correctness score.
-
- Question | Ground truth | Prediction | Correctness
- --- | --- | --- | ---
- What is x in the equation? | -1 -5 | x = 3 | 0.0
- What is x in the equation? | -1 -5 | x = -1 | 0.5
- What is x in the equation? | -1 -5 | x = -5 | 0.5
- What is x in the equation? | -1 -5 | x = -5 or 5 | 0.5
- What is x in the equation? | -1 -5 | x = -1 or x = -5 | 1.0
- Can you explain this meme? | This meme is poking fun at the fact that the names of the countries Iceland and Greenland are misleading. Despite its name, Iceland is known for its beautiful green landscapes, while Greenland is mostly covered in ice and snow. The meme is saying that the person has trust issues because the names of these countries do not accurately represent their landscapes. | The meme talks about Iceland and Greenland. It's pointing out that despite their names, Iceland is not very icy and Greenland isn't very green. | 0.4
- Can you explain this meme? | This meme is poking fun at the fact that the names of the countries Iceland and Greenland are misleading. Despite its name, Iceland is known for its beautiful green landscapes, while Greenland is mostly covered in ice and snow. The meme is saying that the person has trust issues because the names of these countries do not accurately represent their landscapes. | The meme is using humor to point out the misleading nature of Iceland's and Greenland's names. Iceland, despite its name, has lush green landscapes while Greenland is mostly covered in ice and snow. The text 'This is why I have trust issues' is a playful way to suggest that these contradictions can lead to distrust or confusion. The humor in this meme is derived from the unexpected contrast between the names of the countries and their actual physical characteristics. | 1.0
- """
- gpt4_prompt = prompt + '\n' + ' | '.join([question, gt.replace("", " ").replace("", " "), prediction, ""])
+ prompt = """
+Compare the ground truth and prediction from AI models, to give a correctness score for the prediction.
+ in the ground truth means it is totally right
+only when all elements in the ground truth are present in the prediction,
+and means it is totally right when any one element in the ground truth is present in the prediction.
+The correctness score is 0.0 (totally wrong), 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, or 1.0 (totally right).
+Just complete the last space of the correctness score.
+
+Question | Ground truth | Prediction | Correctness
+--- | --- | --- | ---
+What is x in the equation? | -1 -5 | x = 3 | 0.0
+What is x in the equation? | -1 -5 | x = -1 | 0.5
+What is x in the equation? | -1 -5 | x = -5 | 0.5
+What is x in the equation? | -1 -5 | x = -5 or 5 | 0.5
+What is x in the equation? | -1 -5 | x = -1 or x = -5 | 1.0
+Can you explain this meme? | This meme is poking fun at the fact that the names of the countries
+Iceland and Greenland are misleading. Despite its name, Iceland is known for its beautiful green landscapes,
+while Greenland is mostly covered in ice and snow. The meme is saying that the person has trust issues
+because the names of these countries do not accurately represent their landscapes. |
+The meme talks about Iceland and Greenland. It's pointing out that despite their names,
+Iceland is not very icy and Greenland isn't very green. | 0.4
+Can you explain this meme? | This meme is poking fun at the fact that the names of the countries
+Iceland and Greenland are misleading. Despite its name, Iceland is known for its beautiful green landscapes,
+while Greenland is mostly covered in ice and snow. The meme is saying that the person has trust issues
+because the names of these countries do not accurately represent their landscapes. |
+The meme is using humor to point out the misleading nature of Iceland's and Greenland's names.
+Iceland, despite its name, has lush green landscapes while Greenland is mostly covered in ice and snow.
+The text 'This is why I have trust issues' is a playful way to suggest
+that these contradictions can lead to distrust or confusion.
+The humor in this meme is derived from the unexpected contrast between the names of the countries
+and their actual physical characteristics. | 1.0
+"""
+ gpt4_prompt = prompt + '\n' + ' | '.join(
+ [question, gt.replace('', ' ').replace('', ' '), prediction, ''])
return gpt4_prompt
+
def MMVet_auxeval(model, line):
def float_cvt(s):
try:
@@ -44,6 +67,7 @@ def float_cvt(s):
log += 'All 5 retries failed.\n'
return dict(log=log, score=0.0)
+
def MMVet_acc(result_file):
data = load(result_file)
tot = defaultdict(lambda: 0)
@@ -53,11 +77,11 @@ def MMVet_acc(result_file):
for i in range(lt):
item = data.iloc[i]
cate = item['category']
- cate2 = cate.replace(',','_')
+ cate2 = cate.replace(',', '_')
if cate2 not in cate2_list:
cate2_list.append(cate2)
grade = float(item['score'])
- cate_list = ['rec','ocr','know','gen','spat','math']
+ cate_list = ['rec', 'ocr', 'know', 'gen', 'spat', 'math']
for capa in cate_list:
if capa in cate:
tot[capa] += 1
@@ -83,6 +107,7 @@ def MMVet_acc(result_file):
res2 = pd.DataFrame(res2)
return res, res2
+
def MMVet_eval(eval_file, model='gpt-4-turbo', nproc=4, verbose=False):
logger = get_logger('Evaluation')
@@ -90,12 +115,12 @@ def MMVet_eval(eval_file, model='gpt-4-turbo', nproc=4, verbose=False):
storage = eval_file.replace(f'.{suffix}', f'_{model}.xlsx')
tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
if osp.exists(storage):
- logger.warning(f"GPT scoring file {storage} already exists, will reuse it in MMVet_eval. ")
+ logger.warning(f'GPT scoring file {storage} already exists, will reuse it in MMVet_eval. ')
else:
data = load(eval_file)
gpt_version = model
model = build_judge(gpt_version, verbose=verbose, max_tokens=3, retry=10)
-
+
lt = len(data)
lines = [data.iloc[i] for i in range(lt)]
tups = [(model, line) for line in lines]
@@ -106,16 +131,16 @@ def MMVet_eval(eval_file, model='gpt-4-turbo', nproc=4, verbose=False):
ans = load(tmp_file)
tups = [x for x, i in zip(tups, indices) if i not in ans]
indices = [i for i in indices if i not in ans]
-
+
if len(indices):
new_results = track_progress_rich(
MMVet_auxeval, tups, nproc=nproc, chunksize=nproc,
keys=indices, save=tmp_file)
ans = load(tmp_file)
for k, v in zip(indices, new_results):
- assert k in ans
+ assert k in ans
assert ans[k]['log'] == v['log'] and ans[k]['score'] == v['score']
-
+
log_map, score_map = {}, {}
all_inds = [line['index'] for line in lines]
for k in all_inds:
@@ -131,25 +156,29 @@ def MMVet_eval(eval_file, model='gpt-4-turbo', nproc=4, verbose=False):
dump(score, score_pth)
dump(score_fine, score_fine_pth)
- logger.info(f'MMVet_eval successfully finished evaluating {eval_file}, results saved in {score_pth} and {score_fine_pth}')
- logger.info(f'Score: ')
+ logger.info(
+ f'MMVet_eval successfully finished evaluating {eval_file}, '
+ f'results saved in {score_pth} and {score_fine_pth}'
+ )
+ logger.info('Score: ')
logger.info(score)
+
def parse_args():
- parser = argparse.ArgumentParser(description="Inference LLM Answers. ")
- parser.add_argument("data", type=str, help="The question set for inference, in excel / tsv / json format. ")
+ parser = argparse.ArgumentParser(description='Inference LLM Answers. ')
+ parser.add_argument('data', type=str, help='The question set for inference, in excel / tsv / json format. ')
parser.add_argument(
- "--model",
- type=str,
- help="The LLM (GPT) used for inference. ",
- default="gpt-4-turbo",
+ '--model',
+ type=str,
+ help='The LLM (GPT) used for inference. ',
+ default='gpt-4-turbo',
choices=['gpt-4-0613', 'gpt-4-turbo', 'chatgpt-1106', 'chatgpt-0613'])
- parser.add_argument("--nproc", type=int, default=4)
- parser.add_argument("--verbose", action='store_true')
+ parser.add_argument('--nproc', type=int, default=4)
+ parser.add_argument('--verbose', action='store_true')
args = parser.parse_args()
return args
+
if __name__ == '__main__':
args = parse_args()
MMVet_eval(eval_file=args.data, model=args.model, nproc=args.nproc, verbose=args.verbose)
-
\ No newline at end of file
diff --git a/vlmeval/evaluate/multiple_choice.py b/vlmeval/evaluate/multiple_choice.py
index a5e78def1..c413af14b 100644
--- a/vlmeval/evaluate/multiple_choice.py
+++ b/vlmeval/evaluate/multiple_choice.py
@@ -9,14 +9,15 @@
INTERNAL = os.environ.get('INTERNAL', 0)
abbrs = {
- 'coarse_perception': 'CP',
- 'finegrained_perception (instance-level)': 'FP-S',
- 'finegrained_perception (cross-instance)': 'FP-C',
+ 'coarse_perception': 'CP',
+ 'finegrained_perception (instance-level)': 'FP-S',
+ 'finegrained_perception (cross-instance)': 'FP-C',
'logic_reasoning': 'LR',
'relation_reasoning': 'RR',
'attribute_reasoning': 'AR'
}
+
def MMMU_preproc(data):
logger = get_logger('Evaluation')
cnt = 0
@@ -32,6 +33,7 @@ def MMMU_preproc(data):
data['B'] = Bs
return data
+
def report_acc(df):
# assert group in [None, 'category', 'l2-category']
res = defaultdict(list)
@@ -57,36 +59,43 @@ def report_acc(df):
res[ab_name] = [np.mean(sub_df[sub_df['split'] == sp]['hit']) for sp in res['split']]
return pd.DataFrame(res)
+
def build_prompt(question, options, prediction):
tmpl = (
- "You are an AI assistant who will help me to match an answer with several options of a single-choice question. "
- "You are provided with a question, several options, and an answer, and you need to find which option is most similar to the answer. "
- "If the meaning of all options are significantly different from the answer, output Z. "\
- "Your should output a single uppercase character in A, B, C, D (if they are valid options), and Z. \n"
- "Example 1: \n"
- "Question: What is the main object in image?\nOptions: A. teddy bear B. rabbit C. cat D. dog\nAnswer: a cute teddy bear\nYour output: A\n"
- "Example 2: \n"
- "Question: What is the main object in image?\nOptions: A. teddy bear B. rabbit C. cat D. dog\nAnswer: Spider\nYour output: Z\n"
- "Example 3: \n"
- "Question: {}?\nOptions: {}\nAnswer: {}\nYour output: "
+ 'You are an AI assistant who will help me to match '
+ 'an answer with several options of a single-choice question. '
+ 'You are provided with a question, several options, and an answer, '
+ 'and you need to find which option is most similar to the answer. '
+ 'If the meaning of all options are significantly different from the answer, output Z. '
+ 'Your should output a single uppercase character in A, B, C, D (if they are valid options), and Z. \n'
+ 'Example 1: \n'
+ 'Question: What is the main object in image?\nOptions: A. teddy bear B. rabbit C. cat D. dog\n'
+ 'Answer: a cute teddy bear\nYour output: A\n'
+ 'Example 2: \n'
+ 'Question: What is the main object in image?\nOptions: A. teddy bear B. rabbit C. cat D. dog\n'
+ 'Answer: Spider\nYour output: Z\n'
+ 'Example 3: \n'
+ 'Question: {}?\nOptions: {}\nAnswer: {}\nYour output: '
)
return tmpl.format(question, options, prediction)
+
def build_prompt_cn(question, options, prediction):
tmpl = (
- "你是一个帮助我匹配答案与单选题中多个选项的 AI 助手。"
- "你会被提供:一个问题,多个选项,一个答案。你的任务是找到与答案意义最相近的选项。"
- "如果所有选项的意义都与答案显著不同,则输出 Z。"
- "你应该输出一个单个的大写字母,例如 A, B, C, D(如果它们是有效选项),或 Z。"
- "例 1:"
- "问题: 图中最主要的物体是什么?\n选项: A. 泰迪熊 B. 兔子 C. 猫 D. 狗\n答案: 一只可爱的泰迪熊\n输出: A\n"
- "例 2: \n"
- "问题: 图中最主要的物体是什么?\n选项: A. 泰迪熊 B. 兔子 C. 猫 D. 狗\n答案: 蜘蛛\n输出: Z\n"
- "例 3: \n"
- "问题: {}?\n选项: {}\n答案: {}\n输出: "
+ '你是一个帮助我匹配答案与单选题中多个选项的 AI 助手。'
+ '你会被提供:一个问题,多个选项,一个答案。你的任务是找到与答案意义最相近的选项。'
+ '如果所有选项的意义都与答案显著不同,则输出 Z。'
+ '你应该输出一个单个的大写字母,例如 A, B, C, D(如果它们是有效选项),或 Z。'
+ '例 1:'
+ '问题: 图中最主要的物体是什么?\n选项: A. 泰迪熊 B. 兔子 C. 猫 D. 狗\n答案: 一只可爱的泰迪熊\n输出: A\n'
+ '例 2: \n'
+ '问题: 图中最主要的物体是什么?\n选项: A. 泰迪熊 B. 兔子 C. 猫 D. 狗\n答案: 蜘蛛\n输出: Z\n'
+ '例 3: \n'
+ '问题: {}?\n选项: {}\n答案: {}\n输出: '
)
return tmpl.format(question, options, prediction)
+
def build_choices(item):
ret = {}
for ch in string.ascii_uppercase:
@@ -94,10 +103,12 @@ def build_choices(item):
ret[ch] = item[ch]
return ret
+
def prefetch_answer(item):
choices = build_choices(item)
return can_infer(item['prediction'], choices)
+
def extract_answer_from_item(model, item):
logger = get_logger('Evaluation')
# It will return: (pred, raw, llm_time)
@@ -111,9 +122,9 @@ def extract_answer_from_item(model, item):
retry = 3
ret = can_infer(item['prediction'], choices)
- if ret:
+ if ret:
return dict(opt=ret, log=item['prediction'])
-
+
while retry:
ans = model.generate(prompt)
if 'Failed to obtain answer via API' in ans:
@@ -129,7 +140,8 @@ def extract_answer_from_item(model, item):
if retry == 0:
options = list(choices) + ['Z'] if 'Z' not in choices else []
return dict(opt=rd.choice(options), log='Failed to predict, thus randomly generate one. ')
-
+
+
def prefetch_sub_data(sub_data, answer_map, verbose=False):
lt = len(sub_data)
GT, PRED = [], []
@@ -139,21 +151,25 @@ def prefetch_sub_data(sub_data, answer_map, verbose=False):
GT.append(answer_map[idx])
PRED.append(prefetch_answer(item))
if PRED[-1] and (GT[-1] != PRED[-1]):
- log = f"Failed in Prefetching Rolling {i}: Answer is {GT[-1]}, Prediction is {item['prediction']}, Pre-fetched is {PRED[-1]}. "
+ log = (
+ f'Failed in Prefetching Rolling {i}: Answer is {GT[-1]}, '
+ f"Prediction is {item['prediction']}, Pre-fetched is {PRED[-1]}. "
+ )
return dict(hit=0, log=log)
flag = True
for g, p in zip(GT, PRED):
if g != p:
- flag = False
- ret = (dict(hit=1, log="Succeed During Pre-fetching"), ) if flag else (None, )
+ flag = False
+ ret = (dict(hit=1, log='Succeed During Pre-fetching'), ) if flag else (None, )
ret = ret + (GT, PRED) if verbose else ret
return ret if len(ret) > 1 else ret[0]
-
+
+
def eval_sub_data(model, sub_data, answer_map):
res, GT, PRED = prefetch_sub_data(sub_data, answer_map, verbose=True)
if res is not None:
return res
-
+
lt = len(sub_data)
log = ''
for i in range(lt):
@@ -164,13 +180,20 @@ def eval_sub_data(model, sub_data, answer_map):
opt, match_log = res['opt'], res['log']
PRED[i] = opt
if PRED[i] != GT[i]:
- log += f"Failed in Rolling {i}: Answer is {GT[i]}; Prediction is {sub_data.iloc[i]['prediction']}; Pre-fetched is {PRED[i]}; Match Log is {match_log}.\n"
+ log += (
+ f"Failed in Rolling {i}: Answer is {GT[i]}; Prediction is {sub_data.iloc[i]['prediction']}; "
+ f'Pre-fetched is {PRED[i]}; Match Log is {match_log}.\n'
+ )
return dict(hit=0, log=log)
else:
- log += f"Rolling {i}: Answer is {GT[i]}, Prediction is {sub_data.iloc[i]['prediction']}, Pre-fetched is {PRED[i]}.\n"
+ log += (
+ f"Rolling {i}: Answer is {GT[i]}, Prediction is {sub_data.iloc[i]['prediction']}, "
+ f'Pre-fetched is {PRED[i]}.\n'
+ )
return dict(hit=1, log=log)
+
def eval_data_groups(model, data_groups, answer_map, result, result_file, nproc=16):
prefetched = [prefetch_sub_data(g, answer_map, verbose=False) for g in data_groups]
remain = []
@@ -184,21 +207,22 @@ def eval_data_groups(model, data_groups, answer_map, result, result_file, nproc=
keys = [x.iloc[0]['index'] % 1e6 for x in remain]
if len(tups) == 0:
return
-
+
if model is None:
logger = get_logger('Evaluation')
- logger.warning("Exact Matching mode, will not do GPT-based answer matching. ")
+ logger.warning('Exact Matching mode, will not do GPT-based answer matching. ')
for k in keys:
- result[k] = dict(hit=0, log="Failed in Prefetch, no GPT-based answer matching under `exact_matching` policy.")
+ result[k] = dict(
+ hit=0, log='Failed in Prefetch, no GPT-based answer matching under `exact_matching` policy.')
dump(result, result_file)
return
res = track_progress_rich(
eval_sub_data,
- tups,
+ tups,
nproc=nproc,
- chunksize=nproc,
- save=result_file,
+ chunksize=nproc,
+ save=result_file,
keys=keys)
result = load(result_file)
for k, v in zip(keys, res):
@@ -208,7 +232,8 @@ def eval_data_groups(model, data_groups, answer_map, result, result_file, nproc=
result[k] = v
dump(result, result_file)
-def multiple_choice_eval(eval_file, dataset="default", model='chatgpt-0613', nproc=4, verbose=False):
+
+def multiple_choice_eval(eval_file, dataset='default', model='chatgpt-0613', nproc=4, verbose=False):
logger = get_logger('Evaluation')
# assert dataset is not None
@@ -224,7 +249,7 @@ def multiple_choice_eval(eval_file, dataset="default", model='chatgpt-0613', npr
rd.seed(2680)
suffix = eval_file.split('.')[-1]
- assert model in ['chatgpt-0613', "exact_matching", "gpt-4-0125"]
+ assert model in ['chatgpt-0613', 'exact_matching', 'gpt-4-0125']
name_str_map = {
'chatgpt-0613': 'openai',
'gpt-4-0125': 'gpt4'
@@ -234,20 +259,18 @@ def multiple_choice_eval(eval_file, dataset="default", model='chatgpt-0613', npr
if model == 'exact_matching':
model = None
else:
- model_name = 'chatgpt-0613'
-
if INTERNAL or gpt_key_set():
model = build_judge(model, verbose=verbose, retry=10)
else:
logger.error('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
model = None
-
+
logger.info(f'Evaluating {eval_file}')
result_file = eval_file.replace(f'.{suffix}', f'_{name_str}_result.pkl')
result = {}
if osp.exists(result_file):
result = load(result_file)
-
+
data = load(eval_file)
data = data.sort_values(by='index')
data['prediction'] = [str(x) for x in data['prediction']]
@@ -259,7 +282,8 @@ def multiple_choice_eval(eval_file, dataset="default", model='chatgpt-0613', npr
else:
logger.warning('Dataset is not provided, try to use the original `eval_file` as meta data. ')
meta = load(eval_file)
- assert 'category' in meta and 'index' in meta and 'answer' in meta, "Essentail columns missing in the eval_file."
+ assert 'category' in meta and 'index' in meta and 'answer' in meta, (
+ 'Essentail columns missing in the eval_file.')
cate_map = {i: c for i, c in zip(meta['index'], meta['category'])}
answer_map = {i: c for i, c in zip(meta['index'], meta['answer'])}
@@ -279,7 +303,7 @@ def multiple_choice_eval(eval_file, dataset="default", model='chatgpt-0613', npr
data_main = data[data['index'] < int(1e6)]
meta_idx_set = set(meta['index'])
data_main = data_main[data_main['index'].isin(meta_idx_set)]
-
+
lt = len(data_main)
hit, tot = 0, 0
@@ -288,26 +312,26 @@ def multiple_choice_eval(eval_file, dataset="default", model='chatgpt-0613', npr
# Dealing with the normal part
item_main = data_main.iloc[i]
idx = item_main['index']
-
+
if idx in result:
correct = result[idx]['hit']
assert correct in [0, 1]
hit += correct
tot += 1
continue
-
+
sub_data = data[data['index'] % int(1e6) == idx]
data_groups.append(sub_data)
if len(data_groups):
eval_data_groups(
- model=model,
- data_groups=data_groups,
+ model=model,
+ data_groups=data_groups,
answer_map=answer_map,
- nproc=nproc,
- result=result,
+ nproc=nproc,
+ result=result,
result_file=result_file)
-
+
tmp_pth = f'/tmp/{timestr()}.xlsx'
dump(data_main, tmp_pth)
data_main = load(tmp_pth)
@@ -324,34 +348,41 @@ def multiple_choice_eval(eval_file, dataset="default", model='chatgpt-0613', npr
data_main['l2-category'] = [l2_cate_map[i] for i in main_idx]
if split_map is not None:
data_main['split'] = [split_map[i] for i in indices]
-
+
# load split
dump(data_main, eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
data_main = load(eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
-
+
acc = report_acc(data_main)
- score_file = eval_file.replace(f'.{suffix}', f'_acc.csv')
+ score_file = eval_file.replace(f'.{suffix}', '_acc.csv')
dump(acc, score_file)
logger.info(f'multiple_choice_eval successfully finished evaluating {eval_file}, results saved in {score_file}')
- logger.info(f'Score: ')
+ logger.info('Score: ')
logger.info(acc)
return acc
+
def parse_args():
- parser = argparse.ArgumentParser(description="Inference LLM Answers. ")
- parser.add_argument("data", type=str, help="The question set for inference, in excel / tsv / json format. ")
- parser.add_argument("--model", type=str, help="The LLM (GPT) used for inference. ", default='chatgpt-0613', choices=['chatgpt-0613', 'exact_matching', 'gpt-4-0125'])
+ parser = argparse.ArgumentParser(description='Inference LLM Answers. ')
+ parser.add_argument('data', type=str, help='The question set for inference, in excel / tsv / json format. ')
+ parser.add_argument(
+ '--model',
+ type=str,
+ help='The LLM (GPT) used for inference. ',
+ default='chatgpt-0613',
+ choices=['chatgpt-0613', 'exact_matching', 'gpt-4-0125'])
parser.add_argument(
- "--dataset",
- type=str,
- default="default",
+ '--dataset',
+ type=str,
+ default='default',
help='The dataset to evaluate')
- parser.add_argument("--nproc", type=int, default=6)
- parser.add_argument("--verbose", action='store_true')
+ parser.add_argument('--nproc', type=int, default=6)
+ parser.add_argument('--verbose', action='store_true')
args = parser.parse_args()
return args
+
if __name__ == '__main__':
args = parse_args()
- acc = multiple_choice_eval(eval_file=args.data, model=args.model, dataset=args.dataset, nproc=args.nproc, verbose=args.verbose)
-
\ No newline at end of file
+ acc = multiple_choice_eval(
+ eval_file=args.data, model=args.model, dataset=args.dataset, nproc=args.nproc, verbose=args.verbose)
diff --git a/vlmeval/evaluate/vqa_eval.py b/vlmeval/evaluate/vqa_eval.py
index 440861b7e..8f6695091 100644
--- a/vlmeval/evaluate/vqa_eval.py
+++ b/vlmeval/evaluate/vqa_eval.py
@@ -7,6 +7,7 @@
from typing import Optional
from functools import partial
+
def _process_digit_article(inText):
outText = []
tempText = inText.lower().split()
@@ -158,17 +159,18 @@ def _process_digit_article(inText):
return outText
-def hit_calculate(result, dataset_name, vqa_score_threshold = 3, anls_threshold = 0.5):
+def hit_calculate(result, dataset_name, anls_threshold=0.5):
if listinstr(['TextVQA'], dataset_name):
return [np.mean(x['match']) for x in result]
- elif listinstr(['DocVQA'], dataset_name):
+ elif listinstr(['DocVQA'], dataset_name):
# return [1 - np.min(x['match']) >= anls_threshold for x in result]
- return [0.0 if 1 - np.min(x['match']) < anls_threshold else 1 - np.min(x['match']) for x in result ]
- elif listinstr(['ChartQA','OCRVQA'], dataset_name):
+ return [0.0 if 1 - np.min(x['match']) < anls_threshold else 1 - np.min(x['match']) for x in result]
+ elif listinstr(['ChartQA', 'OCRVQA'], dataset_name):
return [np.max(x['match']) for x in result]
- else: #default using vqa_score to calculate score
+ else: # default using vqa_score to calculate score
return [np.mean(x['match']) for x in result]
+
# https://github.com/google-research/pix2struct/blob/main/pix2struct/metrics.py#L81
def relaxed_correctness(target: str,
prediction: str,
@@ -205,19 +207,19 @@ def _to_float(text: str) -> Optional[float]:
prediction_float = _to_float(prediction)
target_float = _to_float(target)
if prediction_float is not None and target_float:
- relative_change = abs(prediction_float -
- target_float) / abs(target_float)
+ relative_change = abs(prediction_float - target_float) / abs(target_float)
return relative_change <= max_relative_change
else:
return prediction.lower() == target.lower()
+
def levenshtein_distance(s1, s2):
if len(s1) > len(s2):
s1, s2 = s2, s1
distances = range(len(s1) + 1)
for i2, c2 in enumerate(s2):
- distances_ = [i2+1]
+ distances_ = [i2 + 1]
for i1, c1 in enumerate(s1):
if c1 == c2:
distances_.append(distances[i1])
@@ -226,12 +228,13 @@ def levenshtein_distance(s1, s2):
distances = distances_
return distances[-1]
+
def anls_compute(groundtruth, prediction):
gt_answer = ' '.join(groundtruth.strip().lower().split())
det_answer = ' '.join(prediction.strip().lower().split())
- dist = levenshtein_distance(gt_answer,det_answer)
- length = max( len(groundtruth.upper()), len(prediction.upper()) )
- values = 0.0 if length == 0 else float(dist) / float(length)
+ dist = levenshtein_distance(gt_answer, det_answer)
+ length = max(len(groundtruth.upper()), len(prediction.upper()))
+ values = 0.0 if length == 0 else float(dist) / float(length)
return values
@@ -244,7 +247,7 @@ def process_answer(answer):
return answer
-def process_line(line, method = 'vqa_score'):
+def process_line(line, method='vqa_score'):
ret = {}
if istype(line['answer'], list):
answers = eval(line['answer'])
@@ -271,19 +274,20 @@ def process_line(line, method = 'vqa_score'):
elif method == 'relaxed_accuracy':
ret['gt'] = answers
ret['pred'] = line['prediction'].strip()
- ret['match'] = [relaxed_correctness(ret['pred'],x) for x in ret['gt']]
+ ret['match'] = [relaxed_correctness(ret['pred'], x) for x in ret['gt']]
elif method == 'accuracy':
ret['gt'] = answers
ret['pred'] = line['prediction'].strip()
ret['match'] = [(1.0 if (x.strip().lower() == ret['pred'].strip().lower()) else 0.0) for x in ret['gt']]
- else: #default using vqa_score to calculate score
+ else: # default using vqa_score to calculate score
ret['gt'] = [process_answer(x) for x in answers]
ret['pred'] = process_answer(line['prediction'])
ret['match'] = [x == ret['pred'] for x in ret['gt']]
-
+
return ret
-
-def VQAEval(eval_file, dataset_name, **kwargs):
+
+
+def VQAEval(eval_file, dataset_name, **kwargs):
logger = get_logger('Evaluation')
data = load(eval_file)
assert 'answer' in data and 'prediction' in data
@@ -293,22 +297,24 @@ def VQAEval(eval_file, dataset_name, **kwargs):
pool = mp.Pool(16)
lines = [data.iloc[i] for i in range(lt)]
if listinstr(['TextVQA'], dataset_name):
- res = pool.map(partial(process_line, method = 'vqa_score'), lines)
- elif listinstr(['ChartQA'], dataset_name):
- res = pool.map(partial(process_line, method = 'relaxed_accuracy'), lines)
+ res = pool.map(partial(process_line, method='vqa_score'), lines)
+ elif listinstr(['ChartQA'], dataset_name):
+ res = pool.map(partial(process_line, method='relaxed_accuracy'), lines)
elif listinstr(['OCRVQA'], dataset_name):
- res = pool.map(partial(process_line, method = 'accuracy'), lines)
+ res = pool.map(partial(process_line, method='accuracy'), lines)
elif listinstr(['DocVQA'], dataset_name):
- res = pool.map(partial(process_line, method = 'anls'), lines)
- else: #default using vqa_score to calculate score
+ res = pool.map(partial(process_line, method='anls'), lines)
+ else: # default using vqa_score to calculate score
res = pool.map(process_line, lines)
- hit = hit_calculate(res, dataset_name)#[np.mean(x['match']) >= full_score_weight for x in res]
+ # [np.mean(x['match']) >= full_score_weight for x in res]
+ hit = hit_calculate(res, dataset_name)
ret = dict()
if 'split' in data:
splits = set(data['split'])
for sp in splits:
sub = [r for l, r in zip(lines, res) if l['split'] == sp]
- hit = hit_calculate(sub, dataset_name) #[np.mean(x['match']) >= full_score_weight for x in sub]
+ # [np.mean(x['match']) >= full_score_weight for x in sub]
+ hit = hit_calculate(sub, dataset_name)
ret[sp] = np.mean(hit) * 100
else:
ret['Overall'] = np.mean(hit) * 100
@@ -317,7 +323,8 @@ def VQAEval(eval_file, dataset_name, **kwargs):
cates.sort()
for c in cates:
sub = [r for l, r in zip(lines, res) if l['category'] == c]
- hit = hit_calculate(sub, dataset_name) #[np.mean(x['match']) >= full_score_weight for x in sub]
+ # [np.mean(x['match']) >= full_score_weight for x in sub]
+ hit = hit_calculate(sub, dataset_name)
ret[c] = np.mean(hit) * 100
ret = d2df(ret)
ret.round(2)
diff --git a/vlmeval/evaluate/yes_or_no.py b/vlmeval/evaluate/yes_or_no.py
index ffe3adbc5..d5e467e67 100644
--- a/vlmeval/evaluate/yes_or_no.py
+++ b/vlmeval/evaluate/yes_or_no.py
@@ -4,6 +4,7 @@
INTERNAL = os.environ.get('INTERNAL', 0)
+
def MME_rating(data_file):
data = load(data_file)
stats = defaultdict(dict)
@@ -16,7 +17,7 @@ def MME_rating(data_file):
if image_path not in stats[category]:
stats[category][image_path] = []
stats[category][image_path].append(score)
-
+
def acc(key, mode='normal'):
res = stats[key]
values = []
@@ -26,26 +27,30 @@ def acc(key, mode='normal'):
elif mode == 'plus':
values.append(val[0] * val[1])
return np.mean(values) * 100
-
+
scores = {}
for k in stats:
scores[k] = acc(k) + acc(k, 'plus')
super_cates = dict(
- perception=['OCR', 'artwork', 'celebrity', 'color', 'count', 'existence', 'landmark', 'position', 'posters', 'scene'],
+ perception=[
+ 'OCR', 'artwork', 'celebrity', 'color', 'count', 'existence',
+ 'landmark', 'position', 'posters', 'scene'
+ ],
reasoning=['code_reasoning', 'commonsense_reasoning', 'numerical_calculation', 'text_translation']
)
-
+
ret = {}
for sc, cate_list in super_cates.items():
base = 0
for c in cate_list:
base += scores[c]
- ret[sc] = base
+ ret[sc] = base
ret.update(scores)
ret = d2df(ret)
return ret
+
def Hallusion_rating(data_file):
def calc_fAcc(data):
res = defaultdict(list)
@@ -54,7 +59,7 @@ def calc_fAcc(data):
line = data.iloc[i]
res[f"{line['l2-category']}_{line['set_id']}_{line['figure_id']}"].append(line['score'])
return np.mean([np.all(x) for x in res.values()]) * 100
-
+
def calc_qAcc(data):
res = defaultdict(list)
lt = len(data)
@@ -62,10 +67,10 @@ def calc_qAcc(data):
line = data.iloc[i]
res[f"{line['l2-category']}_{line['set_id']}_{line['question_id']}"].append(line['score'])
return np.mean([np.all(x) for x in res.values()]) * 100
-
+
def calc_aAcc(data):
return np.mean(data['score']) * 100
-
+
data = load(data_file)
data['set_id'] = [x.split('_')[3] for x in data['index']]
data['figure_id'] = [x.split('_')[4] for x in data['index']]
@@ -76,7 +81,7 @@ def calc_aAcc(data):
res['aAcc'].append(calc_aAcc(data))
res['fAcc'].append(calc_fAcc(data))
res['qAcc'].append(calc_qAcc(data))
-
+
if 'category' in data:
cates = list(set(data['category']))
for c in cates:
@@ -97,6 +102,7 @@ def calc_aAcc(data):
ret = pd.DataFrame(res)
return ret
+
def default_rating(data_file):
data = load(data_file)
res = {}
@@ -117,23 +123,27 @@ def default_rating(data_file):
res[c] = np.mean(sub['score']) * 100
ret = d2df(res)
return ret
-
+
+
def YOrN_match_prompt(line):
tmpl = (
- "You are an AI assistant who will help me to match an answer with two options of a question. "
- "The options are only Yes / No. "
- "You are provided with a question and an answer, and you need to find which option (Yes / No) is most similar to the answer. "
- "If the meaning of all options are significantly different from the answer, output Unknown. "\
- "Your should output a single word among the following 3 choices: Yes, No, Unknown.\n"
- "Example 1: \n"
+ 'You are an AI assistant who will help me to match an answer with two options of a question. '
+ 'The options are only Yes / No. '
+ 'You are provided with a question and an answer, '
+ 'and you need to find which option (Yes / No) is most similar to the answer. '
+ 'If the meaning of all options are significantly different from the answer, output Unknown. '
+ 'Your should output a single word among the following 3 choices: Yes, No, Unknown.\n'
+ 'Example 1: \n'
"Question: Is the word in this image 'Hello'?\nAnswer: The word in this image is 'Hello'.\nYour output: Yes\n"
- "Example 2: \n"
- "Question: Is the word in this image 'Hello'?\nAnswer: The word in this image is not 'Hello'.\nYour output: No\n"
- "Example 3: \n"
- "Question: {}?\nAnswer: {}\nYour output: "
+ 'Example 2: \n'
+ "Question: Is the word in this image 'Hello'?\n"
+ "Answer: The word in this image is not 'Hello'.\nYour output: No\n"
+ 'Example 3: \n'
+ 'Question: {}?\nAnswer: {}\nYour output: '
)
return tmpl.format(line['question'], line['prediction'])
+
def YOrN_Extraction(output):
s = output.lower()
words = process_punctuation(s).split()
@@ -143,6 +153,7 @@ def YOrN_Extraction(output):
return 'No'
return 'Unknown'
+
def YOrN_auxeval(model, line):
prompt = YOrN_match_prompt(line)
retry = 5
@@ -153,6 +164,7 @@ def YOrN_auxeval(model, line):
return ans
return 'Unknown'
+
def YOrN_eval(eval_file, model='chatgpt-0613', nproc=4, verbose=False, dataset=None):
logger = get_logger('Evaluation')
data = load(eval_file)
@@ -170,7 +182,7 @@ def YOrN_eval(eval_file, model='chatgpt-0613', nproc=4, verbose=False, dataset=N
data['extracted'] = [ans_map[x] for x in data['index']]
unknown = data[data['extracted'] == 'Unknown']
-
+
model_name = 'chatgpt-0613'
if INTERNAL or gpt_key_set():
@@ -185,19 +197,20 @@ def YOrN_eval(eval_file, model='chatgpt-0613', nproc=4, verbose=False, dataset=N
tups = [(model, line) for line in lines]
indices = list(unknown['index'])
if len(tups):
- res = track_progress_rich(YOrN_auxeval, tups, nproc=nproc, chunksize=nproc, keys=indices, save=tmp_file)
+ res = track_progress_rich(
+ YOrN_auxeval, tups, nproc=nproc, chunksize=nproc, keys=indices, save=tmp_file)
for k, v in zip(indices, res):
ans_map[k] = v
data['extracted'] = [ans_map[x] for x in data['index']]
dump(data, storage)
else:
- logger.warning(f"GPT matching file {storage} already exists, will reuse it in YOrN_eval. ")
-
+ logger.warning(f'GPT matching file {storage} already exists, will reuse it in YOrN_eval. ')
+
data = load(storage)
- data["score"] = (data["answer"] == data["extracted"])
+ data['score'] = (data['answer'] == data['extracted'])
dump(data, storage)
-
+
if dataset is not None and listinstr(['MME'], dataset):
score = MME_rating(storage)
elif dataset is not None and listinstr(['Hallusion'], dataset):
@@ -213,16 +226,23 @@ def YOrN_eval(eval_file, model='chatgpt-0613', nproc=4, verbose=False, dataset=N
logger.info(score)
return score
+
def parse_args():
- parser = argparse.ArgumentParser(description="Inference LLM Answers. ")
- parser.add_argument("data", type=str, help="The question set for inference, in excel / tsv / json format. ")
- parser.add_argument("--model", type=str, help="The LLM (GPT) used for inference. ", default="chatgpt-0613", choices=['chatgpt-0613'])
- parser.add_argument("--nproc", type=int, default=4)
- parser.add_argument("--dataset", type=str, default=None)
- parser.add_argument("--verbose", action='store_true')
+ parser = argparse.ArgumentParser(description='Inference LLM Answers. ')
+ parser.add_argument('data', type=str, help='The question set for inference, in excel / tsv / json format. ')
+ parser.add_argument(
+ '--model',
+ type=str,
+ help='The LLM (GPT) used for inference. ',
+ default='chatgpt-0613',
+ choices=['chatgpt-0613'])
+ parser.add_argument('--nproc', type=int, default=4)
+ parser.add_argument('--dataset', type=str, default=None)
+ parser.add_argument('--verbose', action='store_true')
args = parser.parse_args()
return args
+
if __name__ == '__main__':
args = parse_args()
acc = YOrN_eval(eval_file=args.data, model=args.model, nproc=args.nproc, verbose=args.verbose, dataset=args.dataset)
diff --git a/vlmeval/inference.py b/vlmeval/inference.py
index be68a82c4..23202fd56 100644
--- a/vlmeval/inference.py
+++ b/vlmeval/inference.py
@@ -1,4 +1,4 @@
-import torch
+import torch
import torch.distributed as dist
import datetime
from vlmeval.config import supported_VLM, api_models
@@ -7,18 +7,20 @@
FAIL_MSG = 'Failed to obtain answer via API.'
+
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, nargs='+', required=True)
- parser.add_argument("--model", type=str, nargs='+', required=True)
- parser.add_argument("--nproc", type=int, default=4, required=True)
- parser.add_argument("--verbose", action='store_true')
+ parser.add_argument('--model', type=str, nargs='+', required=True)
+ parser.add_argument('--nproc', type=int, default=4, required=True)
+ parser.add_argument('--verbose', action='store_true')
args = parser.parse_args()
return args
+
# Only API model is accepted
def infer_data_api(work_dir, model_name, dataset_name, index_set=None, api_nproc=4, ignore_failed=False):
- rank, world_size = get_rank_and_world_size()
+ rank, world_size = get_rank_and_world_size()
assert rank == 0 and world_size == 1
dataset = TSVDataset(dataset_name)
data = dataset.data
@@ -27,20 +29,20 @@ def infer_data_api(work_dir, model_name, dataset_name, index_set=None, api_nproc
model = supported_VLM[model_name]() if isinstance(model_name, str) else model_name
assert getattr(model, 'is_api', False)
-
+
lt, indices = len(data), list(data['index'])
structs = [dataset.build_prompt(data.iloc[i]) for i in range(lt)]
-
+
out_file = f'{work_dir}/{model_name}_{dataset_name}_supp.pkl'
res = {}
if osp.exists(out_file):
res = load(out_file)
if ignore_failed:
res = {k: v for k, v in res.items() if FAIL_MSG not in v}
-
+
structs = [s for i, s in zip(indices, structs) if i not in res]
indices = [i for i in indices if i not in res]
-
+
gen_func = None
if listinstr(['MMMU', 'CORE_MM'], dataset_name):
assert hasattr(model, 'interleave_generate')
@@ -55,21 +57,23 @@ def infer_data_api(work_dir, model_name, dataset_name, index_set=None, api_nproc
if len(structs):
track_progress_rich(gen_func, structs, nproc=api_nproc, chunksize=api_nproc, save=out_file, keys=indices)
-
+
res = load(out_file)
if index_set is not None:
res = {k: v for k, v in res.items() if k in index_set}
os.remove(out_file)
return res
+
def infer_data(model_name, work_dir, dataset_name, out_file, verbose=False, api_nproc=4):
prev_file = f'{work_dir}/{model_name}_{dataset_name}_PREV.pkl'
res = load(prev_file) if osp.exists(prev_file) else {}
if osp.exists(out_file):
res.update(load(out_file))
- rank, world_size = get_rank_and_world_size()
- if rank == 0:dataset = TSVDataset(dataset_name)
+ rank, world_size = get_rank_and_world_size()
+ if rank == 0:
+ dataset = TSVDataset(dataset_name)
if world_size > 1:
dist.barrier()
dataset = TSVDataset(dataset_name)
@@ -89,7 +93,7 @@ def infer_data(model_name, work_dir, dataset_name, out_file, verbose=False, api_
res = {k: res[k] for k in data_indices}
dump(res, out_file)
return
-
+
# Data need to be inferred
data = data[~data['index'].isin(res)]
lt = len(data)
@@ -99,7 +103,12 @@ def infer_data(model_name, work_dir, dataset_name, out_file, verbose=False, api_
is_api = getattr(model, 'is_api', False)
if is_api:
lt, indices = len(data), list(data['index'])
- supp = infer_data_api(work_dir=work_dir, model_name=model_name, dataset_name=dataset_name, index_set=set(indices), api_nproc=api_nproc)
+ supp = infer_data_api(
+ work_dir=work_dir,
+ model_name=model_name,
+ dataset_name=dataset_name,
+ index_set=set(indices),
+ api_nproc=api_nproc)
for idx in indices:
assert idx in supp
res.update(supp)
@@ -109,7 +118,8 @@ def infer_data(model_name, work_dir, dataset_name, out_file, verbose=False, api_
for i in tqdm(range(lt)):
idx = data.iloc[i]['index']
- if idx in res: continue
+ if idx in res:
+ continue
if hasattr(model, 'use_custom_prompt') and model.use_custom_prompt(dataset_name):
struct = model.build_prompt(data.iloc[i], dataset=dataset_name)
@@ -125,11 +135,13 @@ def infer_data(model_name, work_dir, dataset_name, out_file, verbose=False, api_
elif len(struct['image']) == 1:
response = model.generate(prompt=struct['text'], image_path=struct['image'][0], dataset=dataset_name)
else:
- response = '[MMMU] Failed, multiple images exist while the model only support single-image generate API. '
+ response = (
+ '[MMMU] Failed, multiple images exist while the model only support single-image generate API. '
+ )
else:
response = model.generate(prompt=struct['text'], image_path=struct['image'], dataset=dataset_name)
torch.cuda.empty_cache()
-
+
if verbose:
print(response, flush=True)
@@ -141,6 +153,7 @@ def infer_data(model_name, work_dir, dataset_name, out_file, verbose=False, api_
dump(res, out_file)
return model
+
def prefetch_acc(result_file):
data = load(result_file)
from vlmeval.evaluate.multiple_choice import build_choices, can_infer
@@ -175,8 +188,9 @@ def prefetch_acc(result_file):
res = pd.DataFrame(res)
return res
+
def infer_data_job(model, work_dir, model_name, dataset_name, verbose=False, api_nproc=4, ignore_failed=False):
- rank, world_size = get_rank_and_world_size()
+ rank, world_size = get_rank_and_world_size()
result_file = osp.join(work_dir, f'{model_name}_{dataset_name}.xlsx')
prev_file = f'{work_dir}/{model_name}_{dataset_name}_PREV.pkl'
@@ -187,13 +201,16 @@ def infer_data_job(model, work_dir, model_name, dataset_name, verbose=False, api
if not ignore_failed:
results = {k: v for k, v in results.items() if FAIL_MSG not in str(v)}
dump(results, prev_file)
- if world_size > 1: dist.barrier()
+ if world_size > 1:
+ dist.barrier()
tmpl = osp.join(work_dir, '{}' + f'{world_size}_{dataset_name}.pkl')
out_file = tmpl.format(rank)
- model = infer_data(model, work_dir=work_dir, dataset_name=dataset_name, out_file=out_file, verbose=verbose, api_nproc=api_nproc)
- if world_size > 1: dist.barrier()
+ model = infer_data(
+ model, work_dir=work_dir, dataset_name=dataset_name, out_file=out_file, verbose=verbose, api_nproc=api_nproc)
+ if world_size > 1:
+ dist.barrier()
if rank == 0:
data_all = {}
@@ -206,7 +223,7 @@ def infer_data_job(model, work_dir, model_name, dataset_name, verbose=False, api
data['prediction'] = [str(data_all[x]) for x in data['index']]
data.pop('image')
- dump(data, result_file)
+ dump(data, result_file)
for i in range(world_size):
os.remove(tmpl.format(i))
- return model
\ No newline at end of file
+ return model
diff --git a/vlmeval/smp/__init__.py b/vlmeval/smp/__init__.py
index 14d240df4..46e89687d 100644
--- a/vlmeval/smp/__init__.py
+++ b/vlmeval/smp/__init__.py
@@ -1,4 +1,4 @@
from .file import *
from .vlm import *
from .misc import *
-from .log import *
\ No newline at end of file
+from .log import *
diff --git a/vlmeval/smp/file.py b/vlmeval/smp/file.py
index dd201f474..e6a0e050d 100644
--- a/vlmeval/smp/file.py
+++ b/vlmeval/smp/file.py
@@ -8,6 +8,7 @@
import time
import numpy as np
+
class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, (np.int_, np.intc, np.intp, np.int8,
@@ -22,10 +23,11 @@ def default(self, obj):
return obj.tolist()
elif isinstance(obj, (np.bool_)):
return bool(obj)
- elif isinstance(obj, (np.void)):
+ elif isinstance(obj, (np.void)):
return None
return json.JSONEncoder.default(self, obj)
+
# LOAD & DUMP
def dump(data, f, **kwargs):
def dump_pkl(data, pth, **kwargs):
@@ -52,6 +54,7 @@ def dump_tsv(data, f, quoting=csv.QUOTE_ALL):
suffix = f.split('.')[-1]
return handlers[suffix](data, f, **kwargs)
+
def load(f):
def load_pkl(pth):
return pickle.load(open(pth, 'rb'))
@@ -78,7 +81,8 @@ def load_tsv(f):
handlers = dict(pkl=load_pkl, json=load_json, jsonl=load_jsonl, xlsx=load_xlsx, csv=load_csv, tsv=load_tsv)
suffix = f.split('.')[-1]
- return handlers[suffix](f)
+ return handlers[suffix](f)
+
def download_file(url, filename=None):
import urllib.request
@@ -89,7 +93,7 @@ def update_to(self, b=1, bsize=1, tsize=None):
if tsize is not None:
self.total = tsize
self.update(b * bsize - self.n)
-
+
if filename is None:
filename = url.split('/')[-1]
@@ -98,6 +102,7 @@ def update_to(self, b=1, bsize=1, tsize=None):
urllib.request.urlretrieve(url, filename=filename, reporthook=t.update_to)
return filename
+
def ls(dirname='.', match=[], mode='all', level=1):
if dirname == '.':
ans = os.listdir(dirname)
@@ -109,11 +114,12 @@ def ls(dirname='.', match=[], mode='all', level=1):
if isinstance(match, str):
match = [match]
for m in match:
- if len(m) == 0: continue
- if m[0] != '!':
+ if len(m) == 0:
+ continue
+ if m[0] != '!':
ans = [x for x in ans if m in x]
else:
- ans = [x for x in ans if m[1:] not in x]
+ ans = [x for x in ans if m[1:] not in x]
if mode == 'dir':
ans = [x for x in ans if osp.isdir(x)]
elif mode == 'file':
@@ -122,20 +128,23 @@ def ls(dirname='.', match=[], mode='all', level=1):
ans = [x for x in ans if osp.isdir(x)]
res = []
for d in ans:
- res.extend(ls(d, match=match, mode=mode, level=level-1))
+ res.extend(ls(d, match=match, mode=mode, level=level - 1))
ans = res
return ans
+
def mrlines(fname, sp='\n'):
f = open(fname).read().split(sp)
while f != [] and f[-1] == '':
f = f[:-1]
return f
+
def mwlines(lines, fname):
with open(fname, 'w') as fout:
fout.write('\n'.join(lines))
+
def md5(file_pth):
with open(file_pth, 'rb') as f:
hash = hashlib.new('md5')
@@ -143,6 +152,7 @@ def md5(file_pth):
hash.update(chunk)
return str(hash.hexdigest())
+
def last_modified(pth):
stamp = osp.getmtime(pth)
m_ti = time.ctime(stamp)
diff --git a/vlmeval/smp/log.py b/vlmeval/smp/log.py
index 53194dd99..95804d5ed 100644
--- a/vlmeval/smp/log.py
+++ b/vlmeval/smp/log.py
@@ -2,11 +2,12 @@
logger_initialized = {}
+
def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
logger = logging.getLogger(name)
if name in logger_initialized:
return logger
-
+
for logger_name in logger_initialized:
if name.startswith(logger_name):
return logger
@@ -40,4 +41,4 @@ def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
logger.setLevel(logging.ERROR)
logger_initialized[name] = True
- return logger
\ No newline at end of file
+ return logger
diff --git a/vlmeval/smp/misc.py b/vlmeval/smp/misc.py
index d685b40ec..2a6fa7fc0 100644
--- a/vlmeval/smp/misc.py
+++ b/vlmeval/smp/misc.py
@@ -84,7 +84,7 @@ def get_cache_path(repo_id):
rev2keep, last_modified = None, 0
for rev in revs:
if rev.last_modified > last_modified:
- rev2keep, last_modified = rev, rev.last_modified
+ rev2keep, last_modified = rev, rev.last_modified
if rev2keep is None:
return None
return str(rev2keep.snapshot_path)
@@ -95,8 +95,8 @@ def proxy_set(s):
os.environ[key] = s
def get_rank_and_world_size():
- local_rank = int(os.environ.get("LOCAL_RANK", 0))
- world_size = int(os.environ.get("WORLD_SIZE", 1))
+ local_rank = int(os.environ.get('LOCAL_RANK', 0))
+ world_size = int(os.environ.get('WORLD_SIZE', 1))
return local_rank, world_size
def splitlen(s, sym='/'):
diff --git a/vlmeval/smp/vlm.py b/vlmeval/smp/vlm.py
index 3ab3f1d89..a5525748c 100644
--- a/vlmeval/smp/vlm.py
+++ b/vlmeval/smp/vlm.py
@@ -1,4 +1,5 @@
-import os, io
+import os
+import io
import pandas as pd
import numpy as np
import string
@@ -8,6 +9,7 @@
from PIL import Image
from .file import load, dump
+
def mmqa_display(question):
question = {k.lower(): v for k, v in question.items()}
keys = list(question.keys())
@@ -22,21 +24,22 @@ def mmqa_display(question):
for im in images:
image = decode_base64_to_image(im, target_size=512)
- display(image)
-
+ display(image) # noqa: F821
+
for k in keys:
- try:
+ try:
if not pd.isna(question[k]):
print(f'{k.upper()}. {question[k]}')
except ValueError:
if False in pd.isna(question[k]):
print(f'{k.upper()}. {question[k]}')
+
def encode_image_to_base64(img, target_size=-1):
# if target_size == -1, will not do resizing
# else, will set the max_size ot (target_size, target_size)
- if img.mode in ("RGBA", "P"):
- img = img.convert("RGB")
+ if img.mode in ('RGBA', 'P'):
+ img = img.convert('RGB')
tmp = osp.join('/tmp', str(uuid4()) + '.jpg')
if target_size > 0:
img.thumbnail((target_size, target_size))
@@ -47,10 +50,12 @@ def encode_image_to_base64(img, target_size=-1):
os.remove(tmp)
return ret
+
def encode_image_file_to_base64(image_path, target_size=-1):
image = Image.open(image_path)
return encode_image_to_base64(image, target_size=target_size)
-
+
+
def decode_base64_to_image(base64_string, target_size=-1):
image_data = base64.b64decode(base64_string)
image = Image.open(io.BytesIO(image_data))
@@ -60,10 +65,12 @@ def decode_base64_to_image(base64_string, target_size=-1):
image.thumbnail((target_size, target_size))
return image
+
def decode_base64_to_image_file(base64_string, image_path, target_size=-1):
image = decode_base64_to_image(base64_string, target_size=target_size)
image.save(image_path)
+
def LMUDataRoot():
if 'LMUData' in os.environ and osp.exists(os.environ['LMUData']):
return os.environ['LMUData']
@@ -72,6 +79,7 @@ def LMUDataRoot():
os.makedirs(root, exist_ok=True)
return root
+
def build_option_str(option_dict):
s = 'There are several options: \n'
for c, content in option_dict.items():
@@ -79,9 +87,11 @@ def build_option_str(option_dict):
s += f'{c}. {content}\n'
return s
+
def isimg(s):
return osp.exists(s) or s.startswith('http')
+
def read_ok(img_path):
if not osp.exists(img_path):
return False
@@ -92,17 +102,20 @@ def read_ok(img_path):
except:
return False
+
def gpt_key_set():
openai_key = os.environ.get('OPENAI_API_KEY', None)
return isinstance(openai_key, str) and openai_key.startswith('sk-')
+
def apiok(wrapper):
- s = wrapper.generate("Hello!")
+ s = wrapper.generate('Hello!')
return wrapper.fail_msg not in s
+
def circular_pred(df, extract_func=None):
if extract_func is None:
- extract_func = lambda x: x
+ extract_func = lambda x: x # noqa: E731
df = df.sort_values('index')
from vlmeval.utils import can_infer_option
shift = int(1e6)
@@ -113,7 +126,11 @@ def circular_pred(df, extract_func=None):
valid_map = {i: True for i in pred_map if i < 1e6}
for i in df['index']:
if i >= shift and pred_map[i] and pred_map[i - shift]:
- if pred_map[i] not in list(string.ascii_uppercase) or pred_map[i - shift] not in list(string.ascii_uppercase):
+ if (
+ pred_map[i] not in list(string.ascii_uppercase) or # noqa: W504
+ pred_map[i - shift] not in list(string.ascii_uppercase)
+ ):
+
valid_map[i % shift] = False
continue
if (ord(pred_map[i]) - ord(pred_map[i - shift])) % 4 == 1:
@@ -124,6 +141,7 @@ def circular_pred(df, extract_func=None):
flags = list(flag_map.values())
return np.mean(flags)
+
def MMBenchOfficialServer():
root = LMUDataRoot()
for dataset in ['MMBench', 'MMBench_CN', 'MMBench_TEST_EN', 'MMBench_TEST_CN']:
@@ -131,4 +149,4 @@ def MMBenchOfficialServer():
data = load(f'{root}/{dataset}.tsv')
if 'answer' in data and sum([pd.isna(x) for x in data['answer']]) == 0:
return True
- return False
\ No newline at end of file
+ return False
diff --git a/vlmeval/utils/__init__.py b/vlmeval/utils/__init__.py
index fb993c12b..66d8c6765 100644
--- a/vlmeval/utils/__init__.py
+++ b/vlmeval/utils/__init__.py
@@ -6,7 +6,7 @@
__all__ = [
- 'can_infer', 'can_infer_option', 'can_infer_text', 'track_progress_rich',
+ 'can_infer', 'can_infer_option', 'can_infer_text', 'track_progress_rich',
'TSVDataset', 'dataset_URLs', 'img_root_map', 'DATASET_TYPE', 'CustomPrompt',
'split_MMMU', 'abbr2full'
-]
\ No newline at end of file
+]
diff --git a/vlmeval/utils/custom_prompt.py b/vlmeval/utils/custom_prompt.py
index a75cfd75a..5841fbcec 100644
--- a/vlmeval/utils/custom_prompt.py
+++ b/vlmeval/utils/custom_prompt.py
@@ -2,22 +2,23 @@
from .dataset_config import img_root_map
from abc import abstractmethod
+
class CustomPrompt:
@abstractmethod
def use_custom_prompt(self, dataset):
raise NotImplementedError
-
+
@abstractmethod
def build_prompt(self, line, dataset):
raise NotImplementedError
-
+
def dump_image(self, line, dataset):
ROOT = LMUDataRoot()
assert isinstance(dataset, str)
img_root = osp.join(ROOT, 'images', img_root_map[dataset] if dataset in img_root_map else dataset)
os.makedirs(img_root, exist_ok=True)
- if isinstance(line['image'], list):
+ if isinstance(line['image'], list):
tgt_path = []
assert 'image_path' in line
for img, im_name in zip(line['image'], line['image_path']):
@@ -29,4 +30,4 @@ def dump_image(self, line, dataset):
tgt_path = osp.join(img_root, f"{line['index']}.jpg")
if not read_ok(tgt_path):
decode_base64_to_image_file(line['image'], tgt_path)
- return tgt_path
\ No newline at end of file
+ return tgt_path
diff --git a/vlmeval/utils/dataset.py b/vlmeval/utils/dataset.py
index 6658c8f0f..d3507ee9e 100644
--- a/vlmeval/utils/dataset.py
+++ b/vlmeval/utils/dataset.py
@@ -4,9 +4,11 @@
from .dataset_config import dataset_URLs, dataset_md5_dict, DATASET_TYPE
from .custom_prompt import CustomPrompt
+
def isliststr(s):
return (s[0] == '[') and (s[-1] == ']')
+
def check_md5(data_path, dataset):
try:
with open(data_path, 'rb') as f:
@@ -20,7 +22,8 @@ def check_md5(data_path, dataset):
return False
except:
return False
-
+
+
def split_MMMU(struct):
assert 'image' in struct and 'text' in struct
text, images = struct['text'], struct['image']
@@ -35,8 +38,9 @@ def split_MMMU(struct):
segs.append(seg[2:])
return segs
+
class TSVDataset(CustomPrompt):
-
+
def __init__(self, dataset='MMBench', skip_noimg=True):
self.data_root = LMUDataRoot()
@@ -50,10 +54,12 @@ def __init__(self, dataset='MMBench', skip_noimg=True):
file_name = url.split('/')[-1]
data_path = osp.join(self.data_root, file_name)
- if osp.exists(data_path) and (md5(data_path) == dataset_md5_dict[dataset] if dataset in dataset_md5_dict else True):
+ if osp.exists(data_path) and (
+ md5(data_path) == dataset_md5_dict[dataset] if dataset in dataset_md5_dict else True
+ ):
pass
else:
- warnings.warn("The dataset tsv is not downloaded")
+ warnings.warn('The dataset tsv is not downloaded')
download_file(url, data_path)
else:
data_path = osp.join(self.data_root, dataset + '.tsv')
@@ -66,20 +72,23 @@ def __init__(self, dataset='MMBench', skip_noimg=True):
# Prompt for Captioning
if listinstr(['COCO'], dataset):
- data['question'] = ['Please describe this image in general. Directly provide the description, do not include prefix like "This image depicts". '] * len(data)
+ data['question'] = [(
+ 'Please describe this image in general. Directly provide the description, '
+ 'do not include prefix like "This image depicts". '
+ )] * len(data)
data['index'] = [str(x) for x in data['index']]
data['image'] = [str(x) for x in data['image']]
-
+
image_map = {x: y for x, y in zip(data['index'], data['image'])}
for k in image_map:
if len(image_map[k]) <= 64:
idx = image_map[k]
assert idx in image_map and len(image_map[idx]) > 64
image_map[k] = image_map[idx]
-
+
data['image'] = [
- eval(image_map[k]) if isliststr(image_map[k]) else image_map[k]
+ eval(image_map[k]) if isliststr(image_map[k]) else image_map[k]
for k in data['index']
]
if 'image_path' in data:
@@ -88,7 +97,7 @@ def __init__(self, dataset='MMBench', skip_noimg=True):
]
if np.all([istype(x, int) for x in data['index']]):
data['index'] = [int(x) for x in data['index']]
-
+
self.data = data
def __len__(self):
@@ -122,11 +131,10 @@ def build_prompt(self, line, dataset=None):
if len(options):
prompt += options_prompt
prompt += 'Please select the correct answer from the options above. \n'
-
+
return dict(image=tgt_path, text=prompt)
-
+
def display(self, line):
if isinstance(line, int):
line = self.data.iloc[line]
mmqa_display(line)
-
diff --git a/vlmeval/utils/dataset_config.py b/vlmeval/utils/dataset_config.py
index 6e00cca64..add1a9b0f 100644
--- a/vlmeval/utils/dataset_config.py
+++ b/vlmeval/utils/dataset_config.py
@@ -1,87 +1,88 @@
from ..smp import listinstr
dataset_URLs = {
- 'MMBench_DEV_EN': "https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_EN.tsv",
- 'MMBench_TEST_EN': "https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_EN.tsv",
- 'MMBench_DEV_CN': "https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_CN.tsv",
- 'MMBench_TEST_CN': "https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_CN.tsv",
- "MMBench": "https://opencompass.openxlab.space/utils/VLMEval/MMBench.tsv", # Link Invalid, Internal Only
- "MMBench_CN": "https://opencompass.openxlab.space/utils/VLMEval/MMBench_CN.tsv", # Link Invalid, Internal Only
- 'CCBench': "https://opencompass.openxlab.space/utils/VLMEval/CCBench.tsv",
- 'MME': "https://opencompass.openxlab.space/utils/VLMEval/MME.tsv",
- 'SEEDBench_IMG': "https://opencompass.openxlab.space/utils/VLMEval/SEEDBench_IMG.tsv",
- "CORE_MM": "https://opencompass.openxlab.space/utils/VLMEval/CORE_MM.tsv",
- "MMVet": "https://opencompass.openxlab.space/utils/VLMEval/MMVet.tsv",
- "COCO_VAL": "https://opencompass.openxlab.space/utils/VLMEval/COCO_VAL.tsv",
- "OCRVQA_TEST": "https://opencompass.openxlab.space/utils/VLMEval/OCRVQA_TEST.tsv",
- "OCRVQA_TESTCORE": "https://opencompass.openxlab.space/utils/VLMEval/OCRVQA_TESTCORE.tsv",
- 'TextVQA_VAL': "https://opencompass.openxlab.space/utils/VLMEval/TextVQA_VAL.tsv",
- "MMMU_DEV_VAL": "https://opencompass.openxlab.space/utils/VLMEval/MMMU_DEV_VAL.tsv",
- "MMMU_TEST": "https://opencompass.openxlab.space/utils/VLMEval/MMMU_TEST.tsv",
- "MathVista_MINI": "https://opencompass.openxlab.space/utils/VLMEval/MathVista_MINI.tsv",
- 'ChartQA_VALTEST_HUMAN': "https://opencompass.openxlab.space/utils/VLMEval/ChartQA_VALTEST_HUMAN.tsv",
- 'ScienceQA_VAL': "https://opencompass.openxlab.space/utils/VLMEval/ScienceQA_VAL.tsv",
- 'ScienceQA_TEST': "https://opencompass.openxlab.space/utils/VLMEval/ScienceQA_TEST.tsv",
- 'HallusionBench': "https://opencompass.openxlab.space/utils/VLMEval/HallusionBench.tsv",
- "DocVQA_VAL": "https://opencompass.openxlab.space/utils/VLMEval/DocVQA_VAL.tsv",
- 'AI2D_TEST': "https://opencompass.openxlab.space/utils/VLMEval/AI2D_TEST.tsv",
- "LLaVABench": "https://opencompass.openxlab.space/utils/VLMEval/LLaVABench.tsv",
- "OCRBench": 'https://opencompass.openxlab.space/utils/VLMEval/OCRBench.tsv',
+ 'MMBench_DEV_EN': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_EN.tsv',
+ 'MMBench_TEST_EN': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_EN.tsv',
+ 'MMBench_DEV_CN': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_CN.tsv',
+ 'MMBench_TEST_CN': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_CN.tsv',
+ 'MMBench': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench.tsv', # Link Invalid, Internal Only
+ 'MMBench_CN': 'https://opencompass.openxlab.space/utils/VLMEval/MMBench_CN.tsv', # Link Invalid, Internal Only
+ 'CCBench': 'https://opencompass.openxlab.space/utils/VLMEval/CCBench.tsv',
+ 'MME': 'https://opencompass.openxlab.space/utils/VLMEval/MME.tsv',
+ 'SEEDBench_IMG': 'https://opencompass.openxlab.space/utils/VLMEval/SEEDBench_IMG.tsv',
+ 'CORE_MM': 'https://opencompass.openxlab.space/utils/VLMEval/CORE_MM.tsv',
+ 'MMVet': 'https://opencompass.openxlab.space/utils/VLMEval/MMVet.tsv',
+ 'COCO_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/COCO_VAL.tsv',
+ 'OCRVQA_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/OCRVQA_TEST.tsv',
+ 'OCRVQA_TESTCORE': 'https://opencompass.openxlab.space/utils/VLMEval/OCRVQA_TESTCORE.tsv',
+ 'TextVQA_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/TextVQA_VAL.tsv',
+ 'MMMU_DEV_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/MMMU_DEV_VAL.tsv',
+ 'MMMU_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/MMMU_TEST.tsv',
+ 'MathVista_MINI': 'https://opencompass.openxlab.space/utils/VLMEval/MathVista_MINI.tsv',
+ 'ChartQA_VALTEST_HUMAN': 'https://opencompass.openxlab.space/utils/VLMEval/ChartQA_VALTEST_HUMAN.tsv',
+ 'ScienceQA_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/ScienceQA_VAL.tsv',
+ 'ScienceQA_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/ScienceQA_TEST.tsv',
+ 'HallusionBench': 'https://opencompass.openxlab.space/utils/VLMEval/HallusionBench.tsv',
+ 'DocVQA_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/DocVQA_VAL.tsv',
+ 'AI2D_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/AI2D_TEST.tsv',
+ 'LLaVABench': 'https://opencompass.openxlab.space/utils/VLMEval/LLaVABench.tsv',
+ 'OCRBench': 'https://opencompass.openxlab.space/utils/VLMEval/OCRBench.tsv',
}
dataset_md5_dict = {
- 'MMBench_DEV_EN': "b6caf1133a01c6bb705cf753bb527ed8",
- 'MMBench_TEST_EN': "6939fadb0ce626fefc0bdc9c64efc528",
- 'MMBench_DEV_CN': "08b8fc3324a5ed74155350f57be69fbd",
- 'MMBench_TEST_CN': "7e1239baf0ee4c8b513e19705a0f317e",
- "MMBench": "4115aea3383f3dd0083be6a633e0f820", # Link Invalid, Internal Only
- "MMBench_CN": "2e053ffc90ea598b1feae13c36dc13ee", # Link Invalid, Internal Only
- 'CCBench': "1de88b4257e7eee3f60b18d45eda6f07",
- 'MME': "b36b43c3f09801f5d368627fb92187c3",
- 'SEEDBench_IMG': "68017231464752261a2526d6ca3a10c0",
- "CORE_MM": "8a8da2f2232e79caf98415bfdf0a202d",
- "MMVet": "f400d7f513a585a0f218cbd6882e0671",
- 'COCO_VAL': "72a5079dead060269ac222c5aa5128af",
+ 'MMBench_DEV_EN': 'b6caf1133a01c6bb705cf753bb527ed8',
+ 'MMBench_TEST_EN': '6939fadb0ce626fefc0bdc9c64efc528',
+ 'MMBench_DEV_CN': '08b8fc3324a5ed74155350f57be69fbd',
+ 'MMBench_TEST_CN': '7e1239baf0ee4c8b513e19705a0f317e',
+ 'MMBench': '4115aea3383f3dd0083be6a633e0f820', # Link Invalid, Internal Only
+ 'MMBench_CN': '2e053ffc90ea598b1feae13c36dc13ee', # Link Invalid, Internal Only
+ 'CCBench': '1de88b4257e7eee3f60b18d45eda6f07',
+ 'MME': 'b36b43c3f09801f5d368627fb92187c3',
+ 'SEEDBench_IMG': '68017231464752261a2526d6ca3a10c0',
+ 'CORE_MM': '8a8da2f2232e79caf98415bfdf0a202d',
+ 'MMVet': 'f400d7f513a585a0f218cbd6882e0671',
+ 'COCO_VAL': '72a5079dead060269ac222c5aa5128af',
'OCRVQA_TEST': 'ca46a6d74b403e9d6c0b670f6fc00db9',
'OCRVQA_TESTCORE': 'c5239fe77db8bdc1f2ad8e55e0d1fe97',
'TextVQA_VAL': 'b233b31f551bbf4056f2f955da3a92cd',
- 'MMMU_DEV_VAL': "521afc0f3bf341e6654327792781644d",
- 'MMMU_TEST': "c19875d11a2d348d07e5eb4bdf33166d",
+ 'MMMU_DEV_VAL': '521afc0f3bf341e6654327792781644d',
+ 'MMMU_TEST': 'c19875d11a2d348d07e5eb4bdf33166d',
'MathVista_MINI': 'f199b98e178e5a2a20e7048f5dcb0464',
- 'ChartQA_VALTEST_HUMAN':'2c90a4133408a21d57fb2ea26f77bbfc',
+ 'ChartQA_VALTEST_HUMAN': '2c90a4133408a21d57fb2ea26f77bbfc',
'ScienceQA_VAL': '96320d05e142e585e7204e72affd29f3',
'ScienceQA_TEST': 'e42e9e00f9c59a80d8a5db35bc32b71f',
'HallusionBench': '0c23ac0dc9ef46832d7a24504f2a0c7c',
- "DocVQA_VAL": 'd5ee77e1926ff10690d469c56b73eabf',
- "AI2D_TEST": "0f593e0d1c7df9a3d69bf1f947e71975",
- "LLaVABench": "d382a093f749a697820d3dadd61c8428",
- "OCRBench": 'e953d98a987cc6e26ef717b61260b778',
+ 'DocVQA_VAL': 'd5ee77e1926ff10690d469c56b73eabf',
+ 'AI2D_TEST': '0f593e0d1c7df9a3d69bf1f947e71975',
+ 'LLaVABench': 'd382a093f749a697820d3dadd61c8428',
+ 'OCRBench': 'e953d98a987cc6e26ef717b61260b778',
}
img_root_map = {k: k for k in dataset_URLs}
img_root_map.update({
- 'MMBench_DEV_EN': "MMBench",
- 'MMBench_TEST_EN': "MMBench",
- 'MMBench_DEV_CN': "MMBench",
- 'MMBench_TEST_CN': "MMBench",
- "MMBench_CN": "MMBench", # Link Invalid, Internal Only
- 'COCO_VAL':'COCO',
+ 'MMBench_DEV_EN': 'MMBench',
+ 'MMBench_TEST_EN': 'MMBench',
+ 'MMBench_DEV_CN': 'MMBench',
+ 'MMBench_TEST_CN': 'MMBench',
+ 'MMBench_CN': 'MMBench', # Link Invalid, Internal Only
+ 'COCO_VAL': 'COCO',
'OCRVQA_TEST': 'OCRVQA',
'OCRVQA_TESTCORE': 'OCRVQA',
'TextVQA_VAL': 'TextVQA',
'MMMU_DEV_VAL': 'MMMU',
- "MMMU_TEST": "MMMU",
+ 'MMMU_TEST': 'MMMU',
'MathVista_MINI': 'MathVista',
'ChartQA_VALTEST_HUMAN': 'ChartQA',
'HallusionBench': 'Hallusion',
'DocVQA_VAL': 'DocVQA',
- "OCRBench": 'OCRBench',
+ 'OCRBench': 'OCRBench',
})
assert set(dataset_URLs) == set(img_root_map) == set(dataset_md5_dict)
+
def DATASET_TYPE(dataset):
- # Dealing with Custom Dataset
+ # Dealing with Custom Dataset
dataset = dataset.lower()
if listinstr(['mmbench', 'seedbench', 'ccbench', 'mmmu', 'scienceqa', 'ai2d'], dataset):
return 'multi-choice'
@@ -99,6 +100,7 @@ def DATASET_TYPE(dataset):
else:
return 'QA'
+
def abbr2full(s):
datasets = [x for x in img_root_map]
ins = [s in d for d in datasets]
diff --git a/vlmeval/utils/matching_util.py b/vlmeval/utils/matching_util.py
index 2d1005118..ccdf1b7d5 100644
--- a/vlmeval/utils/matching_util.py
+++ b/vlmeval/utils/matching_util.py
@@ -3,22 +3,23 @@
import os
from ..smp import *
+
def can_infer_option(answer, choices):
verbose = os.environ.get('VERBOSE', 0)
# Choices is a dictionary
if 'Failed to obtain answer via API' in answer:
return False
-
+
reject_to_answer = [
"Sorry, I can't help with images of people yet.",
"I can't process this file.",
"I'm sorry, but without the image provided",
- "Cannot determine the answer"
+ 'Cannot determine the answer'
]
for err in reject_to_answer:
if err in answer:
return 'Z'
-
+
def count_choice(splits, choices, prefix='', suffix=''):
cnt = 0
for c in choices:
@@ -30,7 +31,7 @@ def count_choice(splits, choices, prefix='', suffix=''):
chars = '.()[],:;!*#{}'
for c in chars:
answer_mod = answer_mod.replace(c, ' ')
-
+
splits = [x.strip() for x in answer_mod.split()]
count = count_choice(splits, choices)
@@ -46,6 +47,7 @@ def count_choice(splits, choices, prefix='', suffix=''):
return 'Z'
return False
+
def can_infer_text(answer, choices):
answer = answer.lower()
assert isinstance(choices, dict)
@@ -60,7 +62,8 @@ def can_infer_text(answer, choices):
return cands[0]
return False
+
def can_infer(answer, choices):
answer = str(answer)
copt = can_infer_option(answer, choices)
- return copt if copt else can_infer_text(answer, choices)
\ No newline at end of file
+ return copt if copt else can_infer_text(answer, choices)
diff --git a/vlmeval/utils/mp_util.py b/vlmeval/utils/mp_util.py
index 5cef74455..2322e2d0e 100644
--- a/vlmeval/utils/mp_util.py
+++ b/vlmeval/utils/mp_util.py
@@ -50,6 +50,7 @@ def _tasks_with_index(tasks):
for idx, task in enumerate(tasks):
yield task, idx
+
def track_progress_rich(func: Callable,
tasks: Iterable = tuple(),
task_num: int = None,
@@ -156,7 +157,7 @@ def track_progress_rich(func: Callable,
dump(ans, save)
fh.flush()
os.fsync(fh.fileno())
-
+
prog_bar.update(task_id, advance=1, refresh=True)
else:
with Pool(nproc) as pool:
@@ -187,4 +188,4 @@ def track_progress_rich(func: Callable,
raise e
for result, idx in unordered_results:
results[idx] = result
- return results
\ No newline at end of file
+ return results
diff --git a/vlmeval/vlm/cogvlm.py b/vlmeval/vlm/cogvlm.py
index 2cee3f6ea..9d86f5068 100644
--- a/vlmeval/vlm/cogvlm.py
+++ b/vlmeval/vlm/cogvlm.py
@@ -7,16 +7,18 @@
from ..utils import DATASET_TYPE, CustomPrompt
from transformers import AutoModelForCausalLM, LlamaTokenizer
+
class CogVlm(CustomPrompt):
INSTALL_REQ = True
- def __init__(self,
- name='cogvlm-chat',tokenizer_name ='lmsys/vicuna-7b-v1.5',
- **kwargs):
+ def __init__(self,
+ name='cogvlm-chat',
+ tokenizer_name='lmsys/vicuna-7b-v1.5',
+ **kwargs):
self.tokenizer = LlamaTokenizer.from_pretrained(tokenizer_name)
self.model = AutoModelForCausalLM.from_pretrained(
- f"THUDM/{name}-hf",
+ f'THUDM/{name}-hf',
torch_dtype=torch.bfloat16,
trust_remote_code=True,
).to('cuda').eval()
@@ -26,12 +28,12 @@ def use_custom_prompt(self, dataset):
if DATASET_TYPE(dataset) == 'multi-choice':
return True
return False
-
+
def build_prompt(self, line, dataset=None):
assert dataset is None or isinstance(dataset, str)
assert self.use_custom_prompt(dataset)
tgt_path = self.dump_image(line, dataset)
-
+
if dataset is not None and DATASET_TYPE(dataset) == 'multi-choice':
question = line['question']
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
@@ -49,25 +51,25 @@ def build_prompt(self, line, dataset=None):
prompt = question
if not cn_string(prompt):
- prompt = prompt + "\n" + "Answer with the option's letter from the given choices directly."
+ prompt = prompt + '\n' + "Answer with the option's letter from the given choices directly."
else:
- prompt = prompt + "\n" + "请直接回答选项字母。"
+ prompt = prompt + '\n' + '请直接回答选项字母。'
else:
prompt = line['question']
return {'image': tgt_path, 'text': prompt}
def generate(self, image_path, prompt, dataset=None):
-
image = Image.open(image_path).convert('RGB')
- inputs = self.model.build_conversation_input_ids(self.tokenizer, query=prompt, history=[], images=[image]) # chat mode
+ inputs = self.model.build_conversation_input_ids(
+ self.tokenizer, query=prompt, history=[], images=[image]) # chat mode
inputs = {
'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'),
'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'),
'attention_mask': inputs['attention_mask'].unsqueeze(0).to('cuda'),
'images': [[inputs['images'][0].to('cuda').to(torch.bfloat16)]],
}
- gen_kwargs = {"max_length": 2048, "do_sample": False}
+ gen_kwargs = {'max_length': 2048, 'do_sample': False}
with torch.no_grad():
outputs = self.model.generate(**inputs, **gen_kwargs)
@@ -75,4 +77,4 @@ def generate(self, image_path, prompt, dataset=None):
# print(tokenizer.decode(outputs[0]))
response = self.tokenizer.decode(outputs[0])
# output = response[len(prompt):]
- return response
\ No newline at end of file
+ return response
diff --git a/vlmeval/vlm/emu.py b/vlmeval/vlm/emu.py
index 3b7113d42..d74413e49 100644
--- a/vlmeval/vlm/emu.py
+++ b/vlmeval/vlm/emu.py
@@ -2,74 +2,78 @@
from PIL import Image
from abc import abstractproperty
import os.path as osp
-import os
+import os
from ..smp import *
class Emu:
- def __init__(self,
- name,
- model_path_map={
- "emu2":"BAAI/Emu2",
- "emu2_chat":"BAAI/Emu2-Chat"
- },
+ def __init__(self,
+ name,
+ model_path_map={'emu2': 'BAAI/Emu2', 'emu2_chat': 'BAAI/Emu2-Chat'},
**kwargs):
-
+
self.model_path_map = model_path_map
assert name in self.model_path_map or osp.exists(name) or splitlen(name) == 2
if name in self.model_path_map:
model_path = self.model_path_map[name]
else:
- model_path = name
+ model_path = name
assert osp.exists(model_path) or splitlen(model_path) == 2
-
+
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import init_empty_weights, infer_auto_device_map, dispatch_model
-
- local_rank,world_size = get_rank_and_world_size()
-
+
+ local_rank, world_size = get_rank_and_world_size()
+
device_num = torch.cuda.device_count()
assert world_size * 2 <= device_num, 'The number of devices does not match the world size'
-
+
device_1 = local_rank
- device_2 = local_rank+world_size
+ device_2 = local_rank + world_size
torch.cuda.set_device(device_1)
torch.cuda.set_device(device_2)
-
- tokenizer = AutoTokenizer.from_pretrained(model_path) # "BAAI/Emu2-Chat"
+
+ tokenizer = AutoTokenizer.from_pretrained(model_path) # "BAAI/Emu2-Chat"
self.tokenizer = tokenizer
with init_empty_weights():
model = AutoModelForCausalLM.from_pretrained(
- model_path, # "BAAI/Emu2-Chat"
+ model_path, # "BAAI/Emu2-Chat"
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
- trust_remote_code=True)
+ trust_remote_code=True)
+
+ device_map = infer_auto_device_map(
+ model,
+ max_memory={
+ device_1: '38GiB',
+ device_2: '38GiB'
+ },
+ no_split_module_classes=['Block', 'LlamaDecoderLayer'])
- device_map = infer_auto_device_map(model, max_memory={device_1:'38GiB',device_2:'38GiB',}, no_split_module_classes=['Block','LlamaDecoderLayer'])
# input and output logits should be on same device
- device_map["model.decoder.lm.lm_head"] = device_1
-
+ device_map['model.decoder.lm.lm_head'] = device_1
+
model = dispatch_model(
- model,
+ model,
device_map=device_map).eval()
-
+
self.model = model
- kwargs_default = dict(max_new_tokens= 64, length_penalty= -1)
+ kwargs_default = dict(max_new_tokens=64, length_penalty=-1)
kwargs_default.update(kwargs)
self.kwargs = kwargs_default
- warnings.warn(f"Following kwargs received: {self.kwargs}, will use as generation config. ")
-
+ warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')
+
def interleave_generate(self, ti_list, dataset=None):
- query, images = '',[]
+ query, images = '', []
for item in ti_list:
if isimg(item):
images.append(Image.open(item).convert('RGB'))
query += '[]'
else:
query += item
-
+
inputs = self.model.build_input_ids(
text=[query],
tokenizer=self.tokenizer,
@@ -78,15 +82,15 @@ def interleave_generate(self, ti_list, dataset=None):
with torch.no_grad():
outputs = self.model.generate(
- input_ids=inputs["input_ids"],
- attention_mask=inputs["attention_mask"],
- image=inputs["image"].to(torch.bfloat16),
+ input_ids=inputs['input_ids'],
+ attention_mask=inputs['attention_mask'],
+ image=inputs['image'].to(torch.bfloat16),
**self.kwargs)
output_text = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
return output_text[0]
-
+
def generate(self, image_path, prompt, dataset=None):
- tl_list = [image_path,prompt]
+ tl_list = [image_path, prompt]
output = self.interleave_generate(tl_list, dataset)
return output
diff --git a/vlmeval/vlm/idefics.py b/vlmeval/vlm/idefics.py
index b096dfd1e..9187d953e 100644
--- a/vlmeval/vlm/idefics.py
+++ b/vlmeval/vlm/idefics.py
@@ -4,65 +4,72 @@
import warnings
from ..smp import splitlen
+
class IDEFICS:
INSTALL_REQ = False
- def __init__(self,
- model_pth="HuggingFaceM4/idefics-9b-instruct",
- with_context=False,
+ def __init__(self,
+ model_pth='HuggingFaceM4/idefics-9b-instruct',
+ with_context=False,
**kwargs):
assert osp.exists(model_pth) or splitlen(model_pth) == 2
from transformers import IdeficsForVisionText2Text, AutoProcessor
- self.model = IdeficsForVisionText2Text.from_pretrained(model_pth, torch_dtype=torch.bfloat16, device_map='auto')
+ self.model = IdeficsForVisionText2Text.from_pretrained(
+ model_pth, torch_dtype=torch.bfloat16, device_map='auto')
self.processor = AutoProcessor.from_pretrained(model_pth)
self.with_context = with_context
kwargs_default = {'max_length': 128}
kwargs_default.update(kwargs)
self.kwargs = kwargs_default
self.file_root = osp.dirname(__file__)
- warnings.warn(f"Following kwargs received: {self.kwargs}, will use as generation config. ")
-
+ warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')
+
def interleave_generate(self, ti_list, dataset=None):
prompts = ['Users:'] + ti_list + ['', '\nAssistant: ']
- inputs = self.processor(prompts, add_end_of_utterance_token=False, return_tensors="pt").to("cuda")
- exit_condition = self.processor.tokenizer("", add_special_tokens=False).input_ids
- bad_words_ids = self.processor.tokenizer(["", ""], add_special_tokens=False).input_ids
+ inputs = self.processor(prompts, add_end_of_utterance_token=False, return_tensors='pt').to('cuda')
+ exit_condition = self.processor.tokenizer('', add_special_tokens=False).input_ids
+ bad_words_ids = self.processor.tokenizer(
+ ['', ''],
+ add_special_tokens=False).input_ids
- generated_ids = self.model.generate(**inputs, eos_token_id=exit_condition, bad_words_ids=bad_words_ids, **self.kwargs)
+ generated_ids = self.model.generate(
+ **inputs, eos_token_id=exit_condition, bad_words_ids=bad_words_ids, **self.kwargs)
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
- text = generated_text[0].split("\nAssistant: ")[-1]
+ text = generated_text[0].split('\nAssistant: ')[-1]
return text
-
+
def generate(self, image_path, prompt, dataset=None):
if self.with_context:
prompts = [
[
- "User: What is in this image?",
+ 'User: What is in this image?',
Image.open(osp.join(self.file_root, 'misc/Idefics.jpg')),
- "",
- "\nAssistant: This picture depicts Idefix, the dog of Obelix in Asterix and Obelix. Idefix is running on the ground.",
- "\nUser: " + prompt,
- Image.open(image_path),
- "",
- "\nAssistant: "
+ '',
+ '\nAssistant: This picture depicts Idefix, the dog of Obelix in Asterix and Obelix. ',
+ 'Idefix is running on the ground.',
+ '\nUser: ' + prompt,
+ Image.open(image_path),
+ '',
+ '\nAssistant: '
]
]
else:
prompts = [
[
- "User: " + prompt,
- Image.open(image_path),
- "",
- "\nAssistant: "
+ 'User: ' + prompt,
+ Image.open(image_path),
+ '',
+ '\nAssistant: '
]
]
- inputs = self.processor(prompts, add_end_of_utterance_token=False, return_tensors="pt").to("cuda")
- exit_condition = self.processor.tokenizer("", add_special_tokens=False).input_ids
- bad_words_ids = self.processor.tokenizer(["", ""], add_special_tokens=False).input_ids
+ inputs = self.processor(prompts, add_end_of_utterance_token=False, return_tensors='pt').to('cuda')
+ exit_condition = self.processor.tokenizer('', add_special_tokens=False).input_ids
+ bad_words_ids = self.processor.tokenizer(
+ ['', ''], add_special_tokens=False).input_ids
- generated_ids = self.model.generate(**inputs, eos_token_id=exit_condition, bad_words_ids=bad_words_ids, **self.kwargs)
+ generated_ids = self.model.generate(
+ **inputs, eos_token_id=exit_condition, bad_words_ids=bad_words_ids, **self.kwargs)
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
- text = generated_text[0].split("\nAssistant: ")[-1]
+ text = generated_text[0].split('\nAssistant: ')[-1]
return text
-
diff --git a/vlmeval/vlm/instructblip.py b/vlmeval/vlm/instructblip.py
index 33da83a7d..f72418540 100644
--- a/vlmeval/vlm/instructblip.py
+++ b/vlmeval/vlm/instructblip.py
@@ -2,7 +2,7 @@
from PIL import Image
from abc import abstractproperty
import os.path as osp
-import os, sys
+import sys
from ..smp import *
@@ -12,19 +12,19 @@ class InstructBLIP:
def __init__(self, name):
self.config_map = {
- 'instructblip_7b': f'misc/blip2_instruct_vicuna7b.yaml',
- 'instructblip_13b': f'misc/blip2_instruct_vicuna13b.yaml',
+ 'instructblip_7b': 'misc/blip2_instruct_vicuna7b.yaml',
+ 'instructblip_13b': 'misc/blip2_instruct_vicuna13b.yaml',
}
self.file_path = __file__
config_root = osp.dirname(self.file_path)
-
+
try:
from lavis.models import load_preprocess
from omegaconf import OmegaConf
from lavis.common.registry import registry
except:
- warnings.warn("Please install lavis before using InstructBLIP. ")
+ warnings.warn('Please install lavis before using InstructBLIP. ')
sys.exit(-1)
assert name in self.config_map
@@ -33,11 +33,11 @@ def __init__(self, name):
model_cfg = cfg.model
assert osp.exists(model_cfg.llm_model) or splitlen(model_cfg.llm_model) == 2
- model_cls = registry.get_model_class(name="blip2_vicuna_instruct")
+ model_cls = registry.get_model_class(name='blip2_vicuna_instruct')
model = model_cls.from_config(model_cfg)
model.eval()
- self.device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
+ self.device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
device = self.device
model.to(device)
self.model = model
@@ -50,6 +50,6 @@ def __init__(self, name):
def generate(self, image_path, prompt, dataset=None):
vis_processors = self.vis_processors
raw_image = Image.open(image_path).convert('RGB')
- image_tensor = vis_processors["eval"](raw_image).unsqueeze(0).to(self.device)
+ image_tensor = vis_processors['eval'](raw_image).unsqueeze(0).to(self.device)
outputs = self.model.generate(dict(image=image_tensor, prompt=prompt))
- return outputs[0]
\ No newline at end of file
+ return outputs[0]
diff --git a/vlmeval/vlm/internvl_chat.py b/vlmeval/vlm/internvl_chat.py
index 862d511b4..7f9b8b94d 100644
--- a/vlmeval/vlm/internvl_chat.py
+++ b/vlmeval/vlm/internvl_chat.py
@@ -31,7 +31,7 @@ def __init__(self, model_path='OpenGVLab/InternVL-Chat-Chinese-V1-1', **kwargs):
kwargs_default = dict(do_sample=False, max_new_tokens=512, top_p=None, num_beams=1)
kwargs_default.update(kwargs)
self.kwargs = kwargs_default
- warnings.warn(f"Following kwargs received: {self.kwargs}, will use as generation config. ")
+ warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')
def use_custom_prompt(self, dataset):
assert dataset is not None
@@ -43,12 +43,12 @@ def build_prompt(self, line, dataset=None):
assert self.use_custom_prompt(dataset)
assert dataset is None or isinstance(dataset, str)
tgt_path = self.dump_image(line, dataset)
-
+
question = line['question']
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
if hint is not None:
question = hint + '\n' + question
-
+
options = {
cand: line[cand]
for cand in string.ascii_uppercase
@@ -57,13 +57,13 @@ def build_prompt(self, line, dataset=None):
for key, item in options.items():
question += f'\n{key}. {item}'
prompt = question
-
+
if len(options):
- prompt += "\n请直接回答选项字母。" if cn_string(
+ prompt += '\n请直接回答选项字母。' if cn_string(
prompt) else "\nAnswer with the option's letter from the given choices directly."
else:
- prompt += "\n请直接回答问题。" if cn_string(prompt) else "\nAnswer the question directly."
-
+ prompt += '\n请直接回答问题。' if cn_string(prompt) else '\nAnswer the question directly.'
+
return {'image': tgt_path, 'text': prompt}
def generate(self, image_path, prompt, dataset=None):
@@ -74,4 +74,4 @@ def generate(self, image_path, prompt, dataset=None):
pixel_values = pixel_values.to(torch.bfloat16).to(self.device)
response = self.model.chat(self.tokenizer, pixel_values=pixel_values,
question=prompt, generation_config=self.kwargs)
- return response
\ No newline at end of file
+ return response
diff --git a/vlmeval/vlm/llava.py b/vlmeval/vlm/llava.py
index e7ae51fb3..68e9b5d97 100644
--- a/vlmeval/vlm/llava.py
+++ b/vlmeval/vlm/llava.py
@@ -1,28 +1,29 @@
import torch
from PIL import Image
from abc import abstractproperty
-import os, sys
+import sys
import os.path as osp
from ..smp import *
from ..utils import DATASET_TYPE, CustomPrompt
+
class LLaVA(CustomPrompt):
INSTALL_REQ = True
- def __init__(self,
+ def __init__(self,
model_pth='liuhaotian/llava_v1.5_7b',
- **kwargs):
+ **kwargs):
try:
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path
except:
- warnings.warn("Please install llava before using LLaVA")
+ warnings.warn('Please install llava before using LLaVA')
sys.exit(-1)
-
- warnings.warn("Please install the latest version of llava from github before you evaluate the LLaVA model. ")
+
+ warnings.warn('Please install the latest version of llava from github before you evaluate the LLaVA model. ')
assert osp.exists(model_pth) or splitlen(model_pth) == 2
-
+
if model_pth == 'Lin-Chen/ShareGPT4V-7B':
model_name = 'llava-v1.5-7b'
elif model_pth == 'Lin-Chen/ShareGPT4V-13B':
@@ -32,36 +33,37 @@ def __init__(self,
try:
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
- model_path=model_pth,
- model_base=None,
- model_name=model_name,
- device='cpu',
+ model_path=model_pth,
+ model_base=None,
+ model_name=model_name,
+ device='cpu',
device_map='cpu'
)
except:
if 'ShareGPT4V' in model_pth:
import llava
warnings.warn(
- f'Please manually remove the encoder type check in {llava.__path__[0]}/model/multimodal_encoder/builder.py '
+ 'Please manually remove the encoder type check in '
+ f'{llava.__path__[0]}/model/multimodal_encoder/builder.py '
'Line 8 to use the ShareGPT4V model. ')
else:
warnings.warn('Unknown error when loading LLaVA model.')
exit(-1)
-
+
self.model = self.model.cuda()
- self.conv_mode = 'llava_v1'
+ self.conv_mode = 'llava_v1'
kwargs_default = dict(do_sample=False, temperature=0, max_new_tokens=512, top_p=None, num_beams=1)
kwargs_default.update(kwargs)
self.kwargs = kwargs_default
- warnings.warn(f"Following kwargs received: {self.kwargs}, will use as generation config. ")
+ warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')
def use_custom_prompt(self, dataset):
assert dataset is not None
if DATASET_TYPE(dataset) == 'multi-choice':
return True
return False
-
+
def build_prompt(self, line, dataset=None):
assert self.use_custom_prompt(dataset)
assert dataset is None or isinstance(dataset, str)
@@ -82,15 +84,19 @@ def build_prompt(self, line, dataset=None):
prompt = question
if len(options):
- prompt += "\n请直接回答选项字母。" if cn_string(prompt) else "\nAnswer with the option's letter from the given choices directly."
+ prompt += (
+ '\n请直接回答选项字母。' if cn_string(prompt) else
+ "\nAnswer with the option's letter from the given choices directly."
+ )
else:
- prompt += "\n请直接回答问题。" if cn_string(prompt) else "\nAnswer the question directly."
+ prompt += '\n请直接回答问题。' if cn_string(prompt) else '\nAnswer the question directly.'
return {'image': tgt_path, 'text': prompt}
def generate(self, image_path, prompt, dataset=None):
from llava.mm_utils import process_images, tokenizer_image_token, KeywordsStoppingCriteria
- from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
+ from llava.constants import (
+ IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN)
from llava.conversation import conv_templates, SeparatorStyle
image = Image.open(image_path).convert('RGB')
args = abstractproperty()
@@ -106,12 +112,14 @@ def generate(self, image_path, prompt, dataset=None):
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
- input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
+ input_ids = tokenizer_image_token(
+ prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
with torch.inference_mode():
- output_ids = self.model.generate(input_ids, images=image_tensor, stopping_criteria=[stopping_criteria], **self.kwargs)
+ output_ids = self.model.generate(
+ input_ids, images=image_tensor, stopping_criteria=[stopping_criteria], **self.kwargs)
output = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
return output
diff --git a/vlmeval/vlm/minicpm_v.py b/vlmeval/vlm/minicpm_v.py
index 70259a363..3ec78840c 100644
--- a/vlmeval/vlm/minicpm_v.py
+++ b/vlmeval/vlm/minicpm_v.py
@@ -24,7 +24,7 @@ def __init__(self, model_path='openbmb/MiniCPM-V', **kwargs):
self.kwargs = kwargs
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
torch.cuda.empty_cache()
-
+
def generate(self, image_path, prompt, dataset=None):
image = Image.open(image_path).convert('RGB')
msgs = [{'role': 'user', 'content': prompt}]
@@ -46,4 +46,3 @@ def generate(self, image_path, prompt, dataset=None):
**default_kwargs
)
return res
-
\ No newline at end of file
diff --git a/vlmeval/vlm/minigpt4.py b/vlmeval/vlm/minigpt4.py
index b154f035a..5a7dba1e5 100644
--- a/vlmeval/vlm/minigpt4.py
+++ b/vlmeval/vlm/minigpt4.py
@@ -6,18 +6,22 @@
from transformers import StoppingCriteriaList
from PIL import Image
+
class MiniGPT4:
INSTALL_REQ = True
-
- def __init__(self,
- mode='v2',
- root='/mnt/petrelfs/share_data/duanhaodong/MiniGPT-4/',
- temperature=1,
+
+ def __init__(self,
+ mode='v2',
+ root='/mnt/petrelfs/share_data/duanhaodong/MiniGPT-4/',
+ temperature=1,
max_out_len=512):
-
+
if root is None:
- warnings.warn('Please set root to the directory of MiniGPT-4, which is cloned from here: https://github.com/Vision-CAIR/MiniGPT-4. ')
+ warnings.warn(
+ 'Please set root to the directory of MiniGPT-4, which is cloned from here: '
+ 'https://github.com/Vision-CAIR/MiniGPT-4. '
+ )
if mode == 'v2':
cfg = 'minigptv2_eval.yaml'
@@ -27,11 +31,11 @@ def __init__(self,
cfg = 'minigpt4_13b_eval.yaml'
else:
raise NotImplementedError
-
+
self.mode = mode
- self.temperature = temperature
+ self.temperature = temperature
self.max_out_len = max_out_len
- self.root = root
+ self.root = root
this_dir = osp.dirname(__file__)
self.cfg = osp.join(this_dir, 'misc', cfg)
@@ -43,12 +47,12 @@ def __init__(self,
device = torch.cuda.current_device()
self.device = device
-
+
cfg_path = self.cfg
cfg = OmegaConf.load(cfg_path)
-
+
model_cfg = cfg.model
- model_cfg.device_8bit = device
+ model_cfg.device_8bit = device
model_cls = registry.get_model_class(model_cfg.arch)
model = model_cls.from_config(model_cfg)
model = model.to(device)
@@ -62,7 +66,7 @@ def __init__(self,
stop_words_ids = [[835], [2277, 29937]]
stop_words_ids = [torch.tensor(ids).to(device) for ids in stop_words_ids]
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
-
+
def generate(self, image_path, prompt, dataset=None):
from minigpt4.conversation.conversation import Chat
if self.mode == 'v2':
diff --git a/vlmeval/vlm/misc/minigptv2_eval.yaml b/vlmeval/vlm/misc/minigptv2_eval.yaml
index 94394b4fb..32815e1b9 100644
--- a/vlmeval/vlm/misc/minigptv2_eval.yaml
+++ b/vlmeval/vlm/misc/minigptv2_eval.yaml
@@ -19,7 +19,7 @@ model:
# generation configs
prompt: ""
- # LLM
+ # LLM
llama_model: "please set this value to the path of llama2-chat-7b"
datasets:
diff --git a/vlmeval/vlm/mmalaya.py b/vlmeval/vlm/mmalaya.py
index 365708466..f775d9324 100644
--- a/vlmeval/vlm/mmalaya.py
+++ b/vlmeval/vlm/mmalaya.py
@@ -19,18 +19,18 @@ def __init__(self, model_path='DataCanvas/MMAlaya', **kwargs):
# need initialize tokenizer
model.initialize_tokenizer(self.tokenizer)
self.model = model.cuda()
-
+
self.kwargs = kwargs
- warnings.warn(f"Following kwargs received: {self.kwargs}, will use as generation config. ")
+ warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')
torch.cuda.empty_cache()
def generate(self, image_path, prompt, dataset=None):
# read image
- image = Image.open(image_path).convert("RGB")
+ image = Image.open(image_path).convert('RGB')
# tokenize prompt, and proprecess image
input_ids, image_tensor, stopping_criteria = self.model.prepare_for_inference(
- prompt,
- self.tokenizer,
+ prompt,
+ self.tokenizer,
image,
return_tensors='pt')
with torch.inference_mode():
@@ -42,23 +42,23 @@ def generate(self, image_path, prompt, dataset=None):
num_beams=1,
use_cache=True,
stopping_criteria=[stopping_criteria],
- )
+ )
# truncate input_ids in generate_ids and then decode to text
input_token_len = input_ids.shape[1]
response = self.tokenizer.batch_decode(
- output_ids[:, input_token_len:].cpu(),
- skip_special_tokens=True,
+ output_ids[:, input_token_len:].cpu(),
+ skip_special_tokens=True,
clean_up_tokenization_spaces=False
- )[0].strip()
+ )[0].strip()
return response
-if __name__ == "__main__":
+if __name__ == '__main__':
model = MMAlaya()
response = model.generate(
image_path='./assets/apple.jpg',
prompt='请详细描述一下这张图片。',
- )
+ )
print(response)
"""
diff --git a/vlmeval/vlm/monkey.py b/vlmeval/vlm/monkey.py
index dac914462..da8cb3ffd 100644
--- a/vlmeval/vlm/monkey.py
+++ b/vlmeval/vlm/monkey.py
@@ -5,6 +5,7 @@
from vlmeval.smp import isimg
from ..utils import DATASET_TYPE, CustomPrompt
+
class Monkey:
INSTALL_REQ = False
@@ -16,29 +17,33 @@ def __init__(self, model_path='echo840/Monkey', **kwargs):
model = AutoModelForCausalLM.from_pretrained(model_path, device_map='cpu', trust_remote_code=True).eval()
self.model = model.cuda()
self.kwargs = kwargs
- warnings.warn(f"Following kwargs received: {self.kwargs}, will use as generation config. ")
+ warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')
torch.cuda.empty_cache()
+
def generate_vanilla(self, image_path, prompt):
cur_prompt = f'{image_path} {prompt} Answer: '
input_ids = self.tokenizer(cur_prompt, return_tensors='pt', padding='longest')
attention_mask = input_ids.attention_mask
input_ids = input_ids.input_ids
-
+
output_ids = self.model.generate(
- input_ids=input_ids.cuda(),
- attention_mask=attention_mask.cuda(),
- do_sample=False,
- num_beams=1,
- max_new_tokens=512,
- min_new_tokens=1,
- length_penalty=1,
- num_return_sequences=1,
- output_hidden_states=True,
- use_cache=True,
- pad_token_id=self.tokenizer.eod_id,
- eos_token_id=self.tokenizer.eod_id,
- )
- response = self.tokenizer.decode(output_ids[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip()
+ input_ids=input_ids.cuda(),
+ attention_mask=attention_mask.cuda(),
+ do_sample=False,
+ num_beams=1,
+ max_new_tokens=512,
+ min_new_tokens=1,
+ length_penalty=1,
+ num_return_sequences=1,
+ output_hidden_states=True,
+ use_cache=True,
+ pad_token_id=self.tokenizer.eod_id,
+ eos_token_id=self.tokenizer.eod_id,
+ )
+ response = self.tokenizer.decode(
+ output_ids[0][input_ids.size(1):].cpu(),
+ skip_special_tokens=True
+ ).strip()
return response
def generate_multichoice(self, image_path, prompt):
@@ -46,33 +51,35 @@ def generate_multichoice(self, image_path, prompt):
input_ids = self.tokenizer(cur_prompt, return_tensors='pt', padding='longest')
attention_mask = input_ids.attention_mask
input_ids = input_ids.input_ids
-
+
output_ids = self.model.generate(
- input_ids=input_ids.cuda(),
- attention_mask=attention_mask.cuda(),
- do_sample=False,
- num_beams=1,
- max_new_tokens=10,
- min_new_tokens=1,
- length_penalty=1,
- num_return_sequences=1,
- output_hidden_states=True,
- use_cache=True,
- pad_token_id=self.tokenizer.eod_id,
- eos_token_id=self.tokenizer.eod_id,
- )
- response = self.tokenizer.decode(output_ids[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip()
+ input_ids=input_ids.cuda(),
+ attention_mask=attention_mask.cuda(),
+ do_sample=False,
+ num_beams=1,
+ max_new_tokens=10,
+ min_new_tokens=1,
+ length_penalty=1,
+ num_return_sequences=1,
+ output_hidden_states=True,
+ use_cache=True,
+ pad_token_id=self.tokenizer.eod_id,
+ eos_token_id=self.tokenizer.eod_id,
+ )
+ response = self.tokenizer.decode(
+ output_ids[0][input_ids.size(1):].cpu(),
+ skip_special_tokens=True
+ ).strip()
return response
-
+
def generate(self, image_path, prompt, dataset=None):
if dataset is None:
return self.generate_vanilla(image_path, prompt)
assert isinstance(dataset, str)
- if DATASET_TYPE(dataset) == 'multi-choice' or DATASET_TYPE(dataset) == 'Y/N' or dataset=="HallusionBench":
+ if DATASET_TYPE(dataset) == 'multi-choice' or DATASET_TYPE(dataset) == 'Y/N' or dataset == 'HallusionBench':
return self.generate_multichoice(image_path, prompt)
else:
return self.generate_vanilla(image_path, prompt)
-
class MonkeyChat:
@@ -86,34 +93,37 @@ def __init__(self, model_path='echo840/Monkey-Chat', **kwargs):
model = AutoModelForCausalLM.from_pretrained(model_path, device_map='cpu', trust_remote_code=True).eval()
self.model = model.cuda()
self.kwargs = kwargs
-
+
self.tokenizer.padding_side = 'left'
self.tokenizer.pad_token_id = self.tokenizer.eod_id
- warnings.warn(f"Following kwargs received: {self.kwargs}, will use as generation config. ")
+ warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')
torch.cuda.empty_cache()
-
+
def generate_vanilla(self, image_path, prompt):
cur_prompt = f'{image_path} {prompt} Answer: '
input_ids = self.tokenizer(cur_prompt, return_tensors='pt', padding='longest')
attention_mask = input_ids.attention_mask
input_ids = input_ids.input_ids
-
+
output_ids = self.model.generate(
- input_ids=input_ids.cuda(),
- attention_mask=attention_mask.cuda(),
- do_sample=False,
- num_beams=1,
- max_new_tokens=512,
- min_new_tokens=1,
- length_penalty=1,
- num_return_sequences=1,
- output_hidden_states=True,
- use_cache=True,
- pad_token_id=self.tokenizer.eod_id,
- eos_token_id=self.tokenizer.eod_id,
- )
- response = self.tokenizer.decode(output_ids[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip()
+ input_ids=input_ids.cuda(),
+ attention_mask=attention_mask.cuda(),
+ do_sample=False,
+ num_beams=1,
+ max_new_tokens=512,
+ min_new_tokens=1,
+ length_penalty=1,
+ num_return_sequences=1,
+ output_hidden_states=True,
+ use_cache=True,
+ pad_token_id=self.tokenizer.eod_id,
+ eos_token_id=self.tokenizer.eod_id,
+ )
+ response = self.tokenizer.decode(
+ output_ids[0][input_ids.size(1):].cpu(),
+ skip_special_tokens=True
+ ).strip()
return response
def generate_multichoice(self, image_path, prompt):
@@ -121,29 +131,32 @@ def generate_multichoice(self, image_path, prompt):
input_ids = self.tokenizer(cur_prompt, return_tensors='pt', padding='longest')
attention_mask = input_ids.attention_mask
input_ids = input_ids.input_ids
-
+
output_ids = self.model.generate(
- input_ids=input_ids.cuda(),
- attention_mask=attention_mask.cuda(),
- do_sample=False,
- num_beams=1,
- max_new_tokens=10,
- min_new_tokens=1,
- length_penalty=1,
- num_return_sequences=1,
- output_hidden_states=True,
- use_cache=True,
- pad_token_id=self.tokenizer.eod_id,
- eos_token_id=self.tokenizer.eod_id,
- )
- response = self.tokenizer.decode(output_ids[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip()
+ input_ids=input_ids.cuda(),
+ attention_mask=attention_mask.cuda(),
+ do_sample=False,
+ num_beams=1,
+ max_new_tokens=10,
+ min_new_tokens=1,
+ length_penalty=1,
+ num_return_sequences=1,
+ output_hidden_states=True,
+ use_cache=True,
+ pad_token_id=self.tokenizer.eod_id,
+ eos_token_id=self.tokenizer.eod_id,
+ )
+ response = self.tokenizer.decode(
+ output_ids[0][input_ids.size(1):].cpu(),
+ skip_special_tokens=True
+ ).strip()
return response
-
+
def generate(self, image_path, prompt, dataset=None):
if dataset is None:
return self.generate_vanilla(image_path, prompt)
assert isinstance(dataset, str)
- if DATASET_TYPE(dataset) == 'multi-choice' or DATASET_TYPE(dataset) == 'Y/N' or dataset=="HallusionBench":
+ if DATASET_TYPE(dataset) == 'multi-choice' or DATASET_TYPE(dataset) == 'Y/N' or dataset == 'HallusionBench':
return self.generate_multichoice(image_path, prompt)
else:
return self.generate_vanilla(image_path, prompt)
diff --git a/vlmeval/vlm/mplug_owl2.py b/vlmeval/vlm/mplug_owl2.py
index f65ad9dd7..d7e6df35c 100644
--- a/vlmeval/vlm/mplug_owl2.py
+++ b/vlmeval/vlm/mplug_owl2.py
@@ -1,4 +1,5 @@
-import os, torch, sys
+import sys
+import torch
from PIL import Image
from ..smp import *
from ..utils import DATASET_TYPE, CustomPrompt
@@ -8,16 +9,17 @@ class mPLUG_Owl2(CustomPrompt):
INSTALL_REQ = True
- def __init__(self, model_path='MAGAer13/mplug-owl2-llama2-7b', **kwargs):
+ def __init__(self, model_path='MAGAer13/mplug-owl2-llama2-7b', **kwargs):
try:
from mplug_owl2.model.builder import load_pretrained_model
from mplug_owl2.mm_utils import get_model_name_from_path
except:
warnings.warn('Please install mPLUG_Owl2 before using mPLUG_Owl2. ')
sys.exit(-1)
-
+
model_name = get_model_name_from_path(model_path)
- tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, load_8bit=False, load_4bit=False, device="cpu")
+ tokenizer, model, image_processor, context_len = load_pretrained_model(
+ model_path, None, model_name, load_8bit=False, load_4bit=False, device='cpu')
self.model = model.cuda()
self.device = self.model.device
@@ -28,11 +30,11 @@ def __init__(self, model_path='MAGAer13/mplug-owl2-llama2-7b', **kwargs):
self.context_len = context_len
kwargs_default = dict(
- max_new_tokens=10, do_sample=False, num_beams=1,
+ max_new_tokens=10, do_sample=False, num_beams=1,
min_new_tokens=1, length_penalty=1, num_return_sequences=1)
kwargs_default.update(kwargs)
self.kwargs = kwargs_default
- warnings.warn(f"Following kwargs received: {self.kwargs}, will use as generation config. ")
+ warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')
def use_custom_prompt(self, dataset):
assert dataset is not None
@@ -41,17 +43,20 @@ def use_custom_prompt(self, dataset):
if DATASET_TYPE(dataset) == 'multi-choice' or dataset == 'MMVet':
return True
return False
-
+
def build_prompt(self, line, dataset=None):
assert dataset is None or isinstance(dataset, str)
assert self.use_custom_prompt(dataset)
tgt_path = self.dump_image(line, dataset)
if dataset == 'MMVet':
- prompt_tmpl = "USER: <|image|>{}\nAnswer the question directly. ASSISTANT:"
+ prompt_tmpl = 'USER: <|image|>{}\nAnswer the question directly. ASSISTANT:'
prompt = prompt_tmpl.format(line['question'])
elif DATASET_TYPE(dataset) == 'multi-choice':
- prompt_tmpl = "USER: <|image|>{}\n{}\n{}\nAnswer with the option’s letter from the given choices directly. ASSISTANT:"
+ prompt_tmpl = (
+ 'USER: <|image|>{}\n{}\n{}\nAnswer with the option’s letter from the given choices directly. '
+ 'ASSISTANT:'
+ )
options = {
cand: line[cand]
for cand in string.ascii_uppercase
@@ -60,25 +65,29 @@ def build_prompt(self, line, dataset=None):
options_prompt = ''
for key, item in options.items():
options_prompt += f'{key}. {item}\n'
-
+
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else 'N/A'
if len(options):
- prompt = f"USER: <|image|>{hint}\n{line['question']}\n{options_prompt}\nAnswer with the option’s letter from the given choices directly. ASSISTANT:"
+ question = line['question']
+ prompt = (
+ f'USER: <|image|>{hint}\n{question}\n{options_prompt}\n'
+ 'Answer with the option’s letter from the given choices directly. ASSISTANT:'
+ )
else:
prompt = f"USER: <|image|>{hint}\n{line['question']}\nAnswer the question directly. ASSISTANT:"
else:
raise NotImplementedError
return {'image': tgt_path, 'text': prompt}
-
+
def generate_vanilla(self, image_path, prompt, **kwargs):
from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from mplug_owl2.conversation import conv_templates
from mplug_owl2.mm_utils import process_images, tokenizer_image_token, KeywordsStoppingCriteria
- conv = conv_templates["mplug_owl2"].copy()
+ conv = conv_templates['mplug_owl2'].copy()
image = Image.open(image_path).convert('RGB')
- max_edge = max(image.size) # We recommand you to resize to squared image for BEST performance.
+ max_edge = max(image.size) # We recommand you to resize to squared image for BEST performance.
image = image.resize((max_edge, max_edge))
image_tensor = process_images([image], self.image_processor)
@@ -89,7 +98,8 @@ def generate_vanilla(self, image_path, prompt, **kwargs):
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
- input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
+ input_ids = tokenizer_image_token(
+ prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
stop_str = conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
@@ -111,46 +121,48 @@ def generate_multichoice(self, image_path, prompt):
from mplug_owl2.constants import IMAGE_TOKEN_INDEX
from mplug_owl2.mm_utils import process_images, tokenizer_image_token
image = Image.open(image_path).convert('RGB')
- max_edge = max(image.size) # We recommand you to resize to squared image for BEST performance.
+ max_edge = max(image.size) # We recommand you to resize to squared image for BEST performance.
image = image.resize((max_edge, max_edge))
image_tensor = process_images([image], self.image_processor)
image_tensor = image_tensor.to(self.device, dtype=torch.float16)
- input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
+ input_ids = tokenizer_image_token(
+ prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
with torch.inference_mode():
output_ids = self.model.generate(
- input_ids=input_ids,
- images=image_tensor,
- output_hidden_states=True,
- use_cache=True,
+ input_ids=input_ids,
+ images=image_tensor,
+ output_hidden_states=True,
+ use_cache=True,
**self.kwargs)
- answer = self.tokenizer.decode(output_ids[0, input_ids.shape[1]: ]).strip()
+ answer = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
return answer.split('')[0]
-
+
def generate_mmvet(self, image_path, prompt):
from mplug_owl2.constants import IMAGE_TOKEN_INDEX
from mplug_owl2.mm_utils import process_images, tokenizer_image_token
image = Image.open(image_path).convert('RGB')
- max_edge = max(image.size) # We recommand you to resize to squared image for BEST performance.
+ max_edge = max(image.size) # We recommand you to resize to squared image for BEST performance.
image = image.resize((max_edge, max_edge))
image_tensor = process_images([image], self.image_processor)
image_tensor = image_tensor.to(self.device, dtype=torch.float16)
- input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
+ input_ids = tokenizer_image_token(
+ prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
kwargs = cp.deepcopy(self.kwargs)
kwargs['max_new_tokens'] = 64
kwargs['length_penalty'] = 0
with torch.inference_mode():
output_ids = self.model.generate(
- input_ids=input_ids,
- images=image_tensor,
- output_hidden_states=True,
- use_cache=True,
+ input_ids=input_ids,
+ images=image_tensor,
+ output_hidden_states=True,
+ use_cache=True,
**kwargs)
- answer = self.tokenizer.decode(output_ids[0, input_ids.shape[1]: ]).strip()
+ answer = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
return answer.split('')[0]
def generate(self, image_path, prompt, dataset=None):
@@ -164,11 +176,11 @@ def generate(self, image_path, prompt, dataset=None):
return self.generate_vanilla(image_path, prompt, **gen_config)
else:
return self.generate_vanilla(image_path, prompt)
-
+
def interleave_generate(self, ti_list, dataset=None):
from mplug_owl2.constants import IMAGE_TOKEN_INDEX
from mplug_owl2.mm_utils import process_images, tokenizer_image_token
- prompt_full = "USER: "
+ prompt_full = 'USER: '
images = []
for s in ti_list:
if isimg(s):
@@ -176,20 +188,21 @@ def interleave_generate(self, ti_list, dataset=None):
max_edge = max(image.size)
image = image.resize((max_edge, max_edge))
images.append(image)
- prompt_full += f"<|image|>"
+ prompt_full += '<|image|>'
else:
prompt_full += s
- prompt_full += "\nASSISTANT: "
+ prompt_full += '\nASSISTANT: '
image_tensor = process_images(images, self.image_processor)
image_tensor = image_tensor.to(self.device, dtype=torch.float16)
- input_ids = tokenizer_image_token(prompt_full, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
+ input_ids = tokenizer_image_token(
+ prompt_full, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
with torch.inference_mode():
output_ids = self.model.generate(
- input_ids=input_ids,
- images=image_tensor,
- output_hidden_states=True,
- use_cache=True,
+ input_ids=input_ids,
+ images=image_tensor,
+ output_hidden_states=True,
+ use_cache=True,
**self.kwargs)
- answer = self.tokenizer.decode(output_ids[0, input_ids.shape[1]: ]).strip()
- return answer.split('')[0]
\ No newline at end of file
+ answer = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
+ return answer.split('')[0]
diff --git a/vlmeval/vlm/omnilmm.py b/vlmeval/vlm/omnilmm.py
index f1659b291..8da2fe903 100644
--- a/vlmeval/vlm/omnilmm.py
+++ b/vlmeval/vlm/omnilmm.py
@@ -7,10 +7,11 @@
from ..utils import DATASET_TYPE, CustomPrompt
-DEFAULT_IMAGE_TOKEN = ""
-DEFAULT_IMAGE_PATCH_TOKEN = ""
-DEFAULT_IM_START_TOKEN = ""
-DEFAULT_IM_END_TOKEN = ""
+DEFAULT_IMAGE_TOKEN = ''
+DEFAULT_IMAGE_PATCH_TOKEN = ''
+DEFAULT_IM_START_TOKEN = ''
+DEFAULT_IM_END_TOKEN = ''
+
def init_omni_lmm(model_path):
from omnilmm.model.omnilmm import OmniLMMForCausalLM
@@ -26,7 +27,7 @@ def init_omni_lmm(model_path):
image_processor = build_transform(is_train=False, input_size=model.model.config.image_size, std_mode='OPENAI_CLIP')
- mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
+ mm_use_im_start_end = getattr(model.config, 'mm_use_im_start_end', False)
assert mm_use_im_start_end
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
@@ -41,6 +42,7 @@ def init_omni_lmm(model_path):
return model, image_processor, image_token_len, tokenizer
+
def expand_question_into_multimodal(question_text, image_token_len, im_st_token, im_ed_token, im_patch_token):
if '' in question_text[0]['content']:
question_text[0]['content'] = question_text[0]['content'].replace(
@@ -50,20 +52,19 @@ def expand_question_into_multimodal(question_text, image_token_len, im_st_token,
image_token_len + im_ed_token + '\n' + question_text[0]['content']
return question_text
+
def wrap_question_for_omni_lmm(question, image_token_len, tokenizer):
from omnilmm.train.train_utils import omni_preprocess
question = expand_question_into_multimodal(
question, image_token_len, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN)
conversation = question
- data_dict = omni_preprocess(sources=[conversation],
- tokenizer=tokenizer,
- generation=True)
+ data_dict = omni_preprocess(sources=[conversation], tokenizer=tokenizer, generation=True)
- data_dict = dict(input_ids=data_dict["input_ids"][0],
- labels=data_dict["labels"][0])
+ data_dict = dict(input_ids=data_dict['input_ids'][0], labels=data_dict['labels'][0])
return data_dict
+
class OmniLMM12B(CustomPrompt):
INSTALL_REQ = True
@@ -89,12 +90,12 @@ def __init__(self, model_path, root, **kwargs) -> None:
def generate(self, image_path, prompt, dataset=None):
try:
image = Image.open(image_path).convert('RGB')
- except Exception as e:
+ except:
logger = get_logger('OmniLMM Inference')
- logger.error("Image Decode Error")
- return "Image Decode Error"
+ logger.error('Image Decode Error')
+ return 'Image Decode Error'
- msgs = [dict(role="user", content=prompt)]
+ msgs = [dict(role='user', content=prompt)]
input_ids = wrap_question_for_omni_lmm(
msgs, self.image_token_len, self.tokenizer)['input_ids']
input_ids = torch.as_tensor(input_ids)
@@ -110,13 +111,13 @@ def generate(self, image_path, prompt, dataset=None):
output.sequences[0], skip_special_tokens=True)
response = response.strip()
return response
-
+
def use_custom_prompt(self, dataset):
assert dataset is not None
if DATASET_TYPE(dataset) == 'multi-choice':
return True
return False
-
+
def build_prompt(self, line, dataset=None):
assert dataset is None or isinstance(dataset, str)
assert self.use_custom_prompt(dataset)
@@ -138,6 +139,9 @@ def build_prompt(self, line, dataset=None):
prompt += f'{question}\n'
if len(options):
prompt += options_prompt
- prompt = 'Study the image carefully and pick the option associated with the correct answer. Focus solely on selecting the option and avoid including any other content.\n' + prompt
+ prompt = """
+Study the image carefully and pick the option associated with the correct answer.
+Focus solely on selecting the option and avoid including any other content.\n
+""" + prompt
- return {'image': tgt_path, 'text': prompt}
\ No newline at end of file
+ return {'image': tgt_path, 'text': prompt}
diff --git a/vlmeval/vlm/open_flamingo.py b/vlmeval/vlm/open_flamingo.py
index e8f9a329b..5f2ec272f 100644
--- a/vlmeval/vlm/open_flamingo.py
+++ b/vlmeval/vlm/open_flamingo.py
@@ -7,22 +7,29 @@
from ..smp import splitlen, get_cache_path
from huggingface_hub import snapshot_download
-class OpenFlamingo:
+
+class OpenFlamingo:
INSTALL_REQ = True
- def __init__(self,
- name,
+ def __init__(self,
+ name,
with_context=False,
mpt_pth=None,
- ckpt_pth=None,
+ ckpt_pth=None,
**kwargs):
-
+
if mpt_pth is None:
- warnings.warn('Please set `mpt_pth` to the directory of MPT-7B, which is cloned from here: https://huggingface.co/mosaicml/mpt-7b. ')
+ warnings.warn(
+ 'Please set `mpt_pth` to the directory of MPT-7B, which is cloned from here: '
+ 'https://huggingface.co/mosaicml/mpt-7b. '
+ )
sys.exit(-1)
if ckpt_pth is None:
- warnings.warn('Please set `ckpt_pth` to the openflamingo ckpt, which is the `checkpoint.pt` file downloaded from: https://huggingface.co/openflamingo/OpenFlamingo-9B-vitl-mpt7b/tree/main. ' )
+ warnings.warn(
+ 'Please set `ckpt_pth` to the openflamingo ckpt, which is the `checkpoint.pt` file downloaded '
+ 'from: https://huggingface.co/openflamingo/OpenFlamingo-9B-vitl-mpt7b/tree/main. '
+ )
sys.exit(-1)
else:
if osp.exists(ckpt_pth):
@@ -41,17 +48,17 @@ def __init__(self,
sys.exit(-1)
else:
ckpt_pth = osp.join(cache_path, 'checkpoint.pt')
-
+
self.name = name
assert name in ['v2']
self.mpt_pth = mpt_pth
try:
from open_flamingo import create_model_and_transforms
except:
- raise ImportError("Please first install open_flamingo to use OpenFlamingo")
+ raise ImportError('Please first install open_flamingo to use OpenFlamingo')
model, image_processor, tokenizer = create_model_and_transforms(
- clip_vision_encoder_path="ViT-L-14",
- clip_vision_encoder_pretrained="openai",
+ clip_vision_encoder_path='ViT-L-14',
+ clip_vision_encoder_pretrained='openai',
lang_encoder_path=mpt_pth,
tokenizer_path=mpt_pth,
cross_attn_every_n_layers=4)
@@ -61,19 +68,19 @@ def __init__(self,
torch.cuda.empty_cache()
self.model = model.eval().cuda()
self.tokenizer = tokenizer
- self.tokenizer.padding_side = "left"
+ self.tokenizer.padding_side = 'left'
this_dir = osp.dirname(__file__)
-
- self.demo1 = Image.open(f"{this_dir}/misc/000000039769.jpg")
- self.demo2 = Image.open(f"{this_dir}/misc/000000028137.jpg")
+
+ self.demo1 = Image.open(f'{this_dir}/misc/000000039769.jpg')
+ self.demo2 = Image.open(f'{this_dir}/misc/000000028137.jpg')
self.image_proc = image_processor
kwargs_default = dict(max_new_tokens=128, num_beams=3)
kwargs_default.update(kwargs)
self.kwargs = kwargs_default
- warnings.warn(f"Following kwargs received: {self.kwargs}, will use as generation config. ")
-
+ warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')
+
def generate(self, image_path, prompt, dataset=None):
if self.with_context:
vision_x = [self.image_proc(x).unsqueeze(0) for x in [self.demo1, self.demo2, Image.open(image_path)]]
@@ -83,20 +90,19 @@ def generate(self, image_path, prompt, dataset=None):
vision_x = vision_x.unsqueeze(1).unsqueeze(0)
if self.with_context:
prompt = (
- "Please describe the above image in a sentence. Answer: An image of two cats.<|endofchunk|>" +
- "Please describe the above image in a sentence. Answer: An image of a bathroom sink.<|endofchunk|>" +
- "" + prompt + 'Answer: '
+ 'Please describe the above image in a sentence. Answer: An image of two cats.<|endofchunk|>'
+ 'Please describe the above image in a sentence. '
+ 'Answer: An image of a bathroom sink.<|endofchunk|>'
+ '' + prompt + 'Answer: '
)
else:
- prompt = "" + prompt + 'Answer: '
- lang_x = self.tokenizer([prompt], return_tensors="pt")
+ prompt = '' + prompt + 'Answer: '
+ lang_x = self.tokenizer([prompt], return_tensors='pt')
generated_text = self.model.generate(
- vision_x=vision_x.cuda(),
- lang_x=lang_x['input_ids'].cuda(),
- attention_mask=lang_x['attention_mask'].cuda(),
+ vision_x=vision_x.cuda(),
+ lang_x=lang_x['input_ids'].cuda(),
+ attention_mask=lang_x['attention_mask'].cuda(),
**self.kwargs)
generated_text = self.tokenizer.decode(generated_text[0])
- text = generated_text[len(prompt): ].split('<|endofchunk|>')[0]
- return text
-
-
\ No newline at end of file
+ text = generated_text[len(prompt):].split('<|endofchunk|>')[0]
+ return text
diff --git a/vlmeval/vlm/pandagpt.py b/vlmeval/vlm/pandagpt.py
index 450381bdd..8bd6615fe 100644
--- a/vlmeval/vlm/pandagpt.py
+++ b/vlmeval/vlm/pandagpt.py
@@ -3,6 +3,7 @@
import os.path as osp
import warnings
+
class PandaGPT:
INSTALL_REQ = True
@@ -18,7 +19,10 @@ def __init__(self, name, root=None, **kwargs):
try:
from model.openllama import OpenLLAMAPEFTModel
except:
- raise ImportError('Please first install PandaGPT and set the root path to use PandaGPT, which is cloned from here: https://github.com/yxuansu/PandaGPT. ')
+ raise ImportError(
+ 'Please first install PandaGPT and set the root path to use PandaGPT, '
+ 'which is cloned from here: https://github.com/yxuansu/PandaGPT. '
+ )
self.args = {
'model': 'openllama_peft',
'imagebind_ckpt_path': osp.join(root, 'pretrained_ckpt/imagebind_ckpt'),
@@ -38,12 +42,12 @@ def __init__(self, name, root=None, **kwargs):
kwargs_default = {'top_p': 0.9, 'do_sample': False, 'max_tgt_len': 128, 'temperature': 0.001}
kwargs_default.update(kwargs)
self.kwargs = kwargs_default
- warnings.warn(f"Following kwargs received: {self.kwargs}, will use as generation config. ")
-
+ warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')
+
def generate(self, image_path, prompt, dataset=None):
struct = {
- 'prompt': prompt,
- 'image_paths': [image_path],
+ 'prompt': prompt,
+ 'image_paths': [image_path],
'audio_paths': [],
'video_paths': [],
'thermal_paths': [],
@@ -51,4 +55,4 @@ def generate(self, image_path, prompt, dataset=None):
}
struct.update(self.kwargs)
resp = self.model.generate(struct)
- return resp
\ No newline at end of file
+ return resp
diff --git a/vlmeval/vlm/qwen_vl.py b/vlmeval/vlm/qwen_vl.py
index 0194144cb..39ac3abf8 100644
--- a/vlmeval/vlm/qwen_vl.py
+++ b/vlmeval/vlm/qwen_vl.py
@@ -5,6 +5,7 @@
from vlmeval.smp import isimg, listinstr
import re
+
class QwenVL:
INSTALL_REQ = False
@@ -15,9 +16,9 @@ def __init__(self, model_path='Qwen/Qwen-VL', **kwargs):
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(model_path, device_map='cuda', trust_remote_code=True).eval()
self.kwargs = kwargs
- warnings.warn(f"Following kwargs received: {self.kwargs}, will use as generation config. ")
+ warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')
torch.cuda.empty_cache()
-
+
def generate(self, image_path, prompt, dataset=None):
vl_pair = [{'image': image_path}, {'text': prompt}]
query = self.tokenizer.from_list_format(vl_pair)
@@ -28,7 +29,7 @@ def generate(self, image_path, prompt, dataset=None):
response = self.tokenizer.decode(pred.cpu()[0], skip_special_tokens=False)
response = response.split(prompt)[1].split('<|endoftext|>')[0]
return response
-
+
def interleave_generate(self, ti_list, dataset=None):
vl_list = [{'image': s} if isimg(s) else {'text': s} for s in ti_list]
query = self.tokenizer.from_list_format(vl_list)
@@ -39,7 +40,8 @@ def interleave_generate(self, ti_list, dataset=None):
response = self.tokenizer.decode(pred.cpu()[0], skip_special_tokens=False)
response = response.split(query)[1].split('<|endoftext|>')[0]
return response
-
+
+
class QwenVLChat:
INSTALL_REQ = False
@@ -51,17 +53,17 @@ def __init__(self, model_path='Qwen/Qwen-VL-Chat', **kwargs):
self.model = AutoModelForCausalLM.from_pretrained(model_path, device_map='cuda', trust_remote_code=True).eval()
torch.cuda.empty_cache()
self.kwargs = kwargs
- warnings.warn(f"Following kwargs received: {self.kwargs}, will use as generation config. ")
-
+ warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')
+
def generate(self, image_path, prompt, dataset=None):
vl_pair = [{'image': image_path}, {'text': prompt}]
query = self.tokenizer.from_list_format(vl_pair)
response, _ = self.model.chat(self.tokenizer, query=query, history=None, **self.kwargs)
return response
-
+
def interleave_generate(self, ti_list, dataset=None):
vl_list = [{'image': s} if isimg(s) else {'text': s} for s in ti_list]
query = self.tokenizer.from_list_format(vl_list)
response, _ = self.model.chat(self.tokenizer, query=query, history=None, **self.kwargs)
- return response
\ No newline at end of file
+ return response
diff --git a/vlmeval/vlm/sharedcaptioner.py b/vlmeval/vlm/sharedcaptioner.py
index ab3eb47ab..88d739d43 100644
--- a/vlmeval/vlm/sharedcaptioner.py
+++ b/vlmeval/vlm/sharedcaptioner.py
@@ -5,16 +5,16 @@
from ..smp import *
from ..utils import DATASET_TYPE, CustomPrompt
+
class SharedCaptioner(CustomPrompt):
INSTALL_REQ = False
def __init__(self, model_path='Lin-Chen/ShareCaptioner', **kwargs):
assert model_path is not None
- tokenizer = AutoTokenizer.from_pretrained(
- model_path, trust_remote_code=True)
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(
- model_path, device_map="cuda", trust_remote_code=True).eval()
+ model_path, device_map='cuda', trust_remote_code=True).eval()
self.model.tokenizer = tokenizer
self.model.cuda()
self.model.half()
@@ -29,7 +29,7 @@ def build_prompt(self, line, dataset=None):
assert dataset is None or isinstance(dataset, str)
assert self.use_custom_prompt(dataset)
tgt_path = self.dump_image(line, dataset)
-
+
if dataset is not None and DATASET_TYPE(dataset) == 'multi-choice':
question = line['question']
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
@@ -47,22 +47,22 @@ def build_prompt(self, line, dataset=None):
prompt = question
if not cn_string(prompt):
- prompt = prompt + "\n" + "Answer with the option's letter from the given choices directly."
+ prompt = prompt + '\n' + "Answer with the option's letter from the given choices directly."
else:
- prompt = prompt + "\n" + "请直接回答选项字母。"
+ prompt = prompt + '\n' + '请直接回答选项字母。'
else:
prompt = line['question']
# prompt = 'Analyze the image in a comprehensive and detailed manner.'
return {'image': tgt_path, 'text': prompt}
-
+
def generate(self, image_path, prompt, dataset=None):
seg1 = '<|User|>:'
seg2 = f'{prompt}{self.model.eoh}\n<|Bot|>:'
self.seg_emb1 = self.model.encode_text(seg1, add_special_tokens=True)
self.seg_emb2 = self.model.encode_text(seg2, add_special_tokens=False)
- image = Image.open(image_path).convert("RGB")
+ image = Image.open(image_path).convert('RGB')
image = self.model.vis_processor(image).unsqueeze(0)
image = image.to(self.model.device)
tmp_bs = image.shape[0]
@@ -84,9 +84,8 @@ def generate(self, image_path, prompt, dataset=None):
temperature=1.,
eos_token_id=self.model.tokenizer.eos_token_id,
num_return_sequences=1)
-
+
for j, out in enumerate(out_embeds):
out[out == -1] = 2
response = self.model.decode_text([out])
return response
-
diff --git a/vlmeval/vlm/transcore_m.py b/vlmeval/vlm/transcore_m.py
index 752c0e391..f56f38092 100644
--- a/vlmeval/vlm/transcore_m.py
+++ b/vlmeval/vlm/transcore_m.py
@@ -6,6 +6,7 @@
from ..smp import *
from ..utils import DATASET_TYPE, CustomPrompt
+
class TransCoreM(CustomPrompt):
INSTALL_REQ = True
@@ -28,15 +29,15 @@ def __init__(self,
device_map='cpu'
)
self.model = self.model.cuda()
- print("==============conv_mode: default")
- self.conv_mode = "default"
+ print('==============conv_mode: default')
+ self.conv_mode = 'default'
kwargs_default = dict(do_sample=False, temperature=0.0, max_new_tokens=128, top_p=None, num_beams=1)
kwargs_default.update(kwargs)
self.kwargs = kwargs_default
- warnings.warn(f"Following kwargs received: {self.kwargs}, will use as generation config. ")
+ warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')
- def get_options(self,row, options):
+ def get_options(self, row, options):
parsed_options = []
for option in options:
option_value = row[option]
@@ -45,7 +46,7 @@ def get_options(self,row, options):
parsed_options.append(option_value)
return parsed_options
- def is_none(self,value):
+ def is_none(self, value):
if value is None:
return True
if type(value) is float and math.isnan(value):
@@ -55,7 +56,7 @@ def is_none(self,value):
if type(value) is str and value.lower() == 'none':
return True
return False
-
+
def use_custom_prompt(self, dataset):
assert dataset is not None
if DATASET_TYPE(dataset) == 'multi-choice':
@@ -82,15 +83,19 @@ def build_prompt(self, line, dataset=None):
prompt = question
if len(options):
- prompt += "\n请直接回答选项字母。" if cn_string(prompt) else "\nAnswer with the option's letter from the given choices directly."
+ prompt += (
+ '\n请直接回答选项字母。' if cn_string(prompt)
+ else "\nAnswer with the option's letter from the given choices directly."
+ )
else:
- prompt += "\n请直接回答问题。" if cn_string(prompt) else "\nAnswer the question directly."
-
+ prompt += '\n请直接回答问题。' if cn_string(prompt) else '\nAnswer the question directly.'
+
return {'image': tgt_path, 'text': prompt}
def generate(self, image_path, prompt, dataset=None):
from transcorem.mm_utils import process_images, tokenizer_image_token, KeywordsStoppingCriteria
- from transcorem.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
+ from transcorem.constants import (
+ IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN)
from transcorem.conversation import conv_templates, SeparatorStyle
image = Image.open(image_path).convert('RGB')
@@ -107,7 +112,8 @@ def generate(self, image_path, prompt, dataset=None):
conv.append_message(conv.roles[1], None)
prompt_conv = conv.get_prompt()
- input_ids = tokenizer_image_token(prompt_conv, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
+ input_ids = tokenizer_image_token(
+ prompt_conv, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
with torch.inference_mode():
output_ids = self.model.generate(
diff --git a/vlmeval/vlm/visualglm.py b/vlmeval/vlm/visualglm.py
index 6b7ecd8cd..3536443d1 100644
--- a/vlmeval/vlm/visualglm.py
+++ b/vlmeval/vlm/visualglm.py
@@ -7,29 +7,29 @@ class VisualGLM:
INSTALL_REQ = False
- def __init__(self, model_path="THUDM/visualglm-6b", **kwargs):
+ def __init__(self, model_path='THUDM/visualglm-6b', **kwargs):
try:
import sat
except:
- warnings.warn("Please install SwissArmyTransformer to use VisualGLM")
+ warnings.warn('Please install SwissArmyTransformer to use VisualGLM')
assert model_path is not None
self.model_path = model_path
-
+
from transformers import AutoModel
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
- model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half().cuda()
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half().cuda()
self.model = model
self.kwargs = kwargs
- warnings.warn(f"Following kwargs received: {self.kwargs}, will use as generation config. ")
+ warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')
def generate(self, image_path, prompt, dataset=None):
-
+
output, _ = self.model.chat(
- image_path = image_path,
- tokenizer = self.tokenizer,
- query = prompt,
- history = [],
+ image_path=image_path,
+ tokenizer=self.tokenizer,
+ query=prompt,
+ history=[],
**self.kwargs
)
- return output
\ No newline at end of file
+ return output
diff --git a/vlmeval/vlm/xcomposer.py b/vlmeval/vlm/xcomposer.py
index 34ac9df96..2de110288 100644
--- a/vlmeval/vlm/xcomposer.py
+++ b/vlmeval/vlm/xcomposer.py
@@ -6,6 +6,7 @@
from ..smp import *
from ..utils import CustomPrompt
+
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self, stops=[], encounters=1):
super().__init__()
@@ -18,27 +19,29 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
return False
+
from ..utils import DATASET_TYPE
+
class XComposer(CustomPrompt):
INSTALL_REQ = False
-
+
def __init__(self, model_path='internlm/internlm-xcomposer-vl-7b', **kwargs):
assert model_path is not None
self.model_path = model_path
-
+
model = AutoModel.from_pretrained(self.model_path, device_map='cpu', trust_remote_code=True).cuda().eval()
tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
model.tokenizer = tokenizer
self.model = model
self.device = self.model.internlm_model.model.embed_tokens.weight.device
stop_words_ids = [
- torch.tensor([103027]).to(self.device), ### end of human
- torch.tensor([103028]).to(self.device), ### end of bot
+ torch.tensor([103027]).to(self.device), # end of human
+ torch.tensor([103028]).to(self.device), # end of bot
]
default_kwargs = {
- 'max_new_tokens': 128, 'num_beams': 5, 'do_sample': False,
+ 'max_new_tokens': 128, 'num_beams': 5, 'do_sample': False,
'min_length': 1, 'repetition_penalty': 1.5, 'length_penalty': 1.0
}
default_kwargs.update(kwargs)
@@ -47,9 +50,9 @@ def __init__(self, model_path='internlm/internlm-xcomposer-vl-7b', **kwargs):
def generate_vanilla(self, image_path, prompt):
return self.model.generate(prompt, image_path, **self.kwargs)
-
+
def generate_multichoice(self, image_path, prompt):
- image = Image.open(image_path).convert("RGB")
+ image = Image.open(image_path).convert('RGB')
image = self.model.vis_processor(image).unsqueeze(0).to(self.device)
img_embeds = self.model.encode_img(image)
prompt_segs = prompt.split('')
@@ -63,7 +66,7 @@ def generate_multichoice(self, image_path, prompt):
]
prompt_seg_embs = [prompt_seg_embs[0], img_embeds, prompt_seg_embs[1]]
prompt_embs = torch.cat(prompt_seg_embs, dim=1)
-
+
outputs = self.model.internlm_model.generate(
inputs_embeds=prompt_embs,
max_new_tokens=5,
@@ -86,7 +89,7 @@ def generate_multichoice(self, image_path, prompt):
output_text = output_text.split(self.model.eoa)[0]
output_text = output_text.split('<|Bot|>')[-1].strip()
return output_text
-
+
def generate(self, image_path, prompt, dataset=None):
if dataset is None:
return self.generate_vanilla(image_path, prompt)
@@ -95,7 +98,7 @@ def generate(self, image_path, prompt, dataset=None):
return self.generate_multichoice(image_path, prompt)
else:
return self.generate_vanilla(image_path, prompt)
-
+
def list_to_prompt_embs(self, ti_list):
assert isinstance(ti_list, list)
img_embeds = []
@@ -107,13 +110,13 @@ def list_to_prompt_embs(self, ti_list):
img_embeds.append(self.model.encode_img(image))
prompt_full += f'Image {len(img_embeds)}: '
else:
- prompt_full += s
+ prompt_full += s
prompt_full += self.model.eoh + ' <|Bot|>: '
prompt_segs = prompt_full.split('')
assert len(prompt_segs) == len(img_embeds) + 1
-
+
prompt_seg_tokens = [
- self.model.tokenizer(seg, return_tensors='pt', add_special_tokens=i==0).to(self.device).input_ids
+ self.model.tokenizer(seg, return_tensors='pt', add_special_tokens=(i == 0)).to(self.device).input_ids
for i, seg in enumerate(prompt_segs)
]
prompt_seg_embs = [self.model.internlm_model.model.embed_tokens(seg) for seg in prompt_seg_tokens]
@@ -123,7 +126,7 @@ def list_to_prompt_embs(self, ti_list):
all_embeddings.append(prompt_seg_embs[-1])
prompt_embs = torch.cat(all_embeddings, dim=1)
return prompt_embs
-
+
# def interleave_generate(self, ti_list, dataset=None):
# prompt_embs = self.list_to_prompt_embs(ti_list)
# outputs = self.model.internlm_model.generate(
@@ -140,13 +143,13 @@ def list_to_prompt_embs(self, ti_list):
# output_text = output_text.split(self.model.eoa)[0]
# output_text = output_text.split('<|Bot|>')[-1].strip()
# return output_text
-
+
def use_custom_prompt(self, dataset):
assert dataset is not None
if DATASET_TYPE(dataset) == 'multi-choice':
return True
return False
-
+
def build_prompt(self, line, dataset=None):
assert dataset is None or isinstance(dataset, str)
assert self.use_custom_prompt(dataset)
@@ -173,4 +176,4 @@ def build_prompt(self, line, dataset=None):
ans_prompt = ' <|Bot|>: Answer: The answer is'
prompt = img_prompt + txt_prompt + mid_prompt + '' + ans_prompt
- return {'image': tgt_path, 'text': prompt}
\ No newline at end of file
+ return {'image': tgt_path, 'text': prompt}
diff --git a/vlmeval/vlm/xcomposer2.py b/vlmeval/vlm/xcomposer2.py
index 430ae5fc9..a73b942a9 100644
--- a/vlmeval/vlm/xcomposer2.py
+++ b/vlmeval/vlm/xcomposer2.py
@@ -9,24 +9,31 @@
import re
pattern = re.compile(r'[A-Z]')
+
def __padding__(image):
width, height = image.size
tar = max(width, height)
- top_padding = int((tar - height)/2)
+ top_padding = int((tar - height) / 2)
bottom_padding = tar - height - top_padding
- left_padding = int((tar - width)/2)
+ left_padding = int((tar - width) / 2)
right_padding = tar - width - left_padding
image = torchvision.transforms.functional.pad(image, [left_padding, top_padding, right_padding, bottom_padding])
return image
-meta_instruction = """You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).
-- InternLM-XComposer (浦语·灵笔) is a multi-modality conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.
-- InternLM-XComposer (浦语·灵笔) can understand and communicate fluently in the language chosen by the user such as English and 中文.
-- InternLM-XComposer (浦语·灵笔) is capable of comprehending and articulating responses effectively based on the provided image."""
+
+meta_instruction = """
+You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).
+- InternLM-XComposer (浦语·灵笔) is a multi-modality conversational language model that is developed by
+Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.
+- InternLM-XComposer (浦语·灵笔) can understand and communicate fluently in the language
+chosen by the user such as English and 中文.
+- InternLM-XComposer (浦语·灵笔) is capable of comprehending and articulating responses effectively
+based on the provided image.
+"""
+
def model_gen(model, text, images, need_bos=True, padding=False, beams=3, max_token=500):
pt1 = 0
- #print (text)
embeds = []
im_mask = []
images = [images]
@@ -54,9 +61,14 @@ def model_gen(model, text, images, need_bos=True, padding=False, beams=3, max_to
im_mask = torch.cat(im_mask, dim=1)
im_mask = im_mask.bool()
- outputs = model.generate(inputs_embeds=embeds, im_mask=im_mask,
- temperature=1.0, max_new_tokens=max_token, num_beams=beams,
- do_sample=False, repetition_penalty=1.0)
+ outputs = model.generate(
+ inputs_embeds=embeds,
+ im_mask=im_mask,
+ temperature=1.0,
+ max_new_tokens=max_token,
+ num_beams=beams,
+ do_sample=False,
+ repetition_penalty=1.0)
output_token = outputs[0]
if output_token[0] == 0 or output_token[0] == 1:
@@ -69,18 +81,17 @@ def model_gen(model, text, images, need_bos=True, padding=False, beams=3, max_to
class XComposer2(CustomPrompt):
INSTALL_REQ = False
-
+
def __init__(self, model_path='internlm/internlm-xcomposer2-vl-7b', **kwargs):
assert model_path is not None
self.model_path = model_path
-
+
model = AutoModel.from_pretrained(self.model_path, device_map='cpu', trust_remote_code=True).cuda().eval()
model.half()
tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
model.tokenizer = tokenizer
self.model = model
self.device = self.model.model.tok_embeddings.weight.device
-
def generate_mme(self, image_path, text):
text = text.split('Please answer')[0].strip()
@@ -90,12 +101,13 @@ def generate_mme(self, image_path, text):
return model_gen(self.model, text, image_path, need_bos=True, padding=True, beams=5)
def generate_multichoice(self, image_path, text, dataset):
- out = model_gen(self.model, text, image_path, need_bos=True, padding=False, beams=5, max_token=5)
+ out = model_gen(self.model, text, image_path, need_bos=True, padding=False, beams=5, max_token=5)
if 'mmmu' in dataset.lower():
return out
res = pattern.findall(out)
if len(res) == 0:
- print('Error:', out); res = 'Z'
+ print('Error:', out)
+ res = 'Z'
return res[0]
def generate_vqa(self, image_path, text):
@@ -103,17 +115,23 @@ def generate_vqa(self, image_path, text):
return out
def generate_vanilla(self, image_path, text):
- text = '[UNUSED_TOKEN_146]system\n{}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]user\n{}Answer this question in detail.[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n'.format(meta_instruction, text)
+ text = (
+ '[UNUSED_TOKEN_146]system\n{}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]user\n{}'
+ 'Answer this question in detail.[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n'
+ ).format(meta_instruction, text)
out = model_gen(self.model, text, image_path, need_bos=True, max_token=500)
return out
def generate_brief(self, image_path, text):
- text = "[UNUSED_TOKEN_146]user\nAnswer the question using a single word or phrase.{}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n".format(text)
+ text = (
+ '[UNUSED_TOKEN_146]user\nAnswer the question using a single word or phrase.{}'
+ '[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n'
+ ).format(text)
out = model_gen(self.model, text, image_path, need_bos=True, max_token=10)
return out
-
+
def generate_driectly(self, image_path, text):
- text = "[UNUSED_TOKEN_146]user\n{}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n".format(text)
+ text = '[UNUSED_TOKEN_146]user\n{}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n'.format(text)
out = model_gen(self.model, text, image_path, need_bos=True, max_token=500)
return out
@@ -139,13 +157,13 @@ def generate(self, image_path, prompt, dataset=None):
else:
return self.generate_vanilla(image_path, prompt)
-
+
def use_custom_prompt(self, dataset):
assert dataset is not None
if DATASET_TYPE(dataset) == 'multi-choice' or DATASET_TYPE(dataset) == 'VQA':
return True
return False
-
+
def build_mcqa(self, line):
question = line['question']
options = {
@@ -172,19 +190,21 @@ def build_mcqa(self, line):
return prompt
-
def build_prompt(self, line, dataset=None):
assert dataset is None or isinstance(dataset, str)
assert self.use_custom_prompt(dataset)
tgt_path = self.dump_image(line, dataset)
-
+
if DATASET_TYPE(dataset) == 'multi-choice':
prompt = self.build_mcqa(line)
elif DATASET_TYPE(dataset) == 'VQA':
if 'mathvista' in dataset.lower():
q = line['question']
- prompt = f"[UNUSED_TOKEN_146]user\n{q}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n"
- else:
+ prompt = f'[UNUSED_TOKEN_146]user\n{q}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n'
+ else:
q = line['question']
- prompt = f"[UNUSED_TOKEN_146]user\nAnswer the question using a single word or phrase.{q}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n"
+ prompt = (
+ f'[UNUSED_TOKEN_146]user\nAnswer the question using a single word or phrase.{q}'
+ '[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n'
+ )
return {'image': tgt_path, 'text': prompt}
diff --git a/vlmeval/vlm/yi_vl.py b/vlmeval/vlm/yi_vl.py
index 9f3e70684..3c208c1f5 100644
--- a/vlmeval/vlm/yi_vl.py
+++ b/vlmeval/vlm/yi_vl.py
@@ -10,7 +10,7 @@
You can perform inference of Yi-VL through the following steps:
1. clone the repo https://github.com/01-ai/Yi to path-to-Yi
2. set up the environment and install the required packages in path-to-Yi/VL/requirements.txt
-3. set Yi_ROOT in vlmeval/config.py
+3. set Yi_ROOT in vlmeval/config.py
Yi_ROOT = path-to-Yi
You are all set now! To run a demo for Yi-VL:
@@ -22,6 +22,7 @@
To run evaluation for Yi-VL, use `python run.py --model Yi_VL_6B --data {dataset_list}`
"""
+
def edit_config(repo_id):
if not osp.exists(repo_id):
root = get_cache_path(repo_id)
@@ -36,28 +37,33 @@ def edit_config(repo_id):
assert osp.exists(data['mm_vision_tower'])
dump(data, cfg)
+
def disable_torch_init():
"""
Disable the redundant torch default initialization to accelerate model creation.
"""
import torch
- setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
- setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
-
+ setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)
+ setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)
+
+
class Yi_VL:
INSTALL_REQ = True
-
- def __init__(self,
- model_path='01-ai/Yi-VL-6B',
+
+ def __init__(self,
+ model_path='01-ai/Yi-VL-6B',
root=None,
**kwargs):
-
+
if root is None:
- warnings.warn('Please set root to the directory of Yi, which is cloned from here: https://github.com/01-ai/Yi ')
-
- self.root = osp.join(root,'VL')
+ warnings.warn(
+ 'Please set root to the directory of Yi, '
+ 'which is cloned from here: https://github.com/01-ai/Yi.'
+ )
+
+ self.root = osp.join(root, 'VL')
sys.path.append(self.root)
if splitlen(model_path, '/') == 2 and not osp.exists(model_path):
@@ -69,54 +75,54 @@ def __init__(self,
from llava.mm_utils import get_model_name_from_path, load_pretrained_model
from llava.model.constants import key_info
-
+
disable_torch_init()
- key_info["model_path"] = model_path
+ key_info['model_path'] = model_path
get_model_name_from_path(model_path)
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
model_path,
device_map='cpu')
self.model = self.model.cuda()
self.conv_mode = 'mm_default'
-
+
kwargs_default = dict(temperature=0.2,
num_beams=1,
- do_sample=False,
- max_new_tokens=1024,
+ do_sample=False,
+ max_new_tokens=1024,
top_p=None)
kwargs_default.update(kwargs)
self.kwargs = kwargs_default
- warnings.warn(f"Following kwargs received: {self.kwargs}, will use as generation config. ")
-
+ warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')
+
def generate(self, image_path, prompt, dataset=None):
-
+
from llava.conversation import conv_templates
from llava.model.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
from llava.mm_utils import KeywordsStoppingCriteria, expand2square, tokenizer_image_token
-
- qs = DEFAULT_IMAGE_TOKEN + "\n" + prompt
+
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + prompt
conv = conv_templates[self.conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
-
+
input_ids = (
- tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
+ tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
.unsqueeze(0)
.cuda()
- )
-
+ )
+
image = Image.open(image_path)
- if getattr(self.model.config, "image_aspect_ratio", None) == "pad":
+ if getattr(self.model.config, 'image_aspect_ratio', None) == 'pad':
if image.mode == 'L':
background_color = int(sum([int(x * 255) for x in self.image_processor.image_mean]) / 3)
else:
background_color = tuple(int(x * 255) for x in self.image_processor.image_mean)
image = expand2square(image, background_color)
- image_tensor = self.image_processor.preprocess(image, return_tensors="pt")[
- "pixel_values"
- ][0]
-
+ image_tensor = self.image_processor.preprocess(image, return_tensors='pt')[
+ 'pixel_values'
+ ][0]
+
stop_str = conv.sep
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
@@ -126,21 +132,21 @@ def generate(self, image_path, prompt, dataset=None):
input_ids,
images=image_tensor.unsqueeze(0).to(dtype=torch.bfloat16).cuda(),
stopping_criteria=[stopping_criteria],
- use_cache=True,
+ use_cache=True,
**self.kwargs)
-
+
input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(
- f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids"
- )
+ f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids'
+ )
outputs = self.tokenizer.batch_decode(
output_ids[:, input_token_len:], skip_special_tokens=True
- )[0]
+ )[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[: -len(stop_str)]
outputs = outputs.strip()
- return outputs
\ No newline at end of file
+ return outputs