Skip to content

Commit

Permalink
Remove VisionLanguageConfig from input mapper
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 committed Jun 3, 2024
1 parent 653537d commit a2f5a3c
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 84 deletions.
43 changes: 19 additions & 24 deletions tests/multimodal/test_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,23 @@ def test_clip_image_processor(hf_images, dtype):
seed=0,
dtype=dtype,
revision=None,
)
vlm_config = VisionLanguageConfig(
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
image_token_id=32000,
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
image_feature_size=576,
image_processor=MODEL_NAME,
image_processor_revision=None,
)
multimodal_config=VisionLanguageConfig(
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
image_token_id=32000,
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
image_feature_size=576,
image_processor=MODEL_NAME,
image_processor_revision=None,
))

for image in hf_images:
hf_result = hf_processor.preprocess(
image,
return_tensors="np",
)
vllm_result = MULTIMODAL_REGISTRY.map_input(
model_config,
ImagePixelData(image),
model_config=model_config,
vlm_config=vlm_config,
)

assert hf_result.keys() == vllm_result.keys()
Expand All @@ -65,26 +63,23 @@ def test_image_pixel_types(hf_images, vllm_image_tensors, dtype):
seed=0,
dtype=dtype,
revision=None,
)
vlm_config = VisionLanguageConfig(
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
image_token_id=32000,
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
image_feature_size=576,
image_processor=MODEL_NAME,
image_processor_revision=None,
)
multimodal_config=VisionLanguageConfig(
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
image_token_id=32000,
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
image_feature_size=576,
image_processor=MODEL_NAME,
image_processor_revision=None,
))

for image, tensor in zip(hf_images, vllm_image_tensors):
image_result = MULTIMODAL_REGISTRY.map_input(
model_config,
ImagePixelData(image),
model_config=model_config,
vlm_config=vlm_config,
)
tensor_result = MULTIMODAL_REGISTRY.map_input(
model_config,
ImagePixelData(tensor),
model_config=model_config,
vlm_config=vlm_config,
)

assert image_result.keys() == tensor_result.keys()
Expand Down
2 changes: 1 addition & 1 deletion vllm/inputs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def process_input(self, model_config: "ModelConfig",

return processor(model_config, inputs)

def create_input_processor(self, model_config: ModelConfig):
def create_input_processor(self, model_config: "ModelConfig"):
"""
Create an input processor (see :meth:`process_input`) for a
specific model.
Expand Down
21 changes: 7 additions & 14 deletions vllm/multimodal/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import (TYPE_CHECKING, Callable, Dict, Generic, Optional, Type,
TypeVar)

from vllm.config import ModelConfig, VisionLanguageConfig
from vllm.config import ModelConfig
from vllm.logger import init_logger

if TYPE_CHECKING:
Expand Down Expand Up @@ -32,8 +32,7 @@ class MultiModalData:
D = TypeVar("D", bound=MultiModalData)
N = TypeVar("N", bound=Type["nn.Module"])

MultiModalInputMapper = Callable[[D, ModelConfig, VisionLanguageConfig],
Dict[str, "torch.Tensor"]]
MultiModalInputMapper = Callable[[ModelConfig, D], Dict[str, "torch.Tensor"]]
"""Return a dictionary to be passed as keyword arguments to
:meth:`torch.nn.Module.forward`. This is similar in concept to tokenizers
and processors in HuggingFace Transformers."""
Expand Down Expand Up @@ -63,9 +62,8 @@ def get_data_type(self) -> Type[D]:
raise NotImplementedError

@abstractmethod
def _default_input_mapper(
self, data: D, model_config: ModelConfig,
vlm_config: VisionLanguageConfig) -> Dict[str, "torch.Tensor"]:
def _default_input_mapper(self, model_config: ModelConfig,
data: D) -> Dict[str, "torch.Tensor"]:
"""Return a dictionary to be passed as keyword arguments to
:meth:`torch.nn.Module.forward`. This is similar in concept to
tokenizers and processors in HuggingFace Transformers.
Expand Down Expand Up @@ -99,16 +97,11 @@ def wrapper(model_cls: N) -> N:

return wrapper

def map_input(
self, data: D, model_config: ModelConfig,
vlm_config: VisionLanguageConfig) -> Dict[str, "torch.Tensor"]:
def map_input(self, model_config: ModelConfig,
data: D) -> Dict[str, "torch.Tensor"]:
"""
Apply an input mapper to a :class:`~MultiModalData` instance passed
to the model, transforming the data into a dictionary of model inputs.
The model is identified by ``model_config``. ``vlm_config`` is
for compatibility purposes and may be merged into ``model_config``
in the near future.
"""
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
Expand All @@ -120,4 +113,4 @@ def map_input(
raise KeyError(f"No input mapper in {self} is registered for "
f"model class {model_cls.__name__}.")

return mapper(data, model_config, vlm_config)
return mapper(model_config, data)
16 changes: 7 additions & 9 deletions vllm/multimodal/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ class ImagePixelPlugin(MultiModalPlugin[ImagePixelData]):
def get_data_type(self) -> Type[ImagePixelData]:
return ImagePixelData

def _get_hf_image_processor(self, model_config: ModelConfig,
vlm_config: VisionLanguageConfig):
def _get_hf_image_processor(self, model_config: ModelConfig):
vlm_config = model_config.multimodal_config
if vlm_config is None or vlm_config.image_processor is None:
return None

Expand All @@ -238,12 +238,10 @@ def _get_hf_image_processor(self, model_config: ModelConfig,
revision=vlm_config.image_processor_revision,
)

def _default_input_mapper(
self, data: ImagePixelData, model_config: ModelConfig,
vlm_config: VisionLanguageConfig) -> Dict[str, torch.Tensor]:
def _default_input_mapper(self, model_config: ModelConfig,
data: ImagePixelData) -> Dict[str, torch.Tensor]:
image = data.image
image_processor = self._get_hf_image_processor(model_config,
vlm_config)
image_processor = self._get_hf_image_processor(model_config)

if isinstance(image, Image.Image):
if image_processor is None:
Expand Down Expand Up @@ -280,8 +278,8 @@ def get_data_type(self) -> Type[ImageFeatureData]:
return ImageFeatureData

def _default_input_mapper(
self, data: ImageFeatureData, model_config: ModelConfig,
vlm_config: VisionLanguageConfig) -> Dict[str, torch.Tensor]:
self, model_config: ModelConfig,
data: ImageFeatureData) -> Dict[str, torch.Tensor]:
image_features = data.image_features.to(model_config.dtype)

return {"image_features": image_features}
14 changes: 5 additions & 9 deletions vllm/multimodal/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from torch import nn

from vllm.config import ModelConfig, VisionLanguageConfig
from vllm.config import ModelConfig
from vllm.logger import init_logger

from .base import MultiModalData, MultiModalInputMapper, MultiModalPlugin
Expand Down Expand Up @@ -86,22 +86,18 @@ def register_image_feature_input_mapper(
"""
return self.register_input_mapper(ImageFeatureData, mapper)

def map_input(self, data: MultiModalData, model_config: ModelConfig,
vlm_config: VisionLanguageConfig):
def map_input(self, model_config: ModelConfig, data: MultiModalData):
"""
Apply an input mapper to a :class:`~MultiModalData` instance passed
to the model.
See :meth:`MultiModalPlugin.map_input` for more details.
"""
return self._get_plugin_for_data_type(type(data)) \
.map_input(data, model_config, vlm_config)
.map_input(model_config, data)

def create_input_mapper(self, model_config: ModelConfig,
vlm_config: VisionLanguageConfig):
def create_input_mapper(self, model_config: ModelConfig):
"""
Create an input mapper (see :meth:`map_input`) for a specific model.
"""
return functools.partial(self.map_input,
model_config=model_config,
vlm_config=vlm_config)
return functools.partial(self.map_input, model_config=model_config)
16 changes: 2 additions & 14 deletions vllm/worker/cpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,8 @@ def __init__(
)

# Create processor for multi-modal data
if self.vision_language_config is not None:
self.multi_modal_input_mapper = INPUT_REGISTRY.MULTIMODAL \
.create_input_mapper(
self.model_config,
self.vision_language_config,
)
else:
self.multi_modal_input_mapper = None
self.multi_modal_input_mapper = INPUT_REGISTRY.MULTIMODAL \
.create_input_mapper(self.model_config)

# Lazy initialization.
self.model: nn.Module # Set after init_Model
Expand Down Expand Up @@ -123,12 +117,6 @@ def _prepare_prompt(

mm_data = seq_group_metadata.multi_modal_data
if mm_data is not None:
# Process multi-modal data
if self.multi_modal_input_mapper is None:
raise ValueError(
"Multi-modal inputs are only supported by "
"vision language models.")

mm_kwargs = self.multi_modal_input_mapper(mm_data)
for k, v in mm_kwargs.items():
multi_modal_kwargs_list[k].append(v)
Expand Down
15 changes: 2 additions & 13 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,8 @@ def __init__(
)

# Create processor for multi-modal data
if self.vision_language_config is not None:
self.multi_modal_input_mapper = INPUT_REGISTRY.MULTIMODAL \
.create_input_mapper(
self.model_config,
self.vision_language_config,
)
else:
self.multi_modal_input_mapper = None
self.multi_modal_input_mapper = INPUT_REGISTRY.MULTIMODAL \
.create_input_mapper(self.model_config)

# Lazy initialization
self.model: nn.Module # Set after load_model
Expand Down Expand Up @@ -432,11 +426,6 @@ def _prepare_model_input(
mm_data = seq_group_metadata.multi_modal_data
if mm_data is not None:
# Process multi-modal data
if self.multi_modal_input_mapper is None:
raise ValueError(
"Multi-modal inputs are only supported by "
"vision language models.")

mm_kwargs = self.multi_modal_input_mapper(mm_data)
for k, v in mm_kwargs.items():
multi_modal_kwargs_list[k].append(v)
Expand Down

0 comments on commit a2f5a3c

Please sign in to comment.