Skip to content

Commit

Permalink
[VLM] Remove image_input_type from VLM config (#5852)
Browse files Browse the repository at this point in the history
Signed-off-by: Xiaowei Jiang <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Co-authored-by: Roger Wang <[email protected]>
  • Loading branch information
3 people committed Jul 2, 2024
1 parent 2c37540 commit 98d6682
Show file tree
Hide file tree
Showing 35 changed files with 325 additions and 747 deletions.
4 changes: 0 additions & 4 deletions .buildkite/download-images.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@ set -o pipefail
# aws s3 sync s3://air-example-data-2/vllm_opensource_llava/ images/
mkdir -p images
cd images
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/stop_sign_pixel_values.pt
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/stop_sign_image_features.pt
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/cherry_blossom_pixel_values.pt
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/cherry_blossom_image_features.pt
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/stop_sign.jpg
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/cherry_blossom.jpg

Expand Down
16 changes: 4 additions & 12 deletions docs/requirements-docs.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,5 @@
sphinx == 6.2.1
sphinx-book-theme == 1.0.1
sphinx-copybutton == 0.5.2
myst-parser == 2.0.0
sphinx==6.2.1
sphinx-book-theme==1.0.1
sphinx-copybutton==0.5.2
myst-parser==2.0.0
sphinx-argparse

# packages to install to build the documentation
pydantic
-f https://download.pytorch.org/whl/cpu
torch
py-cpuinfo
transformers
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
8 changes: 5 additions & 3 deletions docs/source/dev/multimodal/multimodal_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ vLLM provides experimental support for multi-modal models through the :mod:`vllm
which allows you to pass in multi-modal input alongside text and token prompts.

By default, vLLM models do not support multi-modal inputs. To enable multi-modal support for a model,
you must decorate the model class with :meth:`MULTIMODAL_REGISTRY.register_dummy_data <MultiModalRegistry.register_dummy_data>`,
as well as :meth:`MULTIMODAL_REGISTRY.register_input <MultiModalRegistry.register_input>` for each modality type to support.
you must decorate the model class with :meth:`InputRegistry.register_dummy_data <vllm.inputs.registry.InputRegistry.register_dummy_data>`,
as well as :meth:`MULTIMODAL_REGISTRY.register_input_mapper <MultiModalRegistry.register_input_mapper>` for each modality type to support.

# TODO: Add more instructions on how to do that once embeddings is in.

Module Contents
+++++++++++++++
Expand All @@ -29,7 +31,7 @@ Registry
Base Classes
------------

.. autoclass:: vllm.multimodal.MultiModalData
.. autoclass:: vllm.multimodal.MultiModalDataDict
:members:
:show-inheritance:

Expand Down
11 changes: 7 additions & 4 deletions docs/source/models/vlm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ To initialize a VLM, the aforementioned arguments must be passed to the ``LLM``
llm = LLM(
model="llava-hf/llava-1.5-7b-hf",
image_input_type="pixel_values",
image_token_id=32000,
image_input_shape="1,3,336,336",
image_feature_size=576,
Expand All @@ -49,7 +48,12 @@ To initialize a VLM, the aforementioned arguments must be passed to the ``LLM``
To pass an image to the model, note the following in :class:`vllm.inputs.PromptStrictInputs`:

* ``prompt``: The prompt should have a number of ``<image>`` tokens equal to ``image_feature_size``.
* ``multi_modal_data``: This should be an instance of :class:`~vllm.multimodal.image.ImagePixelData` or :class:`~vllm.multimodal.image.ImageFeatureData`.
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.

.. note::

``multi_modal_data`` can accept keys and values beyond the builtin ones, as long as a customized plugin is registered through
:class:`vllm.multimodal.MULTIMODAL_REGISTRY`.

.. code-block:: python
Expand All @@ -61,7 +65,7 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptS
outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": ImagePixelData(image),
"multi_modal_data": {"image": image},
})
for o in outputs:
Expand Down Expand Up @@ -93,7 +97,6 @@ Below is an example on how to launch the same ``llava-hf/llava-1.5-7b-hf`` with
python -m vllm.entrypoints.openai.api_server \
--model llava-hf/llava-1.5-7b-hf \
--image-input-type pixel_values \
--image-token-id 32000 \
--image-input-shape 1,3,336,336 \
--image-feature-size 576 \
Expand Down
56 changes: 8 additions & 48 deletions examples/llava_example.py
Original file line number Diff line number Diff line change
@@ -1,84 +1,44 @@
import argparse
import os
import subprocess

import torch
from PIL import Image

from vllm import LLM
from vllm.multimodal.image import ImageFeatureData, ImagePixelData

# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.
# You can use `.buildkite/download-images.sh` to download them


def run_llava_pixel_values(*, disable_image_processor: bool = False):
def run_llava():
llm = LLM(
model="llava-hf/llava-1.5-7b-hf",
image_input_type="pixel_values",
image_token_id=32000,
image_input_shape="1,3,336,336",
image_feature_size=576,
disable_image_processor=disable_image_processor,
)

prompt = "<image>" * 576 + (
"\nUSER: What is the content of this image?\nASSISTANT:")

if disable_image_processor:
image = torch.load("images/stop_sign_pixel_values.pt")
else:
image = Image.open("images/stop_sign.jpg")
image = Image.open("images/stop_sign.jpg")

outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": ImagePixelData(image),
"multi_modal_data": {
"image": image
},
})

for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)


def run_llava_image_features():
llm = LLM(
model="llava-hf/llava-1.5-7b-hf",
image_input_type="image_features",
image_token_id=32000,
image_input_shape="1,576,1024",
image_feature_size=576,
)

prompt = "<image>" * 576 + (
"\nUSER: What is the content of this image?\nASSISTANT:")

image: torch.Tensor = torch.load("images/stop_sign_image_features.pt")

outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": ImageFeatureData(image),
})

for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)


def main(args):
if args.type == "pixel_values":
run_llava_pixel_values()
else:
run_llava_image_features()
def main():
run_llava()


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Demo on Llava")
parser.add_argument("--type",
type=str,
choices=["pixel_values", "image_features"],
default="pixel_values",
help="image input type")
args = parser.parse_args()
# Download from s3
s3_bucket_path = "s3://air-example-data-2/vllm_opensource_llava/"
local_directory = "images"
Expand All @@ -95,4 +55,4 @@ def main(args):
local_directory,
"--no-sign-request",
])
main(args)
main()
61 changes: 35 additions & 26 deletions examples/llava_next_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,44 @@
from PIL import Image

from vllm import LLM, SamplingParams
from vllm.multimodal.image import ImagePixelData

# Dynamic image input is currently not supported and therefore
# a fixed image input shape and its corresponding feature size is required.
# See https://github.com/vllm-project/vllm/pull/4199 for the complete
# configuration matrix.

llm = LLM(
model="llava-hf/llava-v1.6-mistral-7b-hf",
image_input_type="pixel_values",
image_token_id=32000,
image_input_shape="1,3,336,336",
image_feature_size=1176,
)

prompt = "[INST] " + "<image>" * 1176 + "\nWhat is shown in this image? [/INST]"
url = "https://h2o-release.s3.amazonaws.com/h2ogpt/bigben.jpg"
image = Image.open(BytesIO(requests.get(url).content))
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=100)

outputs = llm.generate(
{
"prompt": prompt,
"multi_modal_data": ImagePixelData(image),
},
sampling_params=sampling_params)

generated_text = ""
for o in outputs:
generated_text += o.outputs[0].text

print(f"LLM output:{generated_text}")

def run_llava_next():
llm = LLM(
model="llava-hf/llava-v1.6-mistral-7b-hf",
image_token_id=32000,
image_input_shape="1,3,336,336",
image_feature_size=1176,
)

prompt = "[INST] " + "<image>" * 1176 + (
"\nWhat is shown in this image? [/INST]")
url = "https://h2o-release.s3.amazonaws.com/h2ogpt/bigben.jpg"
image = Image.open(BytesIO(requests.get(url).content))
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
max_tokens=100)

outputs = llm.generate(
{
"prompt": prompt,
"multi_modal_data": {
"image": image
}
},
sampling_params=sampling_params)

generated_text = ""
for o in outputs:
generated_text += o.outputs[0].text

print(f"LLM output:{generated_text}")


if __name__ == "__main__":
run_llava_next()
1 change: 0 additions & 1 deletion examples/openai_vision_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
Launch the vLLM server with the following command:
python -m vllm.entrypoints.openai.api_server \
--model llava-hf/llava-1.5-7b-hf \
--image-input-type pixel_values \
--image-token-id 32000 \
--image-input-shape 1,3,336,336 \
--image-feature-size 576 \
Expand Down
6 changes: 3 additions & 3 deletions examples/phi3v_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from PIL import Image

from vllm import LLM, SamplingParams
from vllm.multimodal.image import ImagePixelData


def run_phi3v():
Expand All @@ -17,7 +16,6 @@ def run_phi3v():
llm = LLM(
model=model_path,
trust_remote_code=True,
image_input_type="pixel_values",
image_token_id=32044,
image_input_shape="1,3,1008,1344",
image_feature_size=1921,
Expand All @@ -35,7 +33,9 @@ def run_phi3v():
outputs = llm.generate(
{
"prompt": prompt,
"multi_modal_data": ImagePixelData(image),
"multi_modal_data": {
"image": image
},
},
sampling_params=sampling_params)
for o in outputs:
Expand Down
38 changes: 8 additions & 30 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,17 @@
AutoTokenizer, BatchEncoding)

from vllm import LLM, SamplingParams
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
from vllm.config import TokenizerPoolConfig
from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel)
from vllm.inputs import TextPrompt
from vllm.logger import init_logger
from vllm.sequence import SampleLogprobs
from vllm.utils import cuda_device_count_stateless, is_cpu

if TYPE_CHECKING:
from vllm.multimodal import MultiModalData
else:
# it will call torch.cuda.device_count()
MultiModalData = None
from vllm.sequence import SampleLogprobs
from vllm.utils import cuda_device_count_stateless, is_cpu
from vllm.multimodal import MultiModalDataDict

logger = init_logger(__name__)

Expand All @@ -51,35 +49,15 @@ def _read_prompts(filename: str) -> List[str]:
class ImageAsset:
name: Literal["stop_sign", "cherry_blossom"]

@cached_property
def pixel_values(self) -> torch.Tensor:
return torch.load(_IMAGE_DIR / f"{self.name}_pixel_values.pt")

@cached_property
def image_features(self) -> torch.Tensor:
return torch.load(_IMAGE_DIR / f"{self.name}_image_features.pt")

@cached_property
def pil_image(self) -> Image.Image:
return Image.open(_IMAGE_DIR / f"{self.name}.jpg")

def for_hf(self) -> Image.Image:
return self.pil_image

def for_vllm(self, vision_config: VisionLanguageConfig) -> MultiModalData:
# don't put this import at the top level
# it will call torch.cuda.device_count()
from vllm.multimodal.image import ImageFeatureData # noqa: F401
from vllm.multimodal.image import ImagePixelData
image_input_type = vision_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType

if image_input_type == ImageInputType.IMAGE_FEATURES:
return ImageFeatureData(self.image_features)
if image_input_type == ImageInputType.PIXEL_VALUES:
return ImagePixelData(self.pil_image)

raise NotImplementedError
def for_vllm(self) -> Dict[str, Any]:
return {"image": self.pil_image}


class _ImageAssetPrompts(TypedDict):
Expand Down Expand Up @@ -453,7 +431,7 @@ def generate(
self,
prompts: List[str],
sampling_params: SamplingParams,
images: Optional[List[MultiModalData]] = None,
images: Optional[List["MultiModalDataDict"]] = None,
) -> List[Tuple[List[List[int]], List[str]]]:
if images is not None:
assert len(prompts) == len(images)
Expand Down Expand Up @@ -502,7 +480,7 @@ def generate_greedy(
self,
prompts: List[str],
max_tokens: int,
images: Optional[List[MultiModalData]] = None,
images: Optional[List["MultiModalDataDict"]] = None,
) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = self.generate(prompts, greedy_params, images=images)
Expand Down
2 changes: 0 additions & 2 deletions tests/entrypoints/openai/test_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ def server():
"--max-model-len",
"4096",
"--enforce-eager",
"--image-input-type",
"pixel_values",
"--image-token-id",
"32000",
"--image-input-shape",
Expand Down
Loading

0 comments on commit 98d6682

Please sign in to comment.