From 28f1140e503b2832070567103bd1e68bfa4d2919 Mon Sep 17 00:00:00 2001 From: xwjiang2010 <87673679+xwjiang2010@users.noreply.github.com> Date: Tue, 2 Jul 2024 00:57:09 -0700 Subject: [PATCH] [VLM] Remove `image_input_type` from VLM config (#5852) Signed-off-by: Xiaowei Jiang Co-authored-by: Cyrus Leung Co-authored-by: Roger Wang --- .buildkite/download-images.sh | 4 - docs/requirements-docs.txt | 16 +-- .../dev/multimodal/multimodal_index.rst | 8 +- docs/source/models/vlm.rst | 11 +- examples/llava_example.py | 56 ++------ examples/llava_next_example.py | 61 +++++---- examples/openai_vision_api_client.py | 1 - examples/phi3v_example.py | 6 +- tests/conftest.py | 38 ++---- tests/entrypoints/openai/test_vision.py | 2 - tests/models/test_llava.py | 22 ++- tests/models/test_llava_next.py | 23 ++-- tests/models/test_phi3v.py | 21 ++- tests/multimodal/test_mapper.py | 40 +----- tests/spec_decode/e2e/conftest.py | 4 +- tests/tokenization/test_image_processor.py | 20 --- vllm/config.py | 34 +---- vllm/engine/arg_utils.py | 56 +------- vllm/entrypoints/openai/api_server.py | 9 -- vllm/entrypoints/openai/serving_chat.py | 65 ++++----- vllm/inputs/data.py | 11 +- vllm/inputs/registry.py | 7 +- vllm/model_executor/model_loader/loader.py | 5 +- vllm/model_executor/models/clip.py | 20 +-- vllm/model_executor/models/llava.py | 102 +++----------- vllm/model_executor/models/llava_next.py | 126 ++++++------------ vllm/model_executor/models/phi3v.py | 25 ++-- vllm/multimodal/__init__.py | 8 +- vllm/multimodal/base.py | 53 ++++---- vllm/multimodal/image.py | 93 ++----------- vllm/multimodal/registry.py | 96 +++++++------ vllm/multimodal/utils.py | 13 +- vllm/sequence.py | 10 +- vllm/transformers_utils/image_processor.py | 4 - vllm/worker/model_runner.py | 2 +- 35 files changed, 325 insertions(+), 747 deletions(-) delete mode 100644 tests/tokenization/test_image_processor.py diff --git a/.buildkite/download-images.sh b/.buildkite/download-images.sh index 389a12956c3c..360a7584bccf 100644 --- a/.buildkite/download-images.sh +++ b/.buildkite/download-images.sh @@ -8,10 +8,6 @@ set -o pipefail # aws s3 sync s3://air-example-data-2/vllm_opensource_llava/ images/ mkdir -p images cd images -wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/stop_sign_pixel_values.pt -wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/stop_sign_image_features.pt -wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/cherry_blossom_pixel_values.pt -wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/cherry_blossom_image_features.pt wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/stop_sign.jpg wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/cherry_blossom.jpg diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt index ed569816200e..db076b2d801d 100644 --- a/docs/requirements-docs.txt +++ b/docs/requirements-docs.txt @@ -1,13 +1,5 @@ -sphinx == 6.2.1 -sphinx-book-theme == 1.0.1 -sphinx-copybutton == 0.5.2 -myst-parser == 2.0.0 +sphinx==6.2.1 +sphinx-book-theme==1.0.1 +sphinx-copybutton==0.5.2 +myst-parser==2.0.0 sphinx-argparse - -# packages to install to build the documentation -pydantic --f https://download.pytorch.org/whl/cpu -torch -py-cpuinfo -transformers -openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args diff --git a/docs/source/dev/multimodal/multimodal_index.rst b/docs/source/dev/multimodal/multimodal_index.rst index f6fdfc1debff..4d5fb3246b68 100644 --- a/docs/source/dev/multimodal/multimodal_index.rst +++ b/docs/source/dev/multimodal/multimodal_index.rst @@ -9,8 +9,10 @@ vLLM provides experimental support for multi-modal models through the :mod:`vllm which allows you to pass in multi-modal input alongside text and token prompts. By default, vLLM models do not support multi-modal inputs. To enable multi-modal support for a model, -you must decorate the model class with :meth:`MULTIMODAL_REGISTRY.register_dummy_data `, -as well as :meth:`MULTIMODAL_REGISTRY.register_input ` for each modality type to support. +you must decorate the model class with :meth:`InputRegistry.register_dummy_data `, +as well as :meth:`MULTIMODAL_REGISTRY.register_input_mapper ` for each modality type to support. + +# TODO: Add more instructions on how to do that once embeddings is in. Module Contents +++++++++++++++ @@ -29,7 +31,7 @@ Registry Base Classes ------------ -.. autoclass:: vllm.multimodal.MultiModalData +.. autoclass:: vllm.multimodal.MultiModalDataDict :members: :show-inheritance: diff --git a/docs/source/models/vlm.rst b/docs/source/models/vlm.rst index 1837dd2aa89f..053f5b8609ce 100644 --- a/docs/source/models/vlm.rst +++ b/docs/source/models/vlm.rst @@ -36,7 +36,6 @@ To initialize a VLM, the aforementioned arguments must be passed to the ``LLM`` llm = LLM( model="llava-hf/llava-1.5-7b-hf", - image_input_type="pixel_values", image_token_id=32000, image_input_shape="1,3,336,336", image_feature_size=576, @@ -49,7 +48,12 @@ To initialize a VLM, the aforementioned arguments must be passed to the ``LLM`` To pass an image to the model, note the following in :class:`vllm.inputs.PromptStrictInputs`: * ``prompt``: The prompt should have a number of ```` tokens equal to ``image_feature_size``. -* ``multi_modal_data``: This should be an instance of :class:`~vllm.multimodal.image.ImagePixelData` or :class:`~vllm.multimodal.image.ImageFeatureData`. +* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`. + +.. note:: + + ``multi_modal_data`` can accept keys and values beyond the builtin ones, as long as a customized plugin is registered through + :class:`vllm.multimodal.MULTIMODAL_REGISTRY`. .. code-block:: python @@ -61,7 +65,7 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptS outputs = llm.generate({ "prompt": prompt, - "multi_modal_data": ImagePixelData(image), + "multi_modal_data": {"image": image}, }) for o in outputs: @@ -93,7 +97,6 @@ Below is an example on how to launch the same ``llava-hf/llava-1.5-7b-hf`` with python -m vllm.entrypoints.openai.api_server \ --model llava-hf/llava-1.5-7b-hf \ - --image-input-type pixel_values \ --image-token-id 32000 \ --image-input-shape 1,3,336,336 \ --image-feature-size 576 \ diff --git a/examples/llava_example.py b/examples/llava_example.py index 980d7bf9f8a3..7f3d84f99f76 100644 --- a/examples/llava_example.py +++ b/examples/llava_example.py @@ -1,38 +1,32 @@ -import argparse import os import subprocess -import torch from PIL import Image from vllm import LLM -from vllm.multimodal.image import ImageFeatureData, ImagePixelData # The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`. # You can use `.buildkite/download-images.sh` to download them -def run_llava_pixel_values(*, disable_image_processor: bool = False): +def run_llava(): llm = LLM( model="llava-hf/llava-1.5-7b-hf", - image_input_type="pixel_values", image_token_id=32000, image_input_shape="1,3,336,336", image_feature_size=576, - disable_image_processor=disable_image_processor, ) prompt = "" * 576 + ( "\nUSER: What is the content of this image?\nASSISTANT:") - if disable_image_processor: - image = torch.load("images/stop_sign_pixel_values.pt") - else: - image = Image.open("images/stop_sign.jpg") + image = Image.open("images/stop_sign.jpg") outputs = llm.generate({ "prompt": prompt, - "multi_modal_data": ImagePixelData(image), + "multi_modal_data": { + "image": image + }, }) for o in outputs: @@ -40,45 +34,11 @@ def run_llava_pixel_values(*, disable_image_processor: bool = False): print(generated_text) -def run_llava_image_features(): - llm = LLM( - model="llava-hf/llava-1.5-7b-hf", - image_input_type="image_features", - image_token_id=32000, - image_input_shape="1,576,1024", - image_feature_size=576, - ) - - prompt = "" * 576 + ( - "\nUSER: What is the content of this image?\nASSISTANT:") - - image: torch.Tensor = torch.load("images/stop_sign_image_features.pt") - - outputs = llm.generate({ - "prompt": prompt, - "multi_modal_data": ImageFeatureData(image), - }) - - for o in outputs: - generated_text = o.outputs[0].text - print(generated_text) - - -def main(args): - if args.type == "pixel_values": - run_llava_pixel_values() - else: - run_llava_image_features() +def main(): + run_llava() if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Demo on Llava") - parser.add_argument("--type", - type=str, - choices=["pixel_values", "image_features"], - default="pixel_values", - help="image input type") - args = parser.parse_args() # Download from s3 s3_bucket_path = "s3://air-example-data-2/vllm_opensource_llava/" local_directory = "images" @@ -95,4 +55,4 @@ def main(args): local_directory, "--no-sign-request", ]) - main(args) + main() diff --git a/examples/llava_next_example.py b/examples/llava_next_example.py index e90a86abe41c..3c39590e7fb8 100644 --- a/examples/llava_next_example.py +++ b/examples/llava_next_example.py @@ -4,35 +4,44 @@ from PIL import Image from vllm import LLM, SamplingParams -from vllm.multimodal.image import ImagePixelData # Dynamic image input is currently not supported and therefore # a fixed image input shape and its corresponding feature size is required. # See https://github.com/vllm-project/vllm/pull/4199 for the complete # configuration matrix. -llm = LLM( - model="llava-hf/llava-v1.6-mistral-7b-hf", - image_input_type="pixel_values", - image_token_id=32000, - image_input_shape="1,3,336,336", - image_feature_size=1176, -) - -prompt = "[INST] " + "" * 1176 + "\nWhat is shown in this image? [/INST]" -url = "https://h2o-release.s3.amazonaws.com/h2ogpt/bigben.jpg" -image = Image.open(BytesIO(requests.get(url).content)) -sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=100) - -outputs = llm.generate( - { - "prompt": prompt, - "multi_modal_data": ImagePixelData(image), - }, - sampling_params=sampling_params) - -generated_text = "" -for o in outputs: - generated_text += o.outputs[0].text - -print(f"LLM output:{generated_text}") + +def run_llava_next(): + llm = LLM( + model="llava-hf/llava-v1.6-mistral-7b-hf", + image_token_id=32000, + image_input_shape="1,3,336,336", + image_feature_size=1176, + ) + + prompt = "[INST] " + "" * 1176 + ( + "\nWhat is shown in this image? [/INST]") + url = "https://h2o-release.s3.amazonaws.com/h2ogpt/bigben.jpg" + image = Image.open(BytesIO(requests.get(url).content)) + sampling_params = SamplingParams(temperature=0.8, + top_p=0.95, + max_tokens=100) + + outputs = llm.generate( + { + "prompt": prompt, + "multi_modal_data": { + "image": image + } + }, + sampling_params=sampling_params) + + generated_text = "" + for o in outputs: + generated_text += o.outputs[0].text + + print(f"LLM output:{generated_text}") + + +if __name__ == "__main__": + run_llava_next() diff --git a/examples/openai_vision_api_client.py b/examples/openai_vision_api_client.py index 26f2aa651fca..fcda1345f576 100644 --- a/examples/openai_vision_api_client.py +++ b/examples/openai_vision_api_client.py @@ -3,7 +3,6 @@ Launch the vLLM server with the following command: python -m vllm.entrypoints.openai.api_server \ --model llava-hf/llava-1.5-7b-hf \ - --image-input-type pixel_values \ --image-token-id 32000 \ --image-input-shape 1,3,336,336 \ --image-feature-size 576 \ diff --git a/examples/phi3v_example.py b/examples/phi3v_example.py index f0b9b0e4fc95..7d6c58d7fcd8 100644 --- a/examples/phi3v_example.py +++ b/examples/phi3v_example.py @@ -4,7 +4,6 @@ from PIL import Image from vllm import LLM, SamplingParams -from vllm.multimodal.image import ImagePixelData def run_phi3v(): @@ -17,7 +16,6 @@ def run_phi3v(): llm = LLM( model=model_path, trust_remote_code=True, - image_input_type="pixel_values", image_token_id=32044, image_input_shape="1,3,1008,1344", image_feature_size=1921, @@ -35,7 +33,9 @@ def run_phi3v(): outputs = llm.generate( { "prompt": prompt, - "multi_modal_data": ImagePixelData(image), + "multi_modal_data": { + "image": image + }, }, sampling_params=sampling_params) for o in outputs: diff --git a/tests/conftest.py b/tests/conftest.py index ac802d03b1c8..c3bd78263e4d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,19 +17,17 @@ AutoTokenizer, BatchEncoding) from vllm import LLM, SamplingParams -from vllm.config import TokenizerPoolConfig, VisionLanguageConfig +from vllm.config import TokenizerPoolConfig from vllm.distributed import (destroy_distributed_environment, destroy_model_parallel) from vllm.inputs import TextPrompt from vllm.logger import init_logger +from vllm.sequence import SampleLogprobs +from vllm.utils import cuda_device_count_stateless, is_cpu if TYPE_CHECKING: - from vllm.multimodal import MultiModalData -else: # it will call torch.cuda.device_count() - MultiModalData = None -from vllm.sequence import SampleLogprobs -from vllm.utils import cuda_device_count_stateless, is_cpu + from vllm.multimodal import MultiModalDataDict logger = init_logger(__name__) @@ -51,14 +49,6 @@ def _read_prompts(filename: str) -> List[str]: class ImageAsset: name: Literal["stop_sign", "cherry_blossom"] - @cached_property - def pixel_values(self) -> torch.Tensor: - return torch.load(_IMAGE_DIR / f"{self.name}_pixel_values.pt") - - @cached_property - def image_features(self) -> torch.Tensor: - return torch.load(_IMAGE_DIR / f"{self.name}_image_features.pt") - @cached_property def pil_image(self) -> Image.Image: return Image.open(_IMAGE_DIR / f"{self.name}.jpg") @@ -66,20 +56,8 @@ def pil_image(self) -> Image.Image: def for_hf(self) -> Image.Image: return self.pil_image - def for_vllm(self, vision_config: VisionLanguageConfig) -> MultiModalData: - # don't put this import at the top level - # it will call torch.cuda.device_count() - from vllm.multimodal.image import ImageFeatureData # noqa: F401 - from vllm.multimodal.image import ImagePixelData - image_input_type = vision_config.image_input_type - ImageInputType = VisionLanguageConfig.ImageInputType - - if image_input_type == ImageInputType.IMAGE_FEATURES: - return ImageFeatureData(self.image_features) - if image_input_type == ImageInputType.PIXEL_VALUES: - return ImagePixelData(self.pil_image) - - raise NotImplementedError + def for_vllm(self) -> Dict[str, Any]: + return {"image": self.pil_image} class _ImageAssetPrompts(TypedDict): @@ -453,7 +431,7 @@ def generate( self, prompts: List[str], sampling_params: SamplingParams, - images: Optional[List[MultiModalData]] = None, + images: Optional[List["MultiModalDataDict"]] = None, ) -> List[Tuple[List[List[int]], List[str]]]: if images is not None: assert len(prompts) == len(images) @@ -502,7 +480,7 @@ def generate_greedy( self, prompts: List[str], max_tokens: int, - images: Optional[List[MultiModalData]] = None, + images: Optional[List["MultiModalDataDict"]] = None, ) -> List[Tuple[List[int], str]]: greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) outputs = self.generate(prompts, greedy_params, images=images) diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index dbaaa349ad37..a7f7fdae8d16 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -39,8 +39,6 @@ def server(): "--max-model-len", "4096", "--enforce-eager", - "--image-input-type", - "pixel_values", "--image-token-id", "32000", "--image-input-shape", diff --git a/tests/models/test_llava.py b/tests/models/test_llava.py index b4220dc59955..c6313c52e4e3 100644 --- a/tests/models/test_llava.py +++ b/tests/models/test_llava.py @@ -25,17 +25,11 @@ def iter_llava_configs(model_name: str): } for (h, w), f in image_hw_to_feature_size.items(): - for input_type, input_shape in [ - (VisionLanguageConfig.ImageInputType.PIXEL_VALUES, (1, 3, h, w)), - (VisionLanguageConfig.ImageInputType.IMAGE_FEATURES, (1, f, 1024)), - ]: - yield (model_name, - VisionLanguageConfig(image_input_type=input_type, - image_feature_size=f, - image_token_id=32000, - image_input_shape=input_shape, - image_processor=model_name, - image_processor_revision=None)) + input_shape = (1, 3, h, w) + yield (model_name, + VisionLanguageConfig(image_feature_size=f, + image_token_id=32000, + image_input_shape=input_shape)) model_and_vl_config = [ @@ -81,8 +75,8 @@ def run_test( All the image fixtures for the test is under tests/images. For huggingface runner, we provide the PIL images as input. - For vllm runner, we provide MultiModalData objects and corresponding - vision language config as input. + For vllm runner, we provide MultiModalDataDict objects + and corresponding vision language config as input. Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. """ @@ -104,7 +98,7 @@ def run_test( # NOTE: `asset.for_vllm` will call `torch.cuda.device_count()` # we must put it inside the vllm_runner context manager # i.e. after creating vLLM instance. - vllm_images = [asset.for_vllm(vlm_config) for asset in image_assets] + vllm_images = [asset.for_vllm() for asset in image_assets] vllm_image_prompts = [ p.replace("", "" * vlm_config.image_feature_size) diff --git a/tests/models/test_llava_next.py b/tests/models/test_llava_next.py index 940d5035ef3f..e9babba13c47 100644 --- a/tests/models/test_llava_next.py +++ b/tests/models/test_llava_next.py @@ -33,16 +33,13 @@ def iter_llava_next_configs(model_name: str): } for (h, w), f in image_hw_to_feature_size.items(): - for input_type, input_shape in [ - (VisionLanguageConfig.ImageInputType.PIXEL_VALUES, (1, 3, h, w)), - ]: - yield (model_name, - VisionLanguageConfig(image_input_type=input_type, - image_feature_size=f, - image_token_id=32000, - image_input_shape=input_shape, - image_processor=model_name, - image_processor_revision=None)) + input_shape = (1, 3, h, w) + yield (model_name, + VisionLanguageConfig( + image_feature_size=f, + image_token_id=32000, + image_input_shape=input_shape, + )) model_and_vl_config = [ @@ -85,14 +82,14 @@ def test_models(hf_runner, vllm_runner, image_assets, model_and_config, All the image fixtures for the test is under tests/images. For huggingface runner, we provide the PIL images as input. - For vllm runner, we provide MultiModalData objects and corresponding - vision language config as input. + For vllm runner, we provide MultiModalDataDict objects + and corresponding vision language config as input. Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. """ model_id, vlm_config = model_and_config hf_images = [asset.for_hf() for asset in image_assets] - vllm_images = [asset.for_vllm(vlm_config) for asset in image_assets] + vllm_images = [asset.for_vllm() for asset in image_assets] with hf_runner(model_id, dtype=dtype, is_vision_model=True) as hf_model: hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS, diff --git a/tests/models/test_phi3v.py b/tests/models/test_phi3v.py index ba71763f9610..917bdbf94ab9 100644 --- a/tests/models/test_phi3v.py +++ b/tests/models/test_phi3v.py @@ -27,16 +27,11 @@ def iter_phi3v_configs(model_name: str): } for (h, w), f in image_hw_to_feature_size.items(): - for input_type, input_shape in [ - (VisionLanguageConfig.ImageInputType.PIXEL_VALUES, (1, 3, h, w)), - ]: - yield (model_name, - VisionLanguageConfig(image_input_type=input_type, - image_feature_size=f, - image_token_id=32044, - image_input_shape=input_shape, - image_processor=model_name, - image_processor_revision=None)) + input_shape = (1, 3, h, w) + yield (model_name, + VisionLanguageConfig(image_feature_size=f, + image_token_id=32044, + image_input_shape=input_shape)) model_and_vl_config = [ @@ -89,8 +84,8 @@ def run_test( All the image fixtures for the test is under tests/images. For huggingface runner, we provide the PIL images as input. - For vllm runner, we provide MultiModalData objects and corresponding - vision language config as input. + For vllm runner, we provide MultiModalDataDict objects + and corresponding vision language config as input. Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. """ @@ -113,7 +108,7 @@ def run_test( # we must put it inside the vllm_runner context manager # i.e. after creating vLLM instance. - vllm_images = [asset.for_vllm(vlm_config) for asset in image_assets] + vllm_images = [asset.for_vllm() for asset in image_assets] vllm_image_prompts = [ p.replace("<|image_1|>", diff --git a/tests/multimodal/test_mapper.py b/tests/multimodal/test_mapper.py index 2c05b0edb0c4..bdbbd9abfc5c 100644 --- a/tests/multimodal/test_mapper.py +++ b/tests/multimodal/test_mapper.py @@ -2,9 +2,8 @@ import pytest from transformers import CLIPImageProcessor, LlavaNextImageProcessor -from vllm.config import ModelConfig, VisionLanguageConfig +from vllm.config import ModelConfig from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.image import ImagePixelData from ..conftest import _STR_DTYPE_TO_TORCH_DTYPE @@ -12,7 +11,6 @@ @pytest.mark.parametrize("dtype", ["half", "float"]) def test_clip_image_processor(image_assets, dtype): MODEL_NAME = "llava-hf/llava-1.5-7b-hf" - IMAGE_HEIGHT = IMAGE_WIDTH = 560 hf_processor = CLIPImageProcessor.from_pretrained(MODEL_NAME) assert isinstance(hf_processor, CLIPImageProcessor) @@ -25,14 +23,6 @@ def test_clip_image_processor(image_assets, dtype): seed=0, dtype=dtype, revision=None, - multimodal_config=VisionLanguageConfig( - image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES, - image_token_id=32000, - image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH), - image_feature_size=576, - image_processor=MODEL_NAME, - image_processor_revision=None, - ), ) for asset in image_assets: @@ -42,7 +32,7 @@ def test_clip_image_processor(image_assets, dtype): ).to(dtype=_STR_DTYPE_TO_TORCH_DTYPE[dtype]) vllm_result = MULTIMODAL_REGISTRY.map_input( model_config, - ImagePixelData(asset.pil_image), + {"image": asset.pil_image}, ) assert hf_result.keys() == vllm_result.keys() @@ -60,7 +50,6 @@ def test_clip_image_processor(image_assets, dtype): @pytest.mark.parametrize("dtype", ["half", "float"]) def test_llava_next_image_processor(image_assets, dtype): MODEL_NAME = "llava-hf/llava-v1.6-34b-hf" - IMAGE_HEIGHT = IMAGE_WIDTH = 560 hf_processor = LlavaNextImageProcessor.from_pretrained(MODEL_NAME) assert isinstance(hf_processor, LlavaNextImageProcessor) @@ -73,14 +62,6 @@ def test_llava_next_image_processor(image_assets, dtype): seed=0, dtype=dtype, revision=None, - multimodal_config=VisionLanguageConfig( - image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES, - image_token_id=64000, - image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH), - image_feature_size=2928, - image_processor=MODEL_NAME, - image_processor_revision=None, - ), ) for asset in image_assets: @@ -90,7 +71,7 @@ def test_llava_next_image_processor(image_assets, dtype): ).to(dtype=_STR_DTYPE_TO_TORCH_DTYPE[dtype]) vllm_result = MULTIMODAL_REGISTRY.map_input( model_config, - ImagePixelData(asset.pil_image), + {"image": asset.pil_image}, ) assert hf_result.keys() == vllm_result.keys() @@ -107,7 +88,6 @@ def test_llava_next_image_processor(image_assets, dtype): @pytest.mark.parametrize("dtype", ["float"]) def test_image_pixel_types(image_assets, dtype): MODEL_NAME = "llava-hf/llava-1.5-7b-hf" - IMAGE_HEIGHT = IMAGE_WIDTH = 560 model_config = ModelConfig( model=MODEL_NAME, @@ -117,23 +97,15 @@ def test_image_pixel_types(image_assets, dtype): seed=0, dtype=dtype, revision=None, - multimodal_config=VisionLanguageConfig( - image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES, - image_token_id=32000, - image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH), - image_feature_size=576, - image_processor=MODEL_NAME, - image_processor_revision=None, - )) - + ) for asset in image_assets: image_result = MULTIMODAL_REGISTRY.map_input( model_config, - ImagePixelData(asset.pil_image), + {"image": asset.pil_image}, ) tensor_result = MULTIMODAL_REGISTRY.map_input( model_config, - ImagePixelData(asset.pixel_values), + {"image": asset.pil_image}, ) assert image_result.keys() == tensor_result.keys() diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index 60dfe33f2918..8ad8e9cb81ff 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -11,7 +11,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.lora.request import LoRARequest from vllm.model_executor.utils import set_random_seed -from vllm.multimodal import MultiModalData +from vllm.multimodal import MultiModalDataDict from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams from vllm.sequence import Logprob @@ -91,7 +91,7 @@ def generate( prompt_token_ids: Optional[List[List[int]]] = None, use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, + multi_modal_data: Optional[MultiModalDataDict] = None, ) -> List[RequestOutput]: if prompts is None: diff --git a/tests/tokenization/test_image_processor.py b/tests/tokenization/test_image_processor.py deleted file mode 100644 index 5ba232336741..000000000000 --- a/tests/tokenization/test_image_processor.py +++ /dev/null @@ -1,20 +0,0 @@ -import pytest -from transformers.image_processing_utils import BaseImageProcessor - -from vllm.transformers_utils.image_processor import get_image_processor - -IMAGE_PROCESSOR_NAMES = [ - "llava-hf/llava-1.5-7b-hf", - "llava-hf/llava-v1.6-34b-hf", -] - - -@pytest.mark.parametrize("processor_name", IMAGE_PROCESSOR_NAMES) -def test_image_processor_revision(processor_name: str): - # Assume that "main" branch always exists - image_processor = get_image_processor(processor_name, revision="main") - assert isinstance(image_processor, BaseImageProcessor) - - # Assume that "never" branch always does not exist - with pytest.raises(OSError, match='not a valid git identifier'): - get_image_processor(processor_name, revision="never") diff --git a/vllm/config.py b/vllm/config.py index 9854f175065a..b919b212da4f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1250,28 +1250,11 @@ def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig): raise ValueError("LoRA is not supported with chunked prefill yet.") +# TODO: To be replaced by MultiModalConfig. @dataclass class VisionLanguageConfig: """Configs the input data format and how models should run for vision language models.""" - - class ImageInputType(enum.Enum): - """Image input type into the vision language model. - - An image roughly goes through the following transformation: - Raw image --> pixel values --> image features --> image embeddings. - - The difference between different image input types is where the - image encoder (pixel values --> image features) is run. - Different image input types also correspond to different tensor shapes. - - For example, for Llava, PIXEL_VALUES: (1, 3, 336, 336). - IMAGE_FEATURES: (1, 576, 1024). - """ - PIXEL_VALUES = enum.auto() - IMAGE_FEATURES = enum.auto() - - image_input_type: ImageInputType # The input id corresponding to image token. image_token_id: int # Used for running `run_prefill_max_token`. @@ -1279,19 +1262,6 @@ class ImageInputType(enum.Enum): # worst case scenario (biggest supported resolution). image_input_shape: tuple image_feature_size: int - # The image processor to load from HuggingFace - image_processor: Optional[str] - image_processor_revision: Optional[str] - - @classmethod - def get_image_input_enum_type(cls, value: str) -> ImageInputType: - """Get the image input type from a string.""" - try: - return cls.ImageInputType[value.upper()] - except KeyError as e: - raise ValueError(f"{value} is not a valid choice. " - f"Expecting to choose from " - f"{[x.name for x in cls.ImageInputType]}.") from e #TODO(ywang96): make this a cached property once we refactor the # VisionLanguageConfig class. @@ -1318,8 +1288,6 @@ def as_cli_args_dict(self) -> Dict[str, Any]: else: result[f.name] = value - result["disable_image_processor"] = self.image_processor is None - return result diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d4044adfce61..565b9e7791db 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1,7 +1,6 @@ import argparse import dataclasses import json -import warnings from dataclasses import dataclass from typing import List, Optional, Tuple, Union @@ -80,13 +79,9 @@ class EngineArgs: preemption_mode: Optional[str] = None # Related to Vision-language models such as llava - image_input_type: Optional[str] = None image_token_id: Optional[int] = None image_input_shape: Optional[str] = None image_feature_size: Optional[int] = None - image_processor: Optional[str] = None - image_processor_revision: Optional[str] = None - disable_image_processor: bool = False scheduler_delay_factor: float = 0.0 enable_chunked_prefill: bool = False @@ -114,14 +109,6 @@ def __post_init__(self): @staticmethod def add_cli_args_for_vlm( parser: FlexibleArgumentParser) -> FlexibleArgumentParser: - parser.add_argument('--image-input-type', - type=nullable_str, - default=None, - choices=[ - t.name.lower() - for t in VisionLanguageConfig.ImageInputType - ], - help=('The image input type passed into vLLM.')) parser.add_argument('--image-token-id', type=int, default=None, @@ -137,24 +124,6 @@ def add_cli_args_for_vlm( type=int, default=None, help=('The image feature size along the context dimension.')) - parser.add_argument( - '--image-processor', - type=str, - default=EngineArgs.image_processor, - help='Name or path of the huggingface image processor to use. ' - 'If unspecified, model name or path will be used.') - parser.add_argument( - '--image-processor-revision', - type=str, - default=None, - help='Revision of the huggingface image processor version to use. ' - 'It can be a branch name, a tag name, or a commit id. ' - 'If unspecified, will use the default version.') - parser.add_argument( - '--disable-image-processor', - action='store_true', - help='Disables the use of image processor, even if one is defined ' - 'for the model on huggingface.') return parser @@ -679,33 +648,16 @@ def create_engine_config(self, ) -> EngineConfig: raise ValueError( "BitsAndBytes load format and QLoRA adapter only support " f"'bitsandbytes' quantization, but got {self.quantization}") - if self.image_input_type: - if (not self.image_token_id or not self.image_input_shape - or not self.image_feature_size): + if self.image_token_id is not None: + if (not self.image_input_shape or not self.image_feature_size): raise ValueError( - 'Specify `image_token_id`, `image_input_shape` and ' - '`image_feature_size` together with `image_input_type`.') - - if self.image_processor is None: - self.image_processor = self.model - if self.disable_image_processor: - if self.image_processor != self.model: - warnings.warn( - "You've specified an image processor " - f"({self.image_processor}) but also disabled " - "it via `--disable-image-processor`.", - stacklevel=2) - - self.image_processor = None + 'Specify `image_input_shape` and ' + '`image_feature_size` together with `image_token_id`.') vision_language_config = VisionLanguageConfig( - image_input_type=VisionLanguageConfig. - get_image_input_enum_type(self.image_input_type), image_token_id=self.image_token_id, image_input_shape=str_to_int_tuple(self.image_input_shape), image_feature_size=self.image_feature_size, - image_processor=self.image_processor, - image_processor_revision=self.image_processor_revision, ) else: vision_language_config = None diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index a708176c254e..76879c96c31e 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -213,15 +213,6 @@ async def authentication(request: Request, call_next): engine_args = AsyncEngineArgs.from_cli_args(args) - # Enforce pixel values as image input type for vision language models - # when serving with API server - if engine_args.image_input_type is not None and \ - engine_args.image_input_type.upper() != "PIXEL_VALUES": - raise ValueError( - f"Invalid image_input_type: {engine_args.image_input_type}. " - "Only --image-input-type 'pixel_values' is supported for serving " - "vision language models with the vLLM API server.") - engine = AsyncLLMEngine.from_engine_args( engine_args, usage_context=UsageContext.OPENAI_API_SERVER) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 4a960fd7ebe1..e5b6b7f573a2 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -26,7 +26,7 @@ from vllm.logger import init_logger from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) -from vllm.multimodal.image import ImagePixelData +from vllm.multimodal import MultiModalDataDict from vllm.multimodal.utils import (async_get_and_parse_image, get_full_image_text_prompt) from vllm.outputs import RequestOutput @@ -47,7 +47,7 @@ class ConversationMessage(TypedDict): @dataclass(frozen=True) class ChatMessageParseResult: messages: List[ConversationMessage] - image_futures: List[Awaitable[ImagePixelData]] = field( + mm_futures: List[Awaitable[MultiModalDataDict]] = field( default_factory=list) @@ -103,7 +103,7 @@ def _parse_chat_message_content_parts( parts: Iterable[ChatCompletionContentPartParam], ) -> ChatMessageParseResult: texts: List[str] = [] - image_futures: List[Awaitable[ImagePixelData]] = [] + mm_futures: List[Awaitable[MultiModalDataDict]] = [] vlm_config: Optional[VisionLanguageConfig] = getattr( self.engine.engine, "vision_language_config", None) @@ -113,39 +113,34 @@ def _parse_chat_message_content_parts( part_type = part["type"] if part_type == "text": text = cast(ChatCompletionContentPartTextParam, part)["text"] - texts.append(text) elif part_type == "image_url": if vlm_config is None: raise ValueError( "'image_url' input is not supported as the loaded " "model is not multimodal.") + assert self.tokenizer is not None + image_url = cast(ChatCompletionContentPartImageParam, + part)["image_url"] - elif len(image_futures) == 0: - assert self.tokenizer is not None - image_url = cast(ChatCompletionContentPartImageParam, - part)["image_url"] - - if image_url.get("detail", "auto") != "auto": - logger.warning( - "'image_url.detail' is currently not supported and " - "will be ignored.") - - image_future = async_get_and_parse_image(image_url["url"]) - image_futures.append(image_future) + if image_url.get("detail", "auto") != "auto": + logger.warning( + "'image_url.detail' is currently not supported and " + "will be ignored.") - else: - raise NotImplementedError( - "Multiple 'image_url' input is currently not supported." - ) + mm_future = async_get_and_parse_image(image_url["url"]) + mm_futures.append(mm_future) else: raise NotImplementedError(f"Unknown part type: {part_type}") text_prompt = "\n".join(texts) - if vlm_config is not None and len(image_futures): + if vlm_config is not None and len(mm_futures): + assert len( + mm_futures + ) == 1, "Multiple 'image_url' input is currently not supported." (image_token_prompt, image_token_str) = vlm_config.get_image_token_text(self.tokenizer) @@ -171,8 +166,7 @@ def _parse_chat_message_content_parts( else: messages = [ConversationMessage(role=role, content=text_prompt)] - return ChatMessageParseResult(messages=messages, - image_futures=image_futures) + return ChatMessageParseResult(messages=messages, mm_futures=mm_futures) def _parse_chat_message_content( self, @@ -182,10 +176,10 @@ def _parse_chat_message_content( content = message.get("content") if content is None: - return ChatMessageParseResult(messages=[], image_futures=[]) + return ChatMessageParseResult(messages=[], mm_futures=[]) if isinstance(content, str): messages = [ConversationMessage(role=role, content=content)] - return ChatMessageParseResult(messages=messages, image_futures=[]) + return ChatMessageParseResult(messages=messages, mm_futures=[]) return self._parse_chat_message_content_parts(role, content) @@ -210,13 +204,13 @@ async def create_chat_completion( try: conversation: List[ConversationMessage] = [] - image_futures: List[Awaitable[ImagePixelData]] = [] + mm_futures: List[Awaitable[MultiModalDataDict]] = [] for msg in request.messages: chat_parsed_result = self._parse_chat_message_content(msg) conversation.extend(chat_parsed_result.messages) - image_futures.extend(chat_parsed_result.image_futures) + mm_futures.extend(chat_parsed_result.mm_futures) tool_dicts = None if request.tools is None else [ tool.model_dump() for tool in request.tools @@ -235,15 +229,14 @@ async def create_chat_completion( logger.error("Error in applying chat template from request: %s", e) return self.create_error_response(str(e)) - # Fetch image data - image_data: Optional[ImagePixelData] = None + mm_data: Optional[MultiModalDataDict] = None try: - if len(image_futures): - # since we support only single image currently - assert len(image_futures) == 1 - image_data = await image_futures[0] + if len(mm_futures): + # since we support only single mm data currently + assert len(mm_futures) == 1 + mm_data = await mm_futures[0] except Exception as e: - logger.error("Error in loading image data: %s", e) + logger.error("Error in loading multi-modal data: %s", e) return self.create_error_response(str(e)) request_id = f"cmpl-{random_uuid()}" @@ -274,8 +267,8 @@ async def create_chat_completion( "prompt": prompt_text, "prompt_token_ids": prompt_ids, } - if image_data is not None: - inputs["multi_modal_data"] = image_data + if mm_data is not None: + inputs["multi_modal_data"] = mm_data is_tracing_enabled = await self.engine.is_tracing_enabled() trace_headers = None diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 9b163b9cfb66..c6381fcc01e5 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -4,7 +4,7 @@ from typing_extensions import NotRequired if TYPE_CHECKING: - from vllm.multimodal import MultiModalData + from vllm.multimodal import MultiModalDataDict class ParsedText(TypedDict): @@ -72,7 +72,7 @@ class TextPrompt(TypedDict): prompt: str """The input text to be tokenized before passing to the model.""" - multi_modal_data: NotRequired["MultiModalData"] + multi_modal_data: NotRequired["MultiModalDataDict"] """ Optional multi-modal data to pass to the model, if the model supports it. @@ -85,7 +85,7 @@ class TokensPrompt(TypedDict): prompt_token_ids: List[int] """A list of token IDs to pass to the model.""" - multi_modal_data: NotRequired["MultiModalData"] + multi_modal_data: NotRequired["MultiModalDataDict"] """ Optional multi-modal data to pass to the model, if the model supports it. @@ -103,7 +103,7 @@ class TextTokensPrompt(TypedDict): prompt_token_ids: List[int] """The token IDs of the prompt.""" - multi_modal_data: NotRequired["MultiModalData"] + multi_modal_data: NotRequired["MultiModalDataDict"] """ Optional multi-modal data to pass to the model, if the model supports it. @@ -128,7 +128,6 @@ class LLMInputs(TypedDict): The inputs in :class:`~vllm.LLMEngine` before they are passed to the model executor. """ - prompt_token_ids: List[int] """The token IDs of the prompt.""" @@ -137,7 +136,7 @@ class LLMInputs(TypedDict): The original prompt text corresponding to the token IDs, if available. """ - multi_modal_data: NotRequired[Optional["MultiModalData"]] + multi_modal_data: NotRequired[Optional["MultiModalDataDict"]] """ Optional multi-modal data to pass to the model, if the model supports it. diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 8f4e108b8cca..3e28733383cb 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -12,7 +12,7 @@ if TYPE_CHECKING: from vllm.config import ModelConfig, VisionLanguageConfig - from vllm.multimodal import MultiModalData + from vllm.multimodal import MultiModalDataDict from vllm.sequence import SequenceData logger = init_logger(__name__) @@ -66,7 +66,8 @@ def get_hf_config(self, hf_config_type: Type[C]) -> C: N = TypeVar("N", bound=Type[nn.Module]) DummyDataFactory = Callable[[InputContext, int], - Tuple["SequenceData", Optional["MultiModalData"]]] + Tuple["SequenceData", + Optional["MultiModalDataDict"]]] """ Create dummy data to be inputted into the model. @@ -94,7 +95,7 @@ def _default_dummy_data_factory( self, ctx: InputContext, seq_len: int, - ) -> Tuple["SequenceData", Optional["MultiModalData"]]: + ) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: """ The default dummy data factory represents the longest possible text that can be inputted to the model. diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 361dc7322f1b..b61ac7490d1f 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -84,9 +84,8 @@ def _get_model_initialization_kwargs( if supports_vision(model_class): if vlm_config is None: - raise ValueError("Provide `image_input_type` and other vision " - "related configurations through LLM entrypoint " - "or engine arguments.") + raise ValueError("Provide vision related configurations " + "through LLM entrypoint or engine arguments.") extra_kwargs["vlm_config"] = vlm_config diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 77fbade056ee..5212e2808fb3 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -12,7 +12,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.multimodal.image import ImageFeatureData, ImagePixelData from vllm.sequence import SequenceData @@ -49,7 +48,7 @@ def dummy_seq_data_for_clip( return SequenceData(token_ids) -def dummy_pixel_data_for_clip( +def dummy_image_for_clip( hf_config: CLIPVisionConfig, *, image_width_override: Optional[int] = None, @@ -62,22 +61,7 @@ def dummy_pixel_data_for_clip( height = image_height_override image = Image.new("RGB", (width, height), color=0) - return ImagePixelData(image) - - -def dummy_feature_data_for_clip( - hf_config: CLIPVisionConfig, - *, - image_feature_size_override: Optional[int] = None, -): - if image_feature_size_override is None: - image_feature_size = get_clip_image_feature_size(hf_config) - else: - image_feature_size = image_feature_size_override - - values = torch.zeros((1, image_feature_size, hf_config.hidden_size), - dtype=torch.float16) - return ImageFeatureData(values) + return {"image": image} # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index ba4496f9cfac..e0134c5c452d 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -1,4 +1,4 @@ -from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union +from typing import Iterable, List, Literal, Optional, Tuple, TypedDict import torch import torch.nn as nn @@ -17,11 +17,10 @@ from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import SamplerOutput -from .clip import (dummy_feature_data_for_clip, dummy_pixel_data_for_clip, - dummy_seq_data_for_clip) +from .clip import dummy_image_for_clip, dummy_seq_data_for_clip from .interfaces import SupportsVision _KEYS_TO_MODIFY_MAPPING = { @@ -76,17 +75,10 @@ class LlavaImagePixelInputs(TypedDict): """Shape: (batch_size, num_channels, height, width)""" -class LlavaImageFeatureInputs(TypedDict): - type: Literal["image_features"] - data: torch.Tensor - """Shape: (batch_size, image_feature_size, hidden_size)""" - - -LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs] +LlavaImageInputs = LlavaImagePixelInputs def dummy_data_for_llava(ctx: InputContext, seq_len: int): - multimodal_config = ctx.get_multimodal_config() hf_config = ctx.get_hf_config(LlavaConfig) vision_config = hf_config.vision_config @@ -97,22 +89,14 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int): image_token_id=hf_config.image_token_index, ) - image_input_type = multimodal_config.image_input_type - ImageInputType = VisionLanguageConfig.ImageInputType - mm_data: MultiModalData - if image_input_type == ImageInputType.PIXEL_VALUES: - mm_data = dummy_pixel_data_for_clip(vision_config) - elif image_input_type == ImageInputType.IMAGE_FEATURES: - mm_data = dummy_feature_data_for_clip(vision_config) - + mm_data = dummy_image_for_clip(vision_config) return seq_data, mm_data msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) -@MULTIMODAL_REGISTRY.register_image_feature_input_mapper() -@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper() +@MULTIMODAL_REGISTRY.register_image_input_mapper() @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava) class LlavaForConditionalGeneration(nn.Module, SupportsVision): @@ -126,11 +110,8 @@ def __init__(self, self.config = config self.vlm_config = vlm_config - if self.vlm_config.image_input_type == ( - VisionLanguageConfig.ImageInputType.PIXEL_VALUES): - self.vision_tower = CLIPVisionModel(config.vision_config) - else: - self.vision_tower = None + # TODO: Optionally initializes this for supporting embeddings. + self.vision_tower = CLIPVisionModel(config.vision_config) self.multi_modal_projector = LlavaMultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, @@ -165,44 +146,18 @@ def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor: def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[LlavaImageInputs]: pixel_values = kwargs.pop("pixel_values", None) - image_features = kwargs.pop("image_features", None) - - expected_input_type = self.vlm_config.image_input_type - ImageInputType = VisionLanguageConfig.ImageInputType - - if expected_input_type == ImageInputType.PIXEL_VALUES: - if image_features is not None: - raise ValueError( - "Expected pixel values but got image features") - if pixel_values is None: - return None - - if not isinstance(pixel_values, torch.Tensor): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") - return LlavaImagePixelInputs( - type="pixel_values", - data=self._validate_image_data(pixel_values), - ) + if pixel_values is None: + return None - if expected_input_type == ImageInputType.IMAGE_FEATURES: - if pixel_values is not None: - raise ValueError( - "Expected image features but got pixel values") - if image_features is None: - return None + if not isinstance(pixel_values, torch.Tensor): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") - if not isinstance(image_features, torch.Tensor): - raise ValueError("Incorrect type of image features. " - f"Got type: {type(image_features)}") - - return LlavaImageFeatureInputs( - type="image_features", - data=self._validate_image_data(image_features), - ) - - return None + return LlavaImagePixelInputs( + type="pixel_values", + data=self._validate_image_data(pixel_values), + ) def _select_image_features(self, image_features: torch.Tensor, *, strategy: str) -> torch.Tensor: @@ -237,12 +192,8 @@ def _process_image_pixels(self, def _process_image_input(self, image_input: LlavaImageInputs) -> torch.Tensor: - if image_input["type"] == "pixel_values": - assert self.vision_tower is not None - image_features = self._process_image_pixels(image_input) - else: - image_features = image_input["data"] - + assert self.vision_tower is not None + image_features = self._process_image_pixels(image_input) return self.multi_modal_projector(image_features) def forward( @@ -273,25 +224,10 @@ def forward( This way, the `positions` and `attn_metadata` are consistent with the `input_ids`. - This model has two modes of image inputs: - `PIXEL_VALUES` and `IMAGE_FEATURES`. - Args: input_ids: Flattened (concatenated) input_ids corresponding to a batch. pixel_values: The pixels in each input image. - Expects a batch with shape `[1, 3, 336, 336]`. - (Only applicable to `PIXEL_VALUES` mode) - image_features: The image features for each input image outputted by - the vision tower before passing to the multi-modal projector. - Expects a batch with shape `[1, 576, 1024]`. - (Only applicable to `IMAGE_FEATURES` mode) - - See also: - Each input maps to huggingface implementation, as follows: - - - `pixel_values`: https://github.com/huggingface/transformers/blob/v4.41.1/src/transformers/models/llava/modeling_llava.py#L360 - - `image_features`: https://github.com/huggingface/transformers/blob/v4.41.1/src/transformers/models/llava/modeling_llava.py#L437 """ image_input = self._parse_and_validate_image_input(**kwargs) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 281431074671..3c0988137f7c 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -1,8 +1,8 @@ -from typing import (Dict, Iterable, List, Literal, Optional, Tuple, TypedDict, - Union) +from typing import Dict, Iterable, List, Literal, Optional, Tuple, TypedDict import torch import torch.nn as nn +from PIL import Image from transformers import CLIPVisionConfig, LlavaNextConfig from transformers.models.llava_next.modeling_llava_next import ( get_anyres_image_grid_shape, unpad_image) @@ -21,12 +21,11 @@ from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData -from vllm.multimodal.image import ImagePixelData +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import SamplerOutput -from .clip import (dummy_feature_data_for_clip, dummy_pixel_data_for_clip, - dummy_seq_data_for_clip, get_clip_patch_grid_length) +from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, + get_clip_patch_grid_length) from .interfaces import SupportsVision from .llava import LlavaMultiModalProjector, merge_vision_embeddings @@ -47,17 +46,7 @@ class LlavaNextImagePixelInputs(TypedDict): """Shape: (batch_size, 2)""" -class LlavaNextImageFeatureInputs(TypedDict): - type: Literal["image_features"] - data: torch.Tensor - """Shape: (batch_size, 1 + num_patches, image_feature_size, hidden_size)""" - - image_sizes: NotRequired[torch.Tensor] - """Shape: (batch_size, 2)""" - - -LlavaNextImageInputs = Union[LlavaNextImagePixelInputs, - LlavaNextImageFeatureInputs] +LlavaNextImageInputs = LlavaNextImagePixelInputs def _get_llava_next_num_unpadded_features( @@ -138,20 +127,11 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int): image_feature_size_override=image_feature_size, ) - image_input_type = multimodal_config.image_input_type - ImageInputType = VisionLanguageConfig.ImageInputType - mm_data: MultiModalData - if image_input_type == ImageInputType.PIXEL_VALUES: - mm_data = dummy_pixel_data_for_clip( - vision_config, - image_width_override=dummy_width, - image_height_override=dummy_height, - ) - elif image_input_type == ImageInputType.IMAGE_FEATURES: - mm_data = dummy_feature_data_for_clip( - vision_config, - image_feature_size_override=image_feature_size, - ) + mm_data = dummy_image_for_clip( + vision_config, + image_width_override=dummy_width, + image_height_override=dummy_height, + ) return seq_data, mm_data @@ -159,32 +139,26 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int): raise NotImplementedError(msg) -def _pixel_mapper(ctx: InputContext, - data: ImagePixelData) -> Dict[str, torch.Tensor]: - image = data.image +def _pixel_mapper(ctx: InputContext, image: object) -> Dict[str, torch.Tensor]: - if isinstance(image, torch.Tensor): - pixel_values = image.to(ctx.model_config.dtype) - batch_size, _, _, h, w = pixel_values.shape - image_sizes = torch.tensor([(w, h) for _ in range(batch_size)]) + if isinstance(image, Image.Image): - return {"pixel_values": pixel_values, "image_sizes": image_sizes} + # Temporary patch before dynamic number of image tokens is supported + _, _, h, w = ctx.get_multimodal_config().image_input_shape + if (w, h) != (image.width, image.height): + logger.warning( + "Dynamic image shape is currently not supported. " + "Resizing input image to (%d, %d).", w, h) - # Temporary patch before dynamic number of image tokens is supported - _, _, h, w = ctx.get_multimodal_config().image_input_shape - if (w, h) != (image.width, image.height): - logger.warning( - "Dynamic image shape is currently not supported. " - "Resizing input image to (%d, %d).", w, h) + image = image.resize((w, h)) - data.image = image.resize((w, h)) + return MULTIMODAL_REGISTRY._get_plugin("image") \ + ._default_input_mapper(ctx, image) - return MULTIMODAL_REGISTRY._get_plugin_for_data_type(ImagePixelData) \ - ._default_input_mapper(ctx, data) + raise TypeError(f"Invalid type for 'image': {type(image)}") -@MULTIMODAL_REGISTRY.register_image_feature_input_mapper() -@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper(_pixel_mapper) +@MULTIMODAL_REGISTRY.register_image_input_mapper(_pixel_mapper) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next) class LlavaNextForConditionalGeneration(nn.Module, SupportsVision): @@ -198,11 +172,7 @@ def __init__(self, self.config = config self.vlm_config = vlm_config - if self.vlm_config.image_input_type == ( - VisionLanguageConfig.ImageInputType.PIXEL_VALUES): - self.vision_tower = CLIPVisionModel(config=config.vision_config) - else: - raise TypeError("Image features are not supported by LLaVA-NeXT") + self.vision_tower = CLIPVisionModel(config=config.vision_config) self.multi_modal_projector = LlavaMultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, @@ -255,36 +225,23 @@ def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[LlavaNextImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_sizes = kwargs.pop("image_sizes", None) - image_features = kwargs.pop("image_features", None) - - expected_input_type = self.vlm_config.image_input_type - ImageInputType = VisionLanguageConfig.ImageInputType - - if expected_input_type == ImageInputType.PIXEL_VALUES: - if image_features is not None: - raise ValueError( - "Expected pixel values but got image features") - if pixel_values is None: - return None - if not isinstance(pixel_values, torch.Tensor): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") + if pixel_values is None or image_sizes is None: + return None - if not isinstance(image_sizes, torch.Tensor): - raise ValueError("Incorrect type of image sizes. " - f"Got type: {type(image_sizes)}") + if not isinstance(pixel_values, torch.Tensor): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") - return LlavaNextImagePixelInputs( - type="pixel_values", - data=self._validate_image_pixels(pixel_values), - image_sizes=self._validate_image_sizes(image_sizes), - ) + if not isinstance(image_sizes, torch.Tensor): + raise ValueError("Incorrect type of image sizes. " + f"Got type: {type(image_sizes)}") - assert expected_input_type != ImageInputType.IMAGE_FEATURES, ( - "Failed to validate this at initialization time") - - return None + return LlavaNextImagePixelInputs( + type="pixel_values", + data=self._validate_image_pixels(pixel_values), + image_sizes=self._validate_image_sizes(image_sizes), + ) def _select_image_features(self, image_features: torch.Tensor, *, strategy: str) -> torch.Tensor: @@ -391,11 +348,8 @@ def _process_image_pixels( def _process_image_input( self, image_input: LlavaNextImageInputs) -> torch.Tensor: - if image_input["type"] == "pixel_values": - assert self.vision_tower is not None - image_features = self._process_image_pixels(image_input) - else: - image_features = image_input["data"] + assert self.vision_tower is not None + image_features = self._process_image_pixels(image_input) patch_embeddings = self.multi_modal_projector(image_features) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index bc3d3f0fbf19..a16f7f0ea570 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -35,10 +35,9 @@ from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.image import ImagePixelData from vllm.sequence import SamplerOutput -from .clip import dummy_pixel_data_for_clip, dummy_seq_data_for_clip +from .clip import dummy_image_for_clip, dummy_seq_data_for_clip from .interfaces import SupportsVision logger = init_logger(__name__) @@ -286,7 +285,7 @@ def dummy_data_for_phi3v(ctx: InputContext, seq_len: int): image_token_id=32044, image_feature_size_override=image_feature_size, ) - mm_data = dummy_pixel_data_for_clip( + mm_data = dummy_image_for_clip( CLIP_VIT_LARGE_PATCH14_336_CONFIG, image_width_override=dummy_width, image_height_override=dummy_height, @@ -331,8 +330,7 @@ def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16): def _image_processor(ctx: InputContext, - data: ImagePixelData) -> Dict[str, torch.Tensor]: - image = data.image + image: object) -> Dict[str, torch.Tensor]: if isinstance(image, Image.Image): # Temporary patch before dynamic number of image tokens is supported @@ -343,13 +341,14 @@ def _image_processor(ctx: InputContext, "Dynamic image shape is currently not supported. " "Resizing input image to (%d, %d).", w, h) - data.image = image.resize((w, h)) + image = image.resize((w, h)) - return MULTIMODAL_REGISTRY._get_plugin_for_data_type(ImagePixelData) \ - ._default_input_mapper(ctx, data) + return MULTIMODAL_REGISTRY._get_plugin("image") \ + ._default_input_mapper(ctx, image) + raise TypeError(f"Invalid type for 'image': {type(image)}") -@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper(_image_processor) +@MULTIMODAL_REGISTRY.register_image_input_mapper(_image_processor) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v) class Phi3VForCausalLM(nn.Module, SupportsVision): @@ -375,14 +374,6 @@ def _parse_and_validate_image_input( pixel_values = kwargs.pop("pixel_values", None) image_sizes = kwargs.pop("image_sizes", None) - expected_input_type = self.vlm_config.image_input_type - ImageInputType = VisionLanguageConfig.ImageInputType - - if expected_input_type != ImageInputType.PIXEL_VALUES: - raise ValueError( - f"Unexpected image input type: {expected_input_type}." - "Phi3v only support pixel_values input currently.") - if pixel_values is not None and image_sizes is not None: return Phi3VImagePixelInputs(type="pixel_values", data=pixel_values, diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py index 20bd87b8c443..256eadd2d7df 100644 --- a/vllm/multimodal/__init__.py +++ b/vllm/multimodal/__init__.py @@ -1,4 +1,4 @@ -from .base import MultiModalData, MultiModalPlugin +from .base import MultiModalDataDict, MultiModalPlugin from .registry import MultiModalRegistry MULTIMODAL_REGISTRY = MultiModalRegistry() @@ -11,6 +11,8 @@ """ __all__ = [ - "MultiModalData", "MultiModalPlugin", "MULTIMODAL_REGISTRY", - "MultiModalRegistry" + "MultiModalPlugin", + "MULTIMODAL_REGISTRY", + "MultiModalRegistry", + "MultiModalDataDict", ] diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index d47cdd559ad8..558cd1175298 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod -from typing import (TYPE_CHECKING, Callable, Dict, Generic, Optional, Type, - TypeVar) +from typing import (TYPE_CHECKING, Any, Callable, Dict, Optional, Type, + TypedDict, TypeVar, Union) from vllm.config import ModelConfig from vllm.inputs import InputContext @@ -8,38 +8,35 @@ if TYPE_CHECKING: import torch + from PIL import Image from torch import nn logger = init_logger(__name__) +N = TypeVar("N", bound=Type["nn.Module"]) -class MultiModalData: - """ - Base class that contains multi-modal data. - - To add a new modality, add a new file under ``multimodal`` directory. - In this new file, subclass :class:`~MultiModalData` and - :class:`~MultiModalPlugin`. +class MultiModalDataBuiltins(TypedDict, total=False): + image: "Image.Image" - Finally, register the new plugin to - :const:`vllm.multimodal.MULTIMODAL_REGISTRY`. - This enables models to call :meth:`MultiModalRegistry.map_input` for - the new modality. - """ - pass +MultiModalDataDict = Union[MultiModalDataBuiltins, Dict[str, Any]] +""" +A dictionary containing an item for each modality type to input. -D = TypeVar("D", bound=MultiModalData) -N = TypeVar("N", bound=Type["nn.Module"]) +The data belonging to each modality is converted into keyword arguments +to the model by the corresponding mapper. By default, the mapper of +the corresponding plugin with the same modality key is applied. +""" -MultiModalInputMapper = Callable[[InputContext, D], Dict[str, "torch.Tensor"]] +MultiModalInputMapper = Callable[[InputContext, object], Dict[str, + "torch.Tensor"]] """Return a dictionary to be passed as keyword arguments to :meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers and processors in HuggingFace Transformers.""" -class MultiModalPlugin(ABC, Generic[D]): +class MultiModalPlugin(ABC): """ Base class that defines data processing logic for a specific modality. @@ -52,19 +49,18 @@ class MultiModalPlugin(ABC, Generic[D]): def __init__(self) -> None: self._input_mappers: Dict[Type["nn.Module"], - MultiModalInputMapper[D]] = {} + MultiModalInputMapper] = {} @abstractmethod - def get_data_type(self) -> Type[D]: + def get_data_key(self) -> str: """ - Get the modality (subclass of :class:`~MultiModalData`) served by - this plugin. + Get the data key corresponding to the modality. """ raise NotImplementedError @abstractmethod def _default_input_mapper(self, ctx: InputContext, - data: D) -> Dict[str, "torch.Tensor"]: + data: object) -> Dict[str, "torch.Tensor"]: """Return a dictionary to be passed as keyword arguments to :meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers and processors in HuggingFace Transformers. @@ -73,11 +69,10 @@ def _default_input_mapper(self, ctx: InputContext, def register_input_mapper( self, - mapper: Optional[MultiModalInputMapper[D]] = None, + mapper: Optional[MultiModalInputMapper] = None, ): """ Register an input mapper to a model class. - When the model receives input data that matches the modality served by this plugin (see :meth:`get_data_type`), the provided function is invoked to transform the data into a dictionary of model inputs. @@ -102,11 +97,13 @@ def wrapper(model_cls: N) -> N: return wrapper def map_input(self, model_config: ModelConfig, - data: D) -> Dict[str, "torch.Tensor"]: + data: object) -> Dict[str, "torch.Tensor"]: """ - Apply an input mapper to a :class:`~MultiModalData` instance passed + Apply an input mapper to a data passed to the model, transforming the data into a dictionary of model inputs. + If the data is not something that the mapper expects, throws TypeError. + The model is identified by ``model_config``. TODO: Add guide [ref: PR #5276] diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index a9691575c2ea..a0b4206bf2ee 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -1,5 +1,5 @@ from functools import lru_cache -from typing import Dict, Type, Union +from typing import Dict import torch from PIL import Image @@ -9,105 +9,36 @@ from vllm.logger import init_logger from vllm.transformers_utils.image_processor import get_image_processor -from .base import MultiModalData, MultiModalPlugin +from .base import MultiModalPlugin logger = init_logger(__name__) cached_get_image_processor = lru_cache(get_image_processor) -class ImagePixelData(MultiModalData): - """ - The pixel data of an image. Can be one of: +class ImagePlugin(MultiModalPlugin): - - :class:`PIL.Image.Image`: An image object. Requires that a HuggingFace - processor is available to the model. - - :class:`torch.Tensor`: The raw pixel data which is passed to the model - without additional pre-processing. - """ - - def __init__(self, image: Union[Image.Image, torch.Tensor]) -> None: - if isinstance(image, Image.Image): - # So that this class can be created inside the Image context manager - image.load() - - self.image = image - - def __repr__(self) -> str: - image = self.image - if isinstance(image, Image.Image): - return f"{type(self).__name__}(image={image})" - - return (f"{type(self).__name__}(image=torch.Tensor(shape=" - f"{image.shape}, dtype={image.dtype}))") - - -class ImagePixelPlugin(MultiModalPlugin[ImagePixelData]): - - def get_data_type(self) -> Type[ImagePixelData]: - return ImagePixelData + def get_data_key(self) -> str: + return "image" def _get_hf_image_processor(self, model_config: ModelConfig): - vlm_config = model_config.multimodal_config - if vlm_config is None or vlm_config.image_processor is None: - return None - return cached_get_image_processor( - vlm_config.image_processor, - trust_remote_code=model_config.trust_remote_code, - revision=vlm_config.image_processor_revision, - ) + model_config.model, + trust_remote_code=model_config.trust_remote_code) def _default_input_mapper(self, ctx: InputContext, - data: ImagePixelData) -> Dict[str, torch.Tensor]: + data: object) -> Dict[str, torch.Tensor]: model_config = ctx.model_config - image = data.image - - if isinstance(image, Image.Image): + if isinstance(data, Image.Image): image_processor = self._get_hf_image_processor(model_config) if image_processor is None: raise RuntimeError("No HuggingFace processor is available" "to process the image object") try: - return image_processor.preprocess(image, return_tensors="pt") \ + return image_processor.preprocess(data, return_tensors="pt") \ .to(model_config.dtype).data except Exception: - logger.error("Failed to process image (%s)", image) + logger.error("Failed to process image (%s)", data) raise - elif isinstance(image, torch.Tensor): - pixel_values = image.to(model_config.dtype) - - return {"pixel_values": pixel_values} - - raise TypeError(f"Invalid image type: {type(image)}") - - -class ImageFeatureData(MultiModalData): - """ - The feature vector of an image, passed directly to the model. - - This should be the output of the vision tower. - """ - - def __init__(self, image_features: torch.Tensor) -> None: - self.image_features = image_features - - def __repr__(self) -> str: - image_features = self.image_features - - return (f"{type(self).__name__}(image_features=torch.Tensor(shape=" - f"{image_features.shape}, dtype={image_features.dtype}))") - - -class ImageFeaturePlugin(MultiModalPlugin[ImageFeatureData]): - - def get_data_type(self) -> Type[ImageFeatureData]: - return ImageFeatureData - - def _default_input_mapper( - self, ctx: InputContext, - data: ImageFeatureData) -> Dict[str, torch.Tensor]: - model_config = ctx.model_config - image_features = data.image_features.to(model_config.dtype) - return {"image_features": image_features} + raise TypeError(f"Invalid type for 'image': {type(data)}") diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index abc88e4f9a9d..a09a80f89f4b 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -1,18 +1,16 @@ import functools -from typing import Any, Optional, Sequence, Type, TypeVar +from typing import Optional, Sequence, Type, TypeVar from torch import nn from vllm.config import ModelConfig from vllm.logger import init_logger -from .base import MultiModalData, MultiModalInputMapper, MultiModalPlugin -from .image import (ImageFeatureData, ImageFeaturePlugin, ImagePixelData, - ImagePixelPlugin) +from .base import MultiModalDataDict, MultiModalInputMapper, MultiModalPlugin +from .image import ImagePlugin logger = init_logger(__name__) -D = TypeVar("D", bound=MultiModalData) N = TypeVar("N", bound=Type[nn.Module]) @@ -20,81 +18,91 @@ class MultiModalRegistry: """ A registry to dispatch data processing according to its modality and the target model. + + The registry handles both external and internal data input. """ - DEFAULT_PLUGINS = (ImageFeaturePlugin(), ImagePixelPlugin()) + DEFAULT_PLUGINS = (ImagePlugin(), ) def __init__( - self, - *, - plugins: Sequence[MultiModalPlugin[Any]] = DEFAULT_PLUGINS, - ) -> None: - self._plugins_by_data_type = {p.get_data_type(): p for p in plugins} + self, + *, + plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None: + self._plugins = {p.get_data_key(): p for p in plugins} - def register_plugin(self, plugin: MultiModalPlugin[Any]) -> None: - data_type = plugin.get_data_type() + def register_plugin(self, plugin: MultiModalPlugin) -> None: + data_type_key = plugin.get_data_key() - if data_type in self._plugins_by_data_type: + if data_type_key in self._plugins: logger.warning( "A plugin is already registered for data type %s, " - "and will be overwritten by the new plugin %s.", data_type, + "and will be overwritten by the new plugin %s.", data_type_key, plugin) - self._plugins_by_data_type[data_type] = plugin + self._plugins[data_type_key] = plugin - def _get_plugin_for_data_type(self, data_type: Type[MultiModalData]): - for typ in data_type.mro(): - plugin = self._plugins_by_data_type.get(typ) - if plugin is not None: - return plugin + def _get_plugin(self, data_type_key: str): + plugin = self._plugins.get(data_type_key) + if plugin is not None: + return plugin - msg = f"Unknown multi-modal data type: {data_type}" + msg = f"Unknown multi-modal data type: {data_type_key}" raise NotImplementedError(msg) - def register_input_mapper( + def register_image_input_mapper( self, - data_type: Type[D], - mapper: Optional[MultiModalInputMapper[D]] = None, + mapper: Optional[MultiModalInputMapper] = None, ): """ - Register an input mapper for a specific modality to a model class. + Register an input mapper for image data to a model class. See :meth:`MultiModalPlugin.register_input_mapper` for more details. """ - return self._get_plugin_for_data_type(data_type) \ - .register_input_mapper(mapper) + return self.register_input_mapper("image", mapper) + + def _process_input(self, key: str, value: object, + model_config: ModelConfig): + plugin = self._plugins.get(key) + if plugin: + return plugin.map_input(model_config, value) + msg = f"Unknown multi-modal data type: {key}" + raise NotImplementedError(msg) - def register_image_pixel_input_mapper( + def register_input_mapper( self, - mapper: Optional[MultiModalInputMapper[ImagePixelData]] = None, + data_type: str, + mapper: Optional[MultiModalInputMapper] = None, ): """ - Register an input mapper for image pixel data to a model class. + Register an input mapper for a specific modality to a model class. See :meth:`MultiModalPlugin.register_input_mapper` for more details. """ - return self.register_input_mapper(ImagePixelData, mapper) - - def register_image_feature_input_mapper( - self, - mapper: Optional[MultiModalInputMapper[ImageFeatureData]] = None, - ): + plugin = self._plugins.get(data_type) + if not plugin: + msg = f"Unknown multi-modal data type: {data_type}" + raise NotImplementedError(msg) + return plugin.register_input_mapper(mapper) + + def register_image_input(self, + mapper: Optional[MultiModalInputMapper] = None): """ - Register an input mapper for image feature data to a model class. + Register an input mapper for image pixel data to a model class. See :meth:`MultiModalPlugin.register_input_mapper` for more details. """ - return self.register_input_mapper(ImageFeatureData, mapper) + return self.register_input_mapper("image", mapper) - def map_input(self, model_config: ModelConfig, data: MultiModalData): + def map_input(self, model_config: ModelConfig, data: MultiModalDataDict): """ - Apply an input mapper to a :class:`~MultiModalData` instance passed - to the model. + Apply an input mapper to the data passed to the model. See :meth:`MultiModalPlugin.map_input` for more details. """ - return self._get_plugin_for_data_type(type(data)) \ - .map_input(model_config, data) + result_list = [ + self._process_input(k, v, model_config) for k, v in data.items() + ] + return {k: v for d in result_list for k, v in d.items()} def create_input_mapper(self, model_config: ModelConfig): """ diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 0cf2c057f892..321b51e5a883 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -8,7 +8,7 @@ from vllm.config import ModelConfig from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT -from vllm.multimodal.image import ImagePixelData +from vllm.multimodal.base import MultiModalDataDict class ImageFetchAiohttp: @@ -53,14 +53,10 @@ async def fetch_image(cls, image_url: str) -> Image.Image: "Invalid 'image_url': A valid 'image_url' must start " "with either 'data:image' or 'http'.") + image.load() return image -async def async_get_and_parse_image(image_url: str) -> ImagePixelData: - with await ImageFetchAiohttp.fetch_image(image_url) as image: - return ImagePixelData(image) - - def encode_image_base64(image: Image.Image, format: str = 'JPEG') -> str: """Encode a pillow image to base64 format.""" @@ -91,3 +87,8 @@ def get_full_image_text_prompt(image_prompt: str, text_prompt: str, raise ValueError( f"Unsupported model type: {config.hf_config.model_type}") return full_prompt + + +async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict: + image = await ImageFetchAiohttp.fetch_image(image_url) + return {"image": image} diff --git a/vllm/sequence.py b/vllm/sequence.py index 21c558d4483d..3e7c31b8c1a8 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -14,7 +14,7 @@ if TYPE_CHECKING: from vllm.inputs import LLMInputs - from vllm.multimodal import MultiModalData + from vllm.multimodal import MultiModalDataDict from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics @@ -280,8 +280,8 @@ def prompt_token_ids(self) -> List[int]: return self.inputs["prompt_token_ids"] @property - def multi_modal_data(self) -> Optional["MultiModalData"]: - return self.inputs.get("multi_modal_data") + def multi_modal_data(self) -> "MultiModalDataDict": + return self.inputs.get("multi_modal_data") or {} @property def lora_int_id(self) -> int: @@ -457,7 +457,7 @@ def prompt_token_ids(self) -> List[int]: return next(iter(self.seqs_dict.values())).prompt_token_ids @property - def multi_modal_data(self) -> Optional["MultiModalData"]: + def multi_modal_data(self) -> Optional["MultiModalDataDict"]: # All sequences in the group should have the same multi-modal data. # We use the multi-modal data of an arbitrary sequence. return next(iter(self.seqs_dict.values())).multi_modal_data @@ -639,7 +639,7 @@ def __init__( lora_request: Optional[LoRARequest] = None, computed_block_nums: Optional[List[int]] = None, state: Optional[SequenceGroupState] = None, - multi_modal_data: Optional["MultiModalData"] = None, + multi_modal_data: Optional["MultiModalDataDict"] = None, encoder_seq_data: Optional[SequenceData] = None, cross_block_table: Optional[List[int]] = None, ) -> None: diff --git a/vllm/transformers_utils/image_processor.py b/vllm/transformers_utils/image_processor.py index 2bb5215d4846..354dcb526395 100644 --- a/vllm/transformers_utils/image_processor.py +++ b/vllm/transformers_utils/image_processor.py @@ -1,5 +1,3 @@ -from typing import Optional - from transformers import AutoImageProcessor from transformers.image_processing_utils import BaseImageProcessor @@ -12,7 +10,6 @@ def get_image_processor( processor_name: str, *args, trust_remote_code: bool = False, - revision: Optional[str] = None, **kwargs, ) -> BaseImageProcessor: """Gets an image processor for the given model name via HuggingFace.""" @@ -21,7 +18,6 @@ def get_image_processor( processor_name, *args, trust_remote_code=trust_remote_code, - revision=revision, **kwargs) except ValueError as e: # If the error pertains to the processor class not existing or not diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 942063677a42..0b20d5010d5e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -504,7 +504,7 @@ def _prepare_model_input_tensors( is not None else 1)) mm_data = seq_group_metadata.multi_modal_data - if mm_data is not None: + if mm_data: # Process multi-modal data mm_kwargs = self.multi_modal_input_mapper(mm_data) for k, v in mm_kwargs.items():