Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Registry for processing model inputs #5214

Merged
merged 63 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
34bfa79
Introduce a higher level `INPUT_REGISTRY`
DarkLight1337 Jun 3, 2024
df2aa19
Move dummy data generation to input registry
DarkLight1337 Jun 3, 2024
c72d2b3
Update docs
DarkLight1337 Jun 3, 2024
d8c6488
Rename `process_input` to `map_input`
DarkLight1337 Jun 3, 2024
f18de48
Reorder arguments
DarkLight1337 Jun 3, 2024
653537d
Apply input processor
DarkLight1337 Jun 3, 2024
a2f5a3c
Remove `VisionLanguageConfig` from input mapper
DarkLight1337 Jun 3, 2024
378ad80
Fix bad use of `functools.partial`
DarkLight1337 Jun 3, 2024
7aa3778
Use default input processor
DarkLight1337 Jun 3, 2024
c774168
Merge branch 'upstream' into mm-image-tokenizer
DarkLight1337 Jun 4, 2024
532f863
Fix wrong arguments
DarkLight1337 Jun 4, 2024
080d40c
Use pillow image instead of tensor to avoid bypassing the processor b…
DarkLight1337 Jun 5, 2024
662693a
Update interface of dummy data factory and input processor
DarkLight1337 Jun 5, 2024
9bc5fcc
Use `InputContext` to handle checked type cast of config types
DarkLight1337 Jun 5, 2024
29c3bb3
Fix LLaVA-NeXT input processor and cleanup code
DarkLight1337 Jun 5, 2024
7bb6cbf
Add sanity check
DarkLight1337 Jun 6, 2024
ccf49c4
Merge branch 'upstream' into mm-image-tokenizer
DarkLight1337 Jun 6, 2024
3482d32
Merge branch 'upstream' into mm-image-tokenizer
DarkLight1337 Jun 6, 2024
8ea8468
Merge branch 'upstream' into mm-image-tokenizer
DarkLight1337 Jun 8, 2024
be3d64f
Merge branch 'upstream' into mm-image-tokenizer
DarkLight1337 Jun 8, 2024
2ff5be6
Merge branch 'upstream' into mm-image-tokenizer
DarkLight1337 Jun 10, 2024
8e2ff86
Update LLaVA-NeXT
DarkLight1337 Jun 11, 2024
553f684
Merge branch 'upstream' into mm-image-tokenizer
DarkLight1337 Jun 11, 2024
b134dfc
Update name
DarkLight1337 Jun 11, 2024
7e33706
Merge branch 'upstream' into mm-image-tokenizer
DarkLight1337 Jun 11, 2024
3fb622c
Remove `MULTIMODAL` convenience property as it was causing some (impo…
DarkLight1337 Jun 11, 2024
6a70e4f
Add docs
DarkLight1337 Jun 12, 2024
52a0116
Add docs
DarkLight1337 Jun 12, 2024
b7a8683
Merge branch 'upstream' into mm-image-tokenizer
DarkLight1337 Jun 12, 2024
25f9949
Merge branch 'upstream' into mm-image-tokenizer
DarkLight1337 Jun 13, 2024
fd7d954
Merge branch 'upstream' into mm-image-tokenizer
DarkLight1337 Jun 15, 2024
49dac3e
Merge branch 'upstream' into mm-image-tokenizer
DarkLight1337 Jun 15, 2024
0104218
Merge branch 'upstream' into mm-image-tokenizer
DarkLight1337 Jun 18, 2024
18cc7e0
Set up dummy data factory for phi3v
DarkLight1337 Jun 18, 2024
2291617
Move dummy data factories to model files
DarkLight1337 Jun 18, 2024
adf5503
Merge branch 'upstream' into mm-image-tokenizer
DarkLight1337 Jun 18, 2024
fecf1f0
Fix wrong feature size
DarkLight1337 Jun 18, 2024
086e0fe
Fix wrong feature size
DarkLight1337 Jun 18, 2024
c036b86
Merge branch 'upstream' into mm-image-tokenizer
DarkLight1337 Jun 24, 2024
bfa5aa9
Remove redundant code
DarkLight1337 Jun 24, 2024
07e695d
Apply isort
DarkLight1337 Jun 24, 2024
7229b07
Move `DummyImageDataFactories` into CLIP model file
DarkLight1337 Jun 25, 2024
d9a4150
Merge branch 'upstream' into mm-image-tokenizer
DarkLight1337 Jun 26, 2024
4b947ad
Merge branch 'upstream' into mm-image-tokenizer
DarkLight1337 Jun 26, 2024
9e82a26
Clarify docs and add todo
DarkLight1337 Jun 26, 2024
6b19e6c
Expand docs
DarkLight1337 Jun 26, 2024
f451668
Add ref
DarkLight1337 Jun 26, 2024
1abb8a7
Add docs
DarkLight1337 Jun 26, 2024
698830f
Fix name
DarkLight1337 Jun 26, 2024
36ab12d
Fix and add links
DarkLight1337 Jun 26, 2024
bf3281c
modify llava_next
ywang96 Jun 27, 2024
56e2d3b
Update comment
DarkLight1337 Jun 27, 2024
d2f8c6d
Update docs
DarkLight1337 Jun 27, 2024
7c197d2
Use dynamic image feature size calculation
DarkLight1337 Jun 27, 2024
f5ffd3e
Fix phi3v not handling `image_sizes` correctly
DarkLight1337 Jun 27, 2024
66aad21
Apply formatter
DarkLight1337 Jun 27, 2024
f2e4633
Merge branch 'upstream' into mm-image-tokenizer
DarkLight1337 Jun 27, 2024
a6e3162
Merge branch 'upstream' into mm-image-tokenizer
DarkLight1337 Jun 27, 2024
ce06541
Fix config
DarkLight1337 Jun 27, 2024
7e80ecc
Merge branch 'upstream' into mm-image-tokenizer
DarkLight1337 Jun 28, 2024
487d742
Merge branch 'upstream' into mm-image-tokenizer
DarkLight1337 Jun 28, 2024
43350b8
update example
ywang96 Jun 28, 2024
57791de
update doc
ywang96 Jun 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions docs/source/dev/input_processing/input_processing_pipeline.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
.. _input_processing_pipeline:

Input Processing Pipeline
=========================

1. Input data is passed to :class:`~vllm.LLMEngine` (or :class:`~vllm.AsyncLLMEngine`).

2. Tokenize the data if necessary.

3. Process the inputs using :meth:`INPUT_REGISTRY.process_input <vllm.inputs.registry.InputRegistry.process_input>`.
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved

- For example, add placeholder tokens to reserve KV cache for multi-modal embeddings.

4. Send the processed inputs to :class:`~vllm.executor.executor_base.ExecutorBase`.

5. Distribute the inputs via :class:`~vllm.worker.worker_base.WorkerBase` to :class:`~vllm.worker.model_runner_base.ModelRunnerBase`.

6. If the data contains multi-modal data, convert it into keyword arguments using :meth:`MULTIMODAL_REGISTRY.map_input <vllm.multimodal.MultiModalRegistry.map_input>`.

- For example, convert a :class:`PIL.Image.Image` input to its pixel values for a vision language model.
39 changes: 39 additions & 0 deletions docs/source/dev/input_processing/model_inputs_index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
.. _input_processing:

Input Processing
================

.. currentmodule:: vllm.inputs

vLLM provides a mechanism for defining input processors for each model so that the inputs are processed
in :class:`~vllm.LLMEngine` before they are passed to model executors.

Currently, this mechanism is only utilized in **multi-modal models** for preprocessing multi-modal input
data in addition to input prompt, but it can be extended to text-only language models when needed.

Guides
++++++

.. toctree::
:maxdepth: 1

input_processing_pipeline

Module Contents
+++++++++++++++

LLM Engine Inputs
-----------------

.. autoclass:: vllm.inputs.LLMInputs
:members:
:show-inheritance:

Registry
--------

.. autodata:: vllm.inputs.INPUT_REGISTRY

.. automodule:: vllm.inputs.registry
:members:
:show-inheritance:
8 changes: 1 addition & 7 deletions docs/source/dev/multimodal/multimodal_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@ By default, vLLM models do not support multi-modal inputs. To enable multi-modal
you must decorate the model class with :meth:`MULTIMODAL_REGISTRY.register_dummy_data <MultiModalRegistry.register_dummy_data>`,
as well as :meth:`MULTIMODAL_REGISTRY.register_input <MultiModalRegistry.register_input>` for each modality type to support.

.. contents::
:local:
:backlinks: none

Module Contents
+++++++++++++++

Expand All @@ -24,9 +20,7 @@ Module Contents
Registry
--------

.. data:: vllm.multimodal.MULTIMODAL_REGISTRY

The global :class:`MultiModalRegistry` which is used by model runners.
.. autodata:: vllm.multimodal.MULTIMODAL_REGISTRY

.. autoclass:: vllm.multimodal.MultiModalRegistry
:members:
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ Documentation
dev/offline_inference/offline_index
dev/engine/engine_index
dev/kernel/paged_attention
dev/input_processing/model_inputs_index
dev/multimodal/multimodal_index
dev/dockerfile/dockerfile

Expand Down
4 changes: 2 additions & 2 deletions docs/source/models/adding_model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ For instance, vLLM's `OPT model <https://github.com/vllm-project/vllm/blob/main/
2. Rewrite the :code:`forward` methods
--------------------------------------

Next, you need to rewrite the :code:`forward` methods of your model by following these steps:
Next, you need to rewrite the :meth:`~torch.nn.Module.forward` method of your model by following these steps:

1. Remove any unnecessary code, such as the code only used for training.
2. Change the input parameters:
Expand Down Expand Up @@ -75,7 +75,7 @@ Next, you need to rewrite the :code:`forward` methods of your model by following

If your model is too large to fit into a single GPU, you can use tensor parallelism to manage it.
To do this, substitute your model's linear and embedding layers with their tensor-parallel versions.
For the embedding layer, you can simply replace :code:`nn.Embedding` with :code:`VocabParallelEmbedding`. For the output LM head, you can use :code:`ParallelLMHead`.
For the embedding layer, you can simply replace :class:`torch.nn.Embedding` with :code:`VocabParallelEmbedding`. For the output LM head, you can use :code:`ParallelLMHead`.
When it comes to the linear layers, we provide the following options to parallelize them:

* :code:`ReplicatedLinear`: Replicates the inputs and weights across multiple GPUs. No memory saving.
Expand Down
3 changes: 2 additions & 1 deletion examples/phi3v_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@ def run_phi3v():
model_path = "microsoft/Phi-3-vision-128k-instruct"

# Note: The model has 128k context length by default which may cause OOM
# If that's the case, override `max_model_len` with a smaller value via args
# In this example, we override max_model_len to 2048.
llm = LLM(
model=model_path,
trust_remote_code=True,
image_input_type="pixel_values",
image_token_id=32044,
image_input_shape="1,3,1008,1344",
image_feature_size=1921,
max_model_len=2048,
)

image = Image.open("images/cherry_blossom.jpg")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,24 @@ def test_clip_image_processor(image_assets, dtype):
seed=0,
dtype=dtype,
revision=None,
)
vlm_config = VisionLanguageConfig(
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
image_token_id=32000,
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
image_feature_size=576,
image_processor=MODEL_NAME,
image_processor_revision=None,
multimodal_config=VisionLanguageConfig(
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
image_token_id=32000,
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
image_feature_size=576,
image_processor=MODEL_NAME,
image_processor_revision=None,
),
)

for asset in image_assets:
hf_result = hf_processor.preprocess(
asset.pil_image,
return_tensors="pt",
).to(dtype=_STR_DTYPE_TO_TORCH_DTYPE[dtype])
vllm_result = MULTIMODAL_REGISTRY.process_input(
vllm_result = MULTIMODAL_REGISTRY.map_input(
model_config,
ImagePixelData(asset.pil_image),
model_config=model_config,
vlm_config=vlm_config,
)

assert hf_result.keys() == vllm_result.keys()
Expand Down Expand Up @@ -74,25 +73,24 @@ def test_llava_next_image_processor(image_assets, dtype):
seed=0,
dtype=dtype,
revision=None,
)
vlm_config = VisionLanguageConfig(
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
image_token_id=64000,
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
image_feature_size=2928,
image_processor=MODEL_NAME,
image_processor_revision=None,
multimodal_config=VisionLanguageConfig(
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
image_token_id=64000,
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
image_feature_size=2928,
image_processor=MODEL_NAME,
image_processor_revision=None,
),
)

for asset in image_assets:
hf_result = hf_processor.preprocess(
asset.pil_image,
return_tensors="pt",
).to(dtype=_STR_DTYPE_TO_TORCH_DTYPE[dtype])
vllm_result = MULTIMODAL_REGISTRY.process_input(
vllm_result = MULTIMODAL_REGISTRY.map_input(
model_config,
ImagePixelData(asset.pil_image),
model_config=model_config,
vlm_config=vlm_config,
)

assert hf_result.keys() == vllm_result.keys()
Expand All @@ -119,26 +117,23 @@ def test_image_pixel_types(image_assets, dtype):
seed=0,
dtype=dtype,
revision=None,
)
vlm_config = VisionLanguageConfig(
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
image_token_id=32000,
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
image_feature_size=576,
image_processor=MODEL_NAME,
image_processor_revision=None,
)
multimodal_config=VisionLanguageConfig(
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
image_token_id=32000,
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
image_feature_size=576,
image_processor=MODEL_NAME,
image_processor_revision=None,
))

for asset in image_assets:
image_result = MULTIMODAL_REGISTRY.process_input(
image_result = MULTIMODAL_REGISTRY.map_input(
model_config,
ImagePixelData(asset.pil_image),
model_config=model_config,
vlm_config=vlm_config,
)
tensor_result = MULTIMODAL_REGISTRY.process_input(
tensor_result = MULTIMODAL_REGISTRY.map_input(
model_config,
ImagePixelData(asset.pixel_values),
model_config=model_config,
vlm_config=vlm_config,
)

assert image_result.keys() == tensor_result.keys()
Expand Down
3 changes: 3 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(
disable_sliding_window: bool = False,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None,
multimodal_config: Optional["VisionLanguageConfig"] = None,
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
self.model = model
self.tokenizer = tokenizer
Expand Down Expand Up @@ -159,6 +160,8 @@ def __init__(
sliding_window_len=self.get_hf_config_sliding_window())
self.served_model_name = get_served_model_name(model,
served_model_name)
self.multimodal_config = multimodal_config

if not self.skip_tokenizer_init:
self._verify_tokenizer_mode()
self._verify_embedding_mode()
Expand Down
64 changes: 32 additions & 32 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,36 @@ def create_engine_config(self, ) -> EngineConfig:
raise ValueError(
"BitsAndBytes load format and QLoRA adapter only support "
f"'bitsandbytes' quantization, but got {self.quantization}")
if self.image_input_type:
if (not self.image_token_id or not self.image_input_shape
or not self.image_feature_size):
raise ValueError(
'Specify `image_token_id`, `image_input_shape` and '
'`image_feature_size` together with `image_input_type`.')

if self.image_processor is None:
self.image_processor = self.model
if self.disable_image_processor:
if self.image_processor != self.model:
warnings.warn(
"You've specified an image processor "
f"({self.image_processor}) but also disabled "
"it via `--disable-image-processor`.",
stacklevel=2)

self.image_processor = None

vision_language_config = VisionLanguageConfig(
image_input_type=VisionLanguageConfig.
get_image_input_enum_type(self.image_input_type),
image_token_id=self.image_token_id,
image_input_shape=str_to_int_tuple(self.image_input_shape),
image_feature_size=self.image_feature_size,
image_processor=self.image_processor,
image_processor_revision=self.image_processor_revision,
)
else:
vision_language_config = None

device_config = DeviceConfig(device=self.device)
model_config = ModelConfig(
Expand All @@ -666,7 +696,8 @@ def create_engine_config(self, ) -> EngineConfig:
max_logprobs=self.max_logprobs,
disable_sliding_window=self.disable_sliding_window,
skip_tokenizer_init=self.skip_tokenizer_init,
served_model_name=self.served_model_name)
served_model_name=self.served_model_name,
multimodal_config=vision_language_config)
cache_config = CacheConfig(
block_size=self.block_size,
gpu_memory_utilization=self.gpu_memory_utilization,
Expand Down Expand Up @@ -742,37 +773,6 @@ def create_engine_config(self, ) -> EngineConfig:
model_loader_extra_config=self.model_loader_extra_config,
)

if self.image_input_type:
if (not self.image_token_id or not self.image_input_shape
or not self.image_feature_size):
raise ValueError(
'Specify `image_token_id`, `image_input_shape` and '
'`image_feature_size` together with `image_input_type`.')

if self.image_processor is None:
self.image_processor = self.model
if self.disable_image_processor:
if self.image_processor != self.model:
warnings.warn(
"You've specified an image processor "
f"({self.image_processor}) but also disabled "
"it via `--disable-image-processor`.",
stacklevel=2)

self.image_processor = None

vision_language_config = VisionLanguageConfig(
image_input_type=VisionLanguageConfig.
get_image_input_enum_type(self.image_input_type),
image_token_id=self.image_token_id,
image_input_shape=str_to_int_tuple(self.image_input_shape),
image_feature_size=self.image_feature_size,
image_processor=self.image_processor,
image_processor_revision=self.image_processor_revision,
)
else:
vision_language_config = None

decoding_config = DecodingConfig(
guided_decoding_backend=self.guided_decoding_backend)

Expand Down
8 changes: 5 additions & 3 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,11 @@ async def process_model_inputs_async(
else:
prompt_token_ids = inputs["prompt_token_ids"]

return LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))
llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))

return self.input_processor(llm_inputs)

async def add_request_async(
self,
Expand Down
13 changes: 9 additions & 4 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import LLMInputs, PromptInputs
from vllm.inputs import INPUT_REGISTRY, LLMInputs, PromptInputs
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
Expand Down Expand Up @@ -227,6 +227,9 @@ def __init__(
self.generation_config_fields = _load_generation_config_dict(
model_config)

self.input_processor = INPUT_REGISTRY.create_input_processor(
self.model_config)

self.model_executor = executor_class(
model_config=model_config,
cache_config=cache_config,
Expand Down Expand Up @@ -511,9 +514,11 @@ def process_model_inputs(
else:
prompt_token_ids = inputs["prompt_token_ids"]

return LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))
llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))

return self.input_processor(llm_inputs)

def add_request(
self,
Expand Down
Loading
Loading