diff --git a/README.md b/README.md index 5f9d0be19..ab9ef5e2e 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,8 @@ ## 🆕 News -- **[2024-03-20]** We support users to use a `.env` file to manage all environment variables used in VLMEvalKit, see [**Quickstart**](\Quickstart.md) for more details +- **[2024-03-21]** We have supported [**DeepSeek-VL**](https://github.com/deepseek-ai/DeepSeek-VL/tree/main) 🔥🔥🔥 +- **[2024-03-20]** We have supported users to use a `.env` file to manage all environment variables used in VLMEvalKit, see [**Quickstart**](\Quickstart.md) for more details - **[2024-03-17]** We have added an API wrapper for [**Step-1V**](https://www.stepfun.com/#step1v) 🔥🔥🔥 - **[2024-03-15]** We have updated to be compatible with the latest version of LLaVA. All LLaVA series models have been re-evaluated with temperature=0, and the new results have been updated to the leaderboard 🔥🔥🔥 - **[2024-02-27]** We have fixed the evaluation results of [**Yi-VL-34B**](https://huggingface.co/01-ai/Yi-VL-34B), check the updated results [**here**](https://huggingface.co/spaces/opencompass/open_vlm_leaderboard) 🔥🔥🔥 @@ -29,7 +30,6 @@ - **[2024-02-07]** We have supported two new models: [**MiniCPM-V**](https://huggingface.co/openbmb/MiniCPM-V) and [**OmniLMM-12B**](https://huggingface.co/openbmb/OmniLMM-12B). 🔥🔥🔥 - **[2024-01-30]** We have supported three new models: [**QwenVLMax**](https://huggingface.co/spaces/Qwen/Qwen-VL-Max), [**InternLM-XComposer2-7B**](https://huggingface.co/internlm/internlm-xcomposer2-vl-7b), [**MMAlaya**](https://huggingface.co/DataCanvas/MMAlaya) 🔥🔥🔥 - **[2024-01-30]** We have merged all performance numbers on our leaderboards into a single json file: [**OpenVLM.json**](http://opencompass.openxlab.space/utils/OpenVLM.json) -- **[2024-01-27]** We have supported the evaluation of [**MMMU_TEST**](https://mmmu-benchmark.github.io) 🔥🔥🔥 ## 📊 Datasets, Models, and Evaluation Results @@ -72,10 +72,11 @@ | [**IDEFICS-[9B/80B]-Instruct**](https://huggingface.co/HuggingFaceM4/idefics-9b-instruct)🎞️🚅 | [**InstructBLIP-[7B/13B]**](https://github.com/salesforce/LAVIS/blob/main/projects/instructblip/README.md) | [**LLaVA-[v1-7B/v1.5-7B/v1.5-13B]**](https://github.com/haotian-liu/LLaVA) | [**MiniGPT-4-[v1-7B/v1-13B/v2-7B]**](https://github.com/Vision-CAIR/MiniGPT-4) | | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | | [**mPLUG-Owl2**](https://github.com/X-PLUG/mPLUG-Owl/tree/main/mPLUG-Owl2)🎞️ | [**OpenFlamingo-v2**](https://github.com/mlfoundations/open_flamingo)🎞️ | [**PandaGPT-13B**](https://github.com/yxuansu/PandaGPT) | [**Qwen-VL**](https://huggingface.co/Qwen/Qwen-VL)🎞️🚅, [**Qwen-VL-Chat**](https://huggingface.co/Qwen/Qwen-VL-Chat)🎞️**🚅** | -| [**VisualGLM-6B**](https://huggingface.co/THUDM/visualglm-6b)🚅 | [**InternLM-XComposer-7B**](https://huggingface.co/internlm/internlm-xcomposer-7b)🚅🎞️ | [**ShareGPT4V-7B**](https://sharegpt4v.github.io)🚅 | [**TransCore-M**](https://github.com/PCIResearch/TransCore-M) | +| [**VisualGLM-6B**](https://huggingface.co/THUDM/visualglm-6b)🚅 | [**InternLM-XComposer-7B**](https://huggingface.co/internlm/internlm-xcomposer-7b)🚅🎞️ | [**ShareGPT4V-[7B/13B]**](https://sharegpt4v.github.io)🚅 | [**TransCore-M**](https://github.com/PCIResearch/TransCore-M) | | [**LLaVA (XTuner)**](https://huggingface.co/xtuner/llava-internlm-7b)🚅 | [**CogVLM-17B-Chat**](https://huggingface.co/THUDM/cogvlm-chat-hf)🚅 | [**SharedCaptioner**](https://huggingface.co/spaces/Lin-Chen/Share-Captioner)🚅 | [**CogVLM-Grounding-Generalist**](https://huggingface.co/THUDM/cogvlm-grounding-generalist-hf)🚅 | | [**Monkey**](https://github.com/Yuliang-Liu/Monkey)🚅 | [**EMU2 / EMU2-Chat**](https://github.com/baaivision/Emu)🚅🎞️ | [**Yi-VL-[6B/34B]**](https://huggingface.co/01-ai/Yi-VL-6B) | [**MMAlaya**](https://huggingface.co/DataCanvas/MMAlaya)🚅 | | [**InternLM-XComposer2-7B**](https://huggingface.co/internlm/internlm-xcomposer2-vl-7b)🚅🎞️ | [**MiniCPM-V**](https://huggingface.co/openbmb/MiniCPM-V)🚅 | [**OmniLMM-12B**](https://huggingface.co/openbmb/OmniLMM-12B) | [**InternVL-Chat Series**](https://github.com/OpenGVLab/InternVL)🚅 | +| [**DeepSeek-VL**](https://github.com/deepseek-ai/DeepSeek-VL/tree/main)🎞️ | | | | 🎞️: Support multiple images as inputs, via the `interleave_generate` interface. @@ -83,8 +84,8 @@ **Transformers Version Recommendation: ** Note that some VLMs may not be able to run under certain transformer versions, we recommend the following settings to evaluate each VLM: -- **Please use** `transformers==4.33.0` **for**: Qwen series, Monkey series, InternVL series, InternLM-XComposer Series, mPLUG-Owl2, OpenFlamingo v2, IDEFICS series, VisualGLM, MMAlaya, SharedCaptioner, MiniGPT4 series, InstructBLIP series -- **Please use** `transformers==4.37.0 ` **for**: Other VLMs. +- **Please use** `transformers==4.33.0` **for**: `Qwen series`, `Monkey series`, `InternVL series`, `InternLM-XComposer Series`, `mPLUG-Owl2`, `OpenFlamingo v2`, `IDEFICS series`, `VisualGLM`, `MMAlaya`, `SharedCaptioner`, `MiniGPT-4 series`, `InstructBLIP series`, `PandaGPT`. +- **Please use** `transformers==4.37.0 ` **for**: `LLaVA series`, `ShareGPT4V series`, `TransCore-M`, `LLaVA (XTuner)`, `CogVLM Series`, `EMU2 Series`, `Yi-VL Series`, `MiniCPM-V`, `OmniLMM-12B`, `DeepSeek-VL series`. ```python # Demo diff --git a/vlmeval/config.py b/vlmeval/config.py index 0f48e4bd0..49a233ba7 100644 --- a/vlmeval/config.py +++ b/vlmeval/config.py @@ -9,46 +9,20 @@ OmniLMM_ROOT = None LLAVA_V1_7B_MODEL_PTH = 'Please set your local path to LLaVA-7B-v1.1 here, the model weight is obtained by merging LLaVA delta weight based on vicuna-7b-v1.1 in https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md with vicuna-7b-v1.1. ' -models = { - 'qwen_base': partial(QwenVL, model_path='Qwen/Qwen-VL'), +ungrouped = { 'TransCore_M': partial(TransCoreM, root=TransCore_ROOT), - 'qwen_chat': partial(QwenVLChat, model_path='Qwen/Qwen-VL-Chat'), 'PandaGPT_13B': partial(PandaGPT, name='PandaGPT_13B', root=PandaGPT_ROOT), 'flamingov2': partial(OpenFlamingo, name='v2', mpt_pth='anas-awadalla/mpt-7b', ckpt_pth='openflamingo/OpenFlamingo-9B-vitl-mpt7b'), - 'flamingov2_fs': partial(OpenFlamingo, name='v2', with_context=True, mpt_pth='anas-awadalla/mpt-7b', ckpt_pth='openflamingo/OpenFlamingo-9B-vitl-mpt7b'), - 'idefics_9b_instruct': partial(IDEFICS, model_pth='HuggingFaceM4/idefics-9b-instruct'), - 'idefics_80b_instruct': partial(IDEFICS, model_pth='HuggingFaceM4/idefics-80b-instruct'), - 'idefics_9b_instruct_fs': partial(IDEFICS, model_pth='HuggingFaceM4/idefics-9b-instruct', with_context=True), - 'idefics_80b_instruct_fs': partial(IDEFICS, model_pth='HuggingFaceM4/idefics-80b-instruct', with_context=True), - 'llava_v1.5_7b': partial(LLaVA, model_pth='liuhaotian/llava-v1.5-7b'), - 'llava_v1.5_13b': partial(LLaVA, model_pth='liuhaotian/llava-v1.5-13b'), - 'llava_v1_7b': partial(LLaVA, model_pth=LLAVA_V1_7B_MODEL_PTH), - 'sharegpt4v_7b': partial(LLaVA, model_pth='Lin-Chen/ShareGPT4V-7B'), - 'sharegpt4v_13b': partial(LLaVA, model_pth='Lin-Chen/ShareGPT4V-13B'), - 'instructblip_7b': partial(InstructBLIP, name='instructblip_7b'), - 'instructblip_13b': partial(InstructBLIP, name='instructblip_13b'), 'VisualGLM_6b': partial(VisualGLM, model_path='THUDM/visualglm-6b'), - 'MiniGPT-4-v2': partial(MiniGPT4, mode='v2', root=MiniGPT4_ROOT), - 'MiniGPT-4-v1-7B': partial(MiniGPT4, mode='v1_7b', root=MiniGPT4_ROOT), - 'MiniGPT-4-v1-13B': partial(MiniGPT4, mode='v1_13b', root=MiniGPT4_ROOT), - 'XComposer': partial(XComposer, model_path='internlm/internlm-xcomposer-vl-7b'), - 'XComposer2': partial(XComposer2, model_path='internlm/internlm-xcomposer2-vl-7b'), 'mPLUG-Owl2': partial(mPLUG_Owl2, model_path='MAGAer13/mplug-owl2-llama2-7b'), 'cogvlm-grounding-generalist':partial(CogVlm, name='cogvlm-grounding-generalist',tokenizer_name ='lmsys/vicuna-7b-v1.5'), 'cogvlm-chat':partial(CogVlm, name='cogvlm-chat',tokenizer_name ='lmsys/vicuna-7b-v1.5'), 'sharedcaptioner':partial(SharedCaptioner, model_path='Lin-Chen/ShareCaptioner'), 'emu2':partial(Emu, name='emu2'), 'emu2_chat':partial(Emu, name='emu2_chat'), - 'monkey':partial(Monkey, model_path='echo840/Monkey'), - 'monkey-chat':partial(MonkeyChat, model_path='echo840/Monkey-Chat'), - 'Yi_VL_6B':partial(Yi_VL, model_path='01-ai/Yi-VL-6B', root=Yi_ROOT), - 'Yi_VL_34B':partial(Yi_VL, model_path='01-ai/Yi-VL-34B', root=Yi_ROOT), 'MMAlaya':partial(MMAlaya, model_path='DataCanvas/MMAlaya'), 'MiniCPM-V':partial(MiniCPM_V, model_path='openbmb/MiniCPM-V'), 'OmniLMM_12B':partial(OmniLMM12B, model_path='openbmb/OmniLMM-12B', root=OmniLMM_ROOT), - 'InternVL-Chat-V1-1':partial(InternVLChat, model_path='OpenGVLab/InternVL-Chat-Chinese-V1-1'), - 'InternVL-Chat-V1-2': partial(InternVLChat, model_path='OpenGVLab/InternVL-Chat-Chinese-V1-2'), - 'InternVL-Chat-V1-2-Plus': partial(InternVLChat, model_path='OpenGVLab/InternVL-Chat-Chinese-V1-2-Plus'), } api_models = { @@ -65,7 +39,6 @@ 'GeminiProVision': partial(GeminiProVision, temperature=0, retry=10), 'QwenVLPlus': partial(QwenVLAPI, model='qwen-vl-plus', temperature=0, retry=10), 'QwenVLMax': partial(QwenVLAPI, model='qwen-vl-max', temperature=0, retry=10), - # Internal Only 'Step1V': partial(Step1V, temperature=0, retry=10), # Internal Only 'Claude3V_Opus': partial(Claude3V, model='claude-3-opus-20240229', temperature=0, retry=10), @@ -73,7 +46,7 @@ 'Claude3V_Haiku': partial(Claude3V, model='claude-3-haiku-20240307', temperature=0, retry=10), } -xtuner_models = { +xtuner_series = { 'llava-internlm2-7b': partial(LLaVA_XTuner, llm_path='internlm/internlm2-chat-7b', llava_path='xtuner/llava-internlm2-7b', visual_select_layer=-2, prompt_template='internlm2_chat'), 'llava-internlm2-20b': partial(LLaVA_XTuner, llm_path='internlm/internlm2-chat-20b', llava_path='xtuner/llava-internlm2-20b', visual_select_layer=-2, prompt_template='internlm2_chat'), 'llava-internlm-7b': partial(LLaVA_XTuner, llm_path='internlm/internlm-chat-7b', llava_path='xtuner/llava-internlm-7b', visual_select_layer=-2, prompt_template='internlm_chat'), @@ -81,6 +54,66 @@ 'llava-v1.5-13b-xtuner': partial(LLaVA_XTuner, llm_path='lmsys/vicuna-13b-v1.5', llava_path='xtuner/llava-v1.5-13b-xtuner', visual_select_layer=-2, prompt_template='vicuna'), } +qwen_series = { + 'qwen_base': partial(QwenVL, model_path='Qwen/Qwen-VL'), + 'qwen_chat': partial(QwenVLChat, model_path='Qwen/Qwen-VL-Chat'), + 'monkey':partial(Monkey, model_path='echo840/Monkey'), + 'monkey-chat':partial(MonkeyChat, model_path='echo840/Monkey-Chat') +} + +llava_series = { + 'llava_v1.5_7b': partial(LLaVA, model_pth='liuhaotian/llava-v1.5-7b'), + 'llava_v1.5_13b': partial(LLaVA, model_pth='liuhaotian/llava-v1.5-13b'), + 'llava_v1_7b': partial(LLaVA, model_pth=LLAVA_V1_7B_MODEL_PTH), + 'sharegpt4v_7b': partial(LLaVA, model_pth='Lin-Chen/ShareGPT4V-7B'), + 'sharegpt4v_13b': partial(LLaVA, model_pth='Lin-Chen/ShareGPT4V-13B'), +} + +internvl_series = { + 'InternVL-Chat-V1-1':partial(InternVLChat, model_path='OpenGVLab/InternVL-Chat-Chinese-V1-1'), + 'InternVL-Chat-V1-2': partial(InternVLChat, model_path='OpenGVLab/InternVL-Chat-Chinese-V1-2'), + 'InternVL-Chat-V1-2-Plus': partial(InternVLChat, model_path='OpenGVLab/InternVL-Chat-Chinese-V1-2-Plus'), +} + +yivl_series = { + 'Yi_VL_6B':partial(Yi_VL, model_path='01-ai/Yi-VL-6B', root=Yi_ROOT), + 'Yi_VL_34B':partial(Yi_VL, model_path='01-ai/Yi-VL-34B', root=Yi_ROOT), +} + +xcomposer_series = { + 'XComposer': partial(XComposer, model_path='internlm/internlm-xcomposer-vl-7b'), + 'XComposer2': partial(XComposer2, model_path='internlm/internlm-xcomposer2-vl-7b'), +} + +minigpt4_series = { + 'MiniGPT-4-v2': partial(MiniGPT4, mode='v2', root=MiniGPT4_ROOT), + 'MiniGPT-4-v1-7B': partial(MiniGPT4, mode='v1_7b', root=MiniGPT4_ROOT), + 'MiniGPT-4-v1-13B': partial(MiniGPT4, mode='v1_13b', root=MiniGPT4_ROOT), +} + +idefics_series = { + 'idefics_9b_instruct': partial(IDEFICS, model_pth='HuggingFaceM4/idefics-9b-instruct'), + 'idefics_80b_instruct': partial(IDEFICS, model_pth='HuggingFaceM4/idefics-80b-instruct'), +} + +instructblip_series = { + 'instructblip_7b': partial(InstructBLIP, name='instructblip_7b'), + 'instructblip_13b': partial(InstructBLIP, name='instructblip_13b'), +} + +deepseekvl_series = { + 'deepseek_vl_7b': partial(DeepSeekVL, model_path='deepseek-ai/deepseek-vl-7b-chat'), + 'deepseek_vl_1.3b': partial(DeepSeekVL, model_path='deepseek-ai/deepseek-vl-1.3b-chat'), +} + supported_VLM = {} -for model_set in [models, api_models, xtuner_models]: - supported_VLM.update(model_set) + +model_groups = [ + ungrouped, api_models, + xtuner_series, qwen_series, llava_series, internvl_series, yivl_series, + xcomposer_series, minigpt4_series, idefics_series, instructblip_series, + deepseekvl_series +] + +for grp in model_groups: + supported_VLM.update(grp) diff --git a/vlmeval/evaluate/OCRBench.py b/vlmeval/evaluate/OCRBench.py index 06cf767c3..c37ad0872 100644 --- a/vlmeval/evaluate/OCRBench.py +++ b/vlmeval/evaluate/OCRBench.py @@ -1,20 +1,20 @@ from vlmeval.smp import * -OCRBench_score = { - 'Regular Text Recognition': 0, - 'Irregular Text Recognition': 0, - 'Artistic Text Recognition': 0, - 'Handwriting Recognition': 0, - 'Digit String Recognition': 0, - 'Non-Semantic Text Recognition': 0, - 'Scene Text-centric VQA': 0, - 'Doc-oriented VQA': 0, - 'Key Information Extraction': 0, - 'Handwritten Mathematical Expression Recognition': 0 -} - def OCRBench_eval(eval_file): + OCRBench_score = { + 'Regular Text Recognition': 0, + 'Irregular Text Recognition': 0, + 'Artistic Text Recognition': 0, + 'Handwriting Recognition': 0, + 'Digit String Recognition': 0, + 'Non-Semantic Text Recognition': 0, + 'Scene Text-centric VQA': 0, + 'Doc-oriented VQA': 0, + 'Key Information Extraction': 0, + 'Handwritten Mathematical Expression Recognition': 0 + } + logger = get_logger('Evaluation') data = load(eval_file) diff --git a/vlmeval/smp/misc.py b/vlmeval/smp/misc.py index 8a0a021cd..1b98b3976 100644 --- a/vlmeval/smp/misc.py +++ b/vlmeval/smp/misc.py @@ -168,3 +168,16 @@ def load_env(): os.environ[k] = v print(f'API Keys successfully loaded from {pth}') return + +def pip_install_robust(package): + import sys + retry = 3 + while retry > 0: + try: + package_base = package.split('=')[0] + module = __import__(package) + return True + except ImportError: + subprocess.check_call([sys.executable, '-m', 'pip', 'install', package]) + retry -= 1 + return False diff --git a/vlmeval/vlm/__init__.py b/vlmeval/vlm/__init__.py index 077d70184..4987dca1b 100644 --- a/vlmeval/vlm/__init__.py +++ b/vlmeval/vlm/__init__.py @@ -24,3 +24,4 @@ from .xcomposer2 import XComposer2 from .yi_vl import Yi_VL from .internvl_chat import InternVLChat +from .deepseek_vl import DeepSeekVL diff --git a/vlmeval/vlm/deepseek_vl.py b/vlmeval/vlm/deepseek_vl.py new file mode 100644 index 000000000..b420e9f5a --- /dev/null +++ b/vlmeval/vlm/deepseek_vl.py @@ -0,0 +1,71 @@ +import sys +import torch +from transformers import AutoModelForCausalLM +import warnings +from vlmeval.smp import isimg + + +class DeepSeekVL: + + INSTALL_REQ = True + + def check_install(self): + try: + import deepseek_vl + except ImportError: + warnings.warn( + 'Please first install deepseek_vl from source codes in: https://github.com/deepseek-ai/DeepSeek-VL') + sys.exit(-1) + + def __init__(self, model_path='deepseek-ai/deepseek-vl-1.3b-chat', **kwargs): + self.check_install() + assert model_path is not None + self.model_path = model_path + from deepseek_vl.models import VLChatProcessor + + self.vl_chat_processor = VLChatProcessor.from_pretrained(model_path) + self.tokenizer = self.vl_chat_processor.tokenizer + + model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True) + self.model = model.to(torch.bfloat16).cuda().eval() + + torch.cuda.empty_cache() + default_kwargs = dict(max_new_tokens=512, do_sample=False, use_cache=True) + default_kwargs.update(kwargs) + self.kwargs = default_kwargs + warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ') + + def prepare_inputs(self, msgs): + content, images = '', [] + for s in msgs: + if isimg(s): + images.append(s) + content += '' + else: + content += s + conversation = [ + dict(role='User', content=content, images=images), + dict(role='Assistant', content='') + ] + return conversation + + def interleave_generate(self, ti_list, dataset=None): + conversation = self.prepare_inputs(ti_list) + from deepseek_vl.utils.io import load_pil_images + pil_images = load_pil_images(conversation) + prepare_inputs = self.vl_chat_processor(conversations=conversation, images=pil_images, force_batchify=True) + prepare_inputs = prepare_inputs.to(self.model.device) + inputs_embeds = self.model.prepare_inputs_embeds(**prepare_inputs) + + outputs = self.model.language_model.generate( + inputs_embeds=inputs_embeds, + attention_mask=prepare_inputs.attention_mask, + pad_token_id=self.tokenizer.eos_token_id, + bos_token_id=self.tokenizer.bos_token_id, + eos_token_id=self.tokenizer.eos_token_id, + **self.kwargs) + answer = self.tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True) + return answer + + def generate(self, image_path, prompt, dataset=None): + return self.interleave_generate([image_path, prompt], dataset=dataset)