diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 95e7d5d602ac..7613f97bcbe4 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -1014,6 +1014,13 @@ See [this page](#generative-models) for more information on how to use generativ * * ✅︎ * ✅︎ +- * `Ovis2ForConditionalGeneration`^ + * Ovis2 + * T + I+ + * `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis2-2B`, etc. + * + * + * ✅︎ - * `PaliGemmaForConditionalGeneration` * PaliGemma, PaliGemma 2 * T + IE diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index d02ac17cfdd6..d455ea2de85d 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -725,6 +725,36 @@ def run_nvlm_d(questions: list[str], modality: str) -> ModelRequestData: ) +# Ovis2 +def run_ovis2(questions: list[str], modality: str) -> ModelRequestData: + assert modality == "image" + + model_name = "AIDC-AI/Ovis2-1B" + tokenizer = "Isotr0py/Ovis2-tokenizer" + + engine_args = EngineArgs( + model=model_name, + tokenizer=tokenizer, + max_model_len=4096, + max_num_seqs=2, + trust_remote_code=True, + dtype="half", + hf_overrides={"architectures": ["Ovis2ForConditionalGeneration"]}, + limit_mm_per_prompt={"image": 1}, + ) + + placeholder = "\n" + prompts = [("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + f"<|im_start|>user\n{placeholder}" + f"{question}<|im_end|>\n" + "<|im_start|>assistant\n") for question in questions] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # PaliGemma def run_paligemma(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -1041,6 +1071,7 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData: "llama4": run_llama4, "molmo": run_molmo, "NVLM_D": run_nvlm_d, + "ovis2": run_ovis2, "paligemma": run_paligemma, "paligemma2": run_paligemma2, "phi3_v": run_phi3v, diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index 7f6608559f9c..f160339931b5 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -436,6 +436,36 @@ def load_nvlm_d(question: str, image_urls: list[str]) -> ModelRequestData: ) +# Ovis2 +def load_ovis2(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "AIDC-AI/Ovis2-1B" + tokenizer = "Isotr0py/Ovis2-tokenizer" + + engine_args = EngineArgs( + model=model_name, + tokenizer=tokenizer, + max_model_len=8192, + max_num_seqs=2, + trust_remote_code=True, + dtype="half", + limit_mm_per_prompt={"image": len(image_urls)}, + hf_overrides={"architectures": ["Ovis2ForConditionalGeneration"]}, + ) + + placeholder = '\n'.join( + [f'Image {i+1}: ' for i in range(len(image_urls))]) + '\n' + prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + f"<|im_start|>user\n{placeholder}" + f"{question}<|im_end|>\n" + "<|im_start|>assistant\n") + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image_data=[fetch_image(url) for url in image_urls], + ) + + def load_pixtral_hf(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "mistral-community/pixtral-12b" @@ -685,6 +715,7 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData: "mistral3": load_mistral3, "mllama": load_mllama, "NVLM_D": load_nvlm_d, + "ovis2": load_ovis2, "phi3_v": load_phi3v, "phi4_mm": load_phi4mm, "pixtral_hf": load_pixtral_hf, diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index 6073364c0199..3dd82b93fae8 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -467,6 +467,18 @@ max_num_seqs=2, patch_hf_runner=model_utils.molmo_patch_hf_runner, ), + "ovis2": VLMTestInfo( + models=["AIDC-AI/Ovis2-1B"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "\n", # noqa: E501 + max_model_len=4096, + max_num_seqs=2, + dtype="half", + # use sdpa mode for hf runner since ovis2 didn't work with flash_attn + hf_model_kwargs={"llm_attn_implementation": "sdpa"}, + patch_hf_runner=model_utils.ovis2_patch_hf_runner, + ), "phi3v": VLMTestInfo( models=["microsoft/Phi-3.5-vision-instruct"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), diff --git a/tests/models/decoder_only/vision_language/vlm_utils/core.py b/tests/models/decoder_only/vision_language/vlm_utils/core.py index fd046f3cd8e8..c3d20f56855f 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/core.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/core.py @@ -67,7 +67,7 @@ def run_test( "disable_mm_preprocessor_cache": True, } if model_info.tokenizer: - vllm_runner_kwargs_["tokenizer"] = model_info.tokenizer + vllm_runner_kwargs_["tokenizer_name"] = model_info.tokenizer if model_info.tokenizer_mode: vllm_runner_kwargs_["tokenizer_mode"] = model_info.tokenizer_mode if model_info.hf_overrides: diff --git a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py index 1185d80b97e3..c856fb198b32 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py @@ -676,3 +676,33 @@ def _generate(self, max_new_tokens=None, do_sample=None, **kwargs): hf_model.model.generate = types.MethodType(_generate, hf_model.model) return hf_model + + +def ovis2_patch_hf_runner(hf_model: HfRunner) -> HfRunner: + """Patches and returns an instance of the HfRunner to use for Ovis2.""" + hf_model.model.visual_tokenizer.to(hf_model.dtype) + hf_model.model.vte.to(hf_model.dtype) + hf_model.model.llm.to(hf_model.dtype) + + hf_model.model.get_output_embeddings = lambda: \ + hf_model.model.llm.get_output_embeddings() + + def processor(*args, text="", images=None, **kwargs): + text_tokenizer = hf_model.model.get_text_tokenizer() + images = [images] if isinstance(images, Image) else images + + text = text.split("<|im_start|>user\n")[1].split("<|im_end|>\n")[0] + + prompt, input_ids, pixel_values = hf_model.model.preprocess_inputs( + text_or_conversations=text, images=images) + attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id) + + inputs = { + "inputs": input_ids.unsqueeze(0), + "pixel_values": pixel_values.unsqueeze(0), + "attention_mask": attention_mask.unsqueeze(0), + } + return BatchFeature(data=inputs, tensor_type="pt") + + hf_model.processor = processor + return hf_model diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 4dc49d18c514..2b1d38dfda97 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -274,6 +274,7 @@ def _test_processing_correctness_mistral( "allenai/Molmo-7B-D-0924", "allenai/Molmo-7B-O-0924", "nvidia/NVLM-D-72B", + "AIDC-AI/Ovis2-1B", "google/paligemma-3b-mix-224", "google/paligemma2-3b-ft-docci-448", "microsoft/Phi-4-multimodal-instruct", diff --git a/tests/models/registry.py b/tests/models/registry.py index 8b330109d9ab..0305f756c55d 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -347,6 +347,10 @@ def check_available_online( max_transformers_version="4.48", transformers_version_reason="Use of deprecated imports which have been removed.", # noqa: E501 extras={"phi3.5": "microsoft/Phi-3.5-vision-instruct"}), # noqa: E501 + "Ovis2ForConditionalGeneration": _HfExamplesInfo("AIDC-AI/Ovis2-1B", + tokenizer="Isotr0py/Ovis2-tokenizer", + trust_remote_code=True, + hf_overrides={"architectures": ["Ovis2ForConditionalGeneration"]}), # noqa: E501 "Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct", trust_remote_code=True), "PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501 diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index fcaa24eec8c8..23dded7f226f 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -496,9 +496,10 @@ def _placeholder_str(self, modality: ModalityStr, if model_type.startswith("llava"): return self._cached_token_str(self._tokenizer, hf_config.image_token_index) + if model_type in ("aya_vision", "chameleon", "deepseek_vl_v2", - "internvl_chat", "skywork_chat", "NVLM_D", - "h2ovl_chat", "idefics3", "smolvlm"): + "internvl_chat", "ovis2", "skywork_chat", + "NVLM_D", "h2ovl_chat", "idefics3", "smolvlm"): return "" if model_type in ("mllama", "llama4"): return "<|image|>" diff --git a/vllm/model_executor/models/aimv2.py b/vllm/model_executor/models/aimv2.py new file mode 100644 index 000000000000..730e770dc3d6 --- /dev/null +++ b/vllm/model_executor/models/aimv2.py @@ -0,0 +1,322 @@ +# SPDX-License-Identifier: Apache-2.0 + +# A modified implementation of the AIMv2 Transformer +# inserted here also the image tokenizer used by Ovis2 +from typing import Optional + +import torch +from torch import nn, softmax +from torch.nn import functional as F +from torch.nn.functional import gumbel_softmax, pad + +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.transformers_utils.configs.ovis2 import (AIMv2Config, + Aimv2VisualTokenizerConfig) + +IMAGE_INDICATOR_IDS = [-301, -302, -303, -304, + -305] # kept for vocab prefixed tokens + + +def st_argmax(y_soft: torch.Tensor, dim: int): # straight-through softmax + index = y_soft.max(dim, keepdim=True)[1] + y_hard = torch.zeros_like( + y_soft, memory_format=torch.legacy_contiguous_format).scatter_( + dim, index, 1.0) + ret = y_hard - y_soft.detach() + y_soft + return ret + + +class Aimv2VisualTokenizer(torch.nn.Module): + + def __init__(self, + config: Aimv2VisualTokenizerConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + **kwargs): + super().__init__() + self.config = config + self.backbone = AIMv2Model( + config=config.backbone_config, # noqa + quant_config=quant_config, + prefix=f"{prefix}.visual_tokenizer") + # reserved tokens for IMAGE_INDICATORS + head_dim = config.vocab_size - len(IMAGE_INDICATOR_IDS) + self.head = torch.nn.Sequential( + ReplicatedLinear( + config.backbone_config.hidden_size * config.hidden_stride * + config.hidden_stride, + head_dim, + bias=False, + ), torch.nn.LayerNorm(head_dim)) + + @property + def dtype(self): + return self.backbone.dtype + + @property + def device(self): + return self.backbone.device + + def tokenize(self, logits): + if self.config.tokenize_function == 'softmax': + tokens = softmax(logits, dim=-1) + elif self.config.tokenize_function == 'gumbel_argmax': + tokens = gumbel_softmax(logits, tau=self.config.tau, hard=True) + elif self.config.tokenize_function == 'st_argmax': + tokens = st_argmax(logits, dim=-1) + else: + raise ValueError( + 'Invalid `max_type`, expected softmax or gumbel_argmax ' + f'or st_argmax, but got {self.config.tokenize_function}') + return tokens + + def encode(self, pixel_values): + features = self.backbone(pixel_values) + if self.config.drop_cls_token: + features = features[:, 1:, :] + + # merge number of `hidden_stride * hidden_stride` hidden states together + # to reduce token sequence length + # e.g., for hidden_stride=2, this leads to a token length reduction: + # 1024 -> 256 for aimv2 + if self.config.hidden_stride > 1: + # this `d` maybe different from the above `d`` + n, L, d = features.shape + sqrt_l = int(L**0.5) + assert sqrt_l**2 == L, ( + "The token sequence length should be a perfect square.") + features = features.reshape(n, sqrt_l, sqrt_l, d) + pl = (self.config.hidden_stride - + (sqrt_l % + self.config.hidden_stride)) % self.config.hidden_stride + features = pad(features, (0, 0, 0, pl, 0, pl), "constant", 0) + sqrt_l += pl + features = features.reshape(n, sqrt_l // self.config.hidden_stride, + self.config.hidden_stride, + sqrt_l // self.config.hidden_stride, + self.config.hidden_stride, d) + # [n, sqrt_l/hs, sqrt_l/hs, hs, hs, d] + features = features.permute(0, 1, 3, 2, 4, 5) + # [n, sqrt_l/hs, sqrt_l/hs, hs*hs*d] + features = features.flatten(3) + # [n, sqrt_l/hs*sqrt_l/hs, hs*hs*d] + features = features.reshape( + n, -1, + self.config.hidden_stride * self.config.hidden_stride * d) + + return features + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + """[BatchSize, ImageShape] -> [BatchSize, Token, VocabSize]""" + features = self.encode(pixel_values) + logits, _ = self.head[0]( + features) # we spllit the sequncial here for not throwing an error + logits = self.head[1](logits) + tokens = self.tokenize(logits) + # tokens' shape is [BatchSize, #Token, VocabSize-5], so padding with + # [BatchSize, #Token, 5], after which, tokens' shape should become + # [BatchSize, #Token, VocabSize] + batch_size, token_len, _ = tokens.shape + padding_tensor = torch.zeros(size=(batch_size, token_len, + len(IMAGE_INDICATOR_IDS)), + dtype=tokens.dtype, + device=tokens.device, + layout=tokens.layout, + requires_grad=False) + tokens = torch.cat((tokens, padding_tensor), dim=2) + return tokens + + +class AIMv2SwiGLUFFN(nn.Module): + + def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig, + prefix: str): + super().__init__() + hidden_features = config.intermediate_size + in_features = config.hidden_size + bias = config.use_bias + + # TODO(Isotr0py): investigate if we can add TP to visual tokenizer + self.fc1 = ReplicatedLinear(in_features, + hidden_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.fc1") + self.fc2 = ReplicatedLinear(hidden_features, + in_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.fc2") + self.fc3 = ReplicatedLinear(in_features, + hidden_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.fc3") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_parallel, _ = self.fc1(x) + gate, _ = self.fc3(x) + x_parallel = F.silu(x_parallel) * gate + out, _ = self.fc2(x_parallel) + return out + + +class AIMv2PatchEmbed(nn.Module): + + def __init__(self, config: AIMv2Config): + super().__init__() + self.proj = nn.Conv2d( + config.num_channels, + config.hidden_size, + kernel_size=(config.patch_size, config.patch_size), + stride=(config.patch_size, config.patch_size), + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x).flatten(2).transpose(1, 2) + x = self.norm.forward_native(x) + return x + + +class AIMv2ViTPreprocessor(nn.Module): + + def __init__(self, config: AIMv2Config): + super().__init__() + num_patches = (config.image_size // config.patch_size)**2 + + self.patchifier = AIMv2PatchEmbed(config) + self.pos_embed = nn.Parameter( + torch.zeros((1, num_patches, config.hidden_size))) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + tokens = self.patchifier(x) + _, N, _ = tokens.shape + pos_embed = self.pos_embed.to(tokens.device) + tokens = tokens + pos_embed[:, :N] + return tokens + + +class AIMv2Attention(nn.Module): + + def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig, + prefix: str): + super().__init__() + dim = config.hidden_size + + # TODO(Isotr0py): investigate if we can add TP to visual tokenizer + self.num_heads = config.num_attention_heads + self.qkv = ReplicatedLinear(dim, dim * 3, bias=config.qkv_bias) + # self.qkv = QKVParallelLinear( + # hidden_size=dim, + # head_size=dim // config.num_attention_heads, + # total_num_heads=config.num_attention_heads, + # bias=config.qkv_bias, + # quant_config=quant_config, + # prefix=f"{prefix}.qkv") + self.proj = ReplicatedLinear(dim, dim, bias=config.use_bias) + # self.proj = RowParallelLinear(input_size=dim, + # output_size=dim, + # bias = config.use_bias, + # quant_config=quant_config, + # prefix=f"{prefix}.proj") + + def forward( # todo might implement multiple attn implementations + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None) -> torch.Tensor: + B, N, C = x.shape + qkv, _ = self.qkv(x) + + qkv = qkv.reshape(B, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv.unbind(0) + + x = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) + x = x.transpose(1, 2).contiguous().reshape(B, N, C) + x, _ = self.proj(x) + return x + + +class AIMv2Block(nn.Module): + + def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig, + prefix: str): + super().__init__() + self.attn = AIMv2Attention(config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + self.norm_1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.mlp = AIMv2SwiGLUFFN(config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + self.norm_2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward(self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None) -> torch.Tensor: + x = x + self.attn(self.norm_1.forward_native(x), mask) + x = x + self.mlp(self.norm_2.forward_native(x)) + return x + + +class AIMv2Transformer(nn.Module): + + def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig, + prefix: str): + super().__init__() + + self.blocks = nn.ModuleList([ + AIMv2Block(config, quant_config, prefix=f"{prefix}.blocks.{i}") + for i in range(config.num_hidden_layers) + ]) + self.post_trunk_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + tokens: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # they take the -1 as the ref embeddings, like a clip skip + for block in self.blocks: + tokens = block(tokens, mask) + # NO NORM IN THE OG IMPLEMENTATION + # tokens = self.post_trunk_norm(tokens) + return tokens + + +class AIMv2Model(torch.nn.Module): + + def __init__(self, + config: AIMv2Config, + quant_config: QuantizationConfig, + prefix: str = ""): + super().__init__() + self.preprocessor = AIMv2ViTPreprocessor(config) + self.trunk = AIMv2Transformer(config, + quant_config=quant_config, + prefix=f"{prefix}.trunk") + + @property + def dtype(self): + return self.trunk.blocks[0].attn.qkv.weight.dtype + + @property + def device(self): + return self.trunk.blocks[0].attn.qkv.device + + def forward( + self, + pixel_values: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + x = self.preprocessor(pixel_values) + x = self.trunk(x, mask) + + return x diff --git a/vllm/model_executor/models/ovis2.py b/vllm/model_executor/models/ovis2.py new file mode 100644 index 000000000000..638077bc87d5 --- /dev/null +++ b/vllm/model_executor/models/ovis2.py @@ -0,0 +1,331 @@ +# SPDX-License-Identifier: Apache-2.0 + +# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/ovis/modeling_ovis.py +# Copyright 2023 The vLLM team. +# Copyright 2023 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Ovis2 model.""" +from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, + TypedDict, Union) + +import torch +import torch.nn as nn +from torch import Tensor +from transformers import BatchFeature + +from vllm.config import VllmConfig +from vllm.model_executor.models.aimv2 import Aimv2VisualTokenizer +from vllm.model_executor.models.utils import (AutoWeightsLoader, flatten_bn, + init_vllm_registered_model, + maybe_prefix) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs) +from vllm.multimodal.parse import ImageSize, MultiModalDataItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.ovis2 import OvisConfig +from vllm.transformers_utils.processors.ovis2 import OvisProcessor + +from .interfaces import MultiModalEmbeddings, SupportsMultiModal +from .utils import merge_multimodal_embeddings + +# Cannot find the following number from hf config. +IMAGE_TOKEN = "" +IMAGE_ATOM_TOKEN_ID = 151666 +IMAGE_PAD_TOKEN_ID = 151672 +NUMBER_OF_TOKEN_TO_RESERVE_FOR_SEGMENT = 256 + + +class Ovis2ImagePatchInputs(TypedDict): + type: Literal["image_patches"] + flat_data: torch.Tensor + """ + Shape: + `(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)` + """ + + patches_per_image: List[int] + """ + List of number of total patches for each image in the batch. + This is used to restore the first two dimensions of `flat_data`. + """ + + +class VisualEmbedding(torch.nn.Embedding): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, visual_tokens: Tensor) -> Tensor: + if visual_tokens.dtype in [ + torch.int8, torch.int16, torch.int32, torch.int64, torch.long + ]: + return super().forward(visual_tokens) + return torch.matmul(visual_tokens, self.weight) + + @property + def device(self): + return self.weight.device + + @property + def dtype(self): + return self.weight.dtype + + +class Ovis2ProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(OvisConfig) + + def get_hf_processor(self, **kwargs): + return self.ctx.get_hf_processor(OvisProcessor) + + def get_image_processor(self) -> OvisProcessor: + return self.get_hf_processor().image_processor # type: ignore + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return { # 32k is model token limit at the moment + "image": + self.get_hf_config().multimodal_max_length // + ((9 + 1) * NUMBER_OF_TOKEN_TO_RESERVE_FOR_SEGMENT) + } + + def get_image_size_with_most_features(self) -> ImageSize: + image_processor = self.get_image_processor() + return ImageSize(width=image_processor.size['shortest_edge'] * 9 * 2, + height=image_processor.size['shortest_edge'] * 9 * 2) + + +class Ovis2DummyInputsBuilder(BaseDummyInputsBuilder[Ovis2ProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + return IMAGE_TOKEN * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + target_width, target_height = \ + self.info.get_image_size_with_most_features() + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images), + } + return mm_data + + +class Ovis2MultiModalProcessor(BaseMultiModalProcessor[Ovis2ProcessingInfo]): + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + if not mm_data: + # # Avoid warning from HF logger for text-only input + prompt_ids = self.info.get_tokenizer().encode(prompt) + # prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) nope + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + + return processed_outputs + + def _apply_hf_processor_tokens_only( + self, + prompt_tokens: list[int], + ) -> list[int]: + + return prompt_tokens + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(pixel_values=MultiModalFieldConfig.batched("image"), + grids=MultiModalFieldConfig.batched("image")) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptReplacement]: + + def get_replacement_ovis(item_idx): + grid = out_mm_kwargs["grids"][item_idx] + + hf_processor = self.info.get_hf_processor() + return hf_processor.construct_image_placeholders(grid) + + return [ + PromptReplacement( + modality="image", + target=IMAGE_TOKEN, + replacement=get_replacement_ovis, + ), + ] + + +@MULTIMODAL_REGISTRY.register_processor(Ovis2MultiModalProcessor, + info=Ovis2ProcessingInfo, + dummy_inputs=Ovis2DummyInputsBuilder) +class Ovis2ForConditionalGeneration(nn.Module, SupportsMultiModal): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + self.config: OvisConfig = config + self.llm = init_vllm_registered_model( + vllm_config=vllm_config.with_hf_config(config.get_text_config()), + prefix=maybe_prefix(prefix, "llm"), + ) + + self.visual_tokenizer = Aimv2VisualTokenizer( + config=config.visual_tokenizer_config, + quant_config=quant_config, + prefix=f"{prefix}.visual_tokenizer", + image_processor_name_or_path=config.visual_tokenizer_config. + backbone_config.name_or_path, + ) + + self.vte = VisualEmbedding( + self.config.visual_tokenizer_config.vocab_size, + self.config.hidden_size) + + # TODO(Isotr0py): PP support + # self.make_empty_intermediate_tensors = ( + # self.language_model.make_empty_intermediate_tensors) + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Ovis2ImagePatchInputs]: + pixel_values = kwargs.pop("pixel_values", None) + if pixel_values is None: + return None + + if pixel_values is not None: + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + return Ovis2ImagePatchInputs( + type="image_patches", + flat_data=flatten_bn(flatten_bn(pixel_values), concat=True), + patches_per_image=[ + x.shape[0] for x in flatten_bn(pixel_values) + ], + ) + + raise AssertionError("This line should be unreachable.") + + def _process_image_input( + self, image_input: Ovis2ImagePatchInputs) -> MultiModalEmbeddings: + image_patches_flat = image_input["flat_data"] + patches_per_image = image_input["patches_per_image"] + + target_dtype = self.visual_tokenizer.dtype + visual_tokens = self.visual_tokenizer( + image_patches_flat.to(target_dtype)) + visual_embeds = self.vte(visual_tokens) # 1:1 numeric eq. + + return tuple( + x.flatten(0, 1) + for x in visual_embeds.split(patches_per_image, dim=0)) + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + + image_features = self._process_image_input(image_input) + + return image_features + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.llm.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + [IMAGE_ATOM_TOKEN_ID, IMAGE_PAD_TOKEN_ID]) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + # up until here we have a inputs_embeds 100% numerical identity + # between the OG HF Transformers implementation and ours + hidden_states = self.llm( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.llm.logits_processor(self.llm.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) + + def get_language_model(self) -> torch.nn.Module: + return self.llm diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index df5b2323212b..156a201de35a 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -195,6 +195,7 @@ "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"), # noqa: E501 "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"), "NVLM_D": ("nvlm_d", "NVLM_D_Model"), + "Ovis2ForConditionalGeneration": ("ovis2", "Ovis2ForConditionalGeneration"), "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501 "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501 diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 5ddfadb02471..f6c2b35535b6 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -38,9 +38,9 @@ MiniMaxVL01Config, MllamaConfig, MLPSpeculatorConfig, MPTConfig, NemotronConfig, NVLM_D_Config, - RWConfig, SkyworkR1VChatConfig, - SolarConfig, Telechat2Config, - UltravoxConfig) + OvisConfig, RWConfig, + SkyworkR1VChatConfig, SolarConfig, + Telechat2Config, UltravoxConfig) # yapf: enable from vllm.transformers_utils.utils import check_gguf_file from vllm.utils import resolve_obj_by_qualname @@ -79,6 +79,7 @@ "minimax_vl_01": MiniMaxVL01Config, "nemotron": NemotronConfig, "NVLM_D": NVLM_D_Config, + "ovis": OvisConfig, "solar": SolarConfig, "skywork_chat": SkyworkR1VChatConfig, "telechat": Telechat2Config, diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 8945c45ea86e..db3efafeef96 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -23,6 +23,7 @@ from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.nemotron import NemotronConfig from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config +from vllm.transformers_utils.configs.ovis2 import OvisConfig from vllm.transformers_utils.configs.skyworkr1v import SkyworkR1VChatConfig from vllm.transformers_utils.configs.solar import SolarConfig from vllm.transformers_utils.configs.telechat2 import Telechat2Config @@ -49,6 +50,7 @@ "KimiVLConfig", "NemotronConfig", "NVLM_D_Config", + "OvisConfig", "SkyworkR1VChatConfig", "SolarConfig", "Telechat2Config", diff --git a/vllm/transformers_utils/configs/ovis2.py b/vllm/transformers_utils/configs/ovis2.py new file mode 100644 index 000000000000..437a16e778c2 --- /dev/null +++ b/vllm/transformers_utils/configs/ovis2.py @@ -0,0 +1,170 @@ +# SPDX-License-Identifier: Apache-2.0 + +# yapf: disable +# ruff: noqa: E501 +# copied from https://huggingface.co/AIDC-AI/Ovis2-1B/blob/main/configuration_aimv2.py +# and https://huggingface.co/AIDC-AI/Ovis2-1B/blob/main/configuration_ovis.py +from typing import Any, Optional, Union + +from transformers import AutoConfig, PretrainedConfig + + +class AIMv2Config(PretrainedConfig): + """This is the configuration class to store the configuration of an [`AIMv2Model`]. + + Instantiating a configuration with the defaults will yield a similar configuration + to that of the [apple/aimv2-large-patch14-224](https://huggingface.co/apple/aimv2-large-patch14-224). + + Args: + hidden_size: Dimension of the hidden representations. + intermediate_size: Dimension of the SwiGLU representations. + num_hidden_layers: Number of hidden layers in the Transformer. + num_attention_heads: Number of attention heads for each attention layer + in the Transformer. + num_channels: Number of input channels. + image_size: Image size. + patch_size: Patch size. + rms_norm_eps: Epsilon value used for the RMS normalization layer. + attention_dropout: Dropout ratio for attention probabilities. + projection_dropout: Dropout ratio for the projection layer after the attention. + qkv_bias: Whether to add a bias to the queries, keys and values. + use_bias: Whether to add a bias in the feed-forward and projection layers. + kwargs: Keyword arguments for the [`PretrainedConfig`]. + """ + + model_type: str = "aimv2" + + def __init__( + self, + hidden_size: int = 1024, + intermediate_size: int = 2816, + num_hidden_layers: int = 24, + num_attention_heads: int = 8, + num_channels: int = 3, + image_size: int = 224, + patch_size: int = 14, + rms_norm_eps: float = 1e-5, + attention_dropout: float = 0.0, + projection_dropout: float = 0.0, + qkv_bias: bool = False, + use_bias: bool = False, + **kwargs: Any, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.rms_norm_eps = rms_norm_eps + + self.projection_dropout = projection_dropout + self.qkv_bias = qkv_bias + self.use_bias = use_bias + + +IGNORE_ID = -100 +IMAGE_TOKEN_ID = -200 +IMAGE_TOKEN = "" +IMAGE_ATOM_ID = -300 +IMAGE_INDICATOR_IDS = [-301, -302, -303, -304, -305] + +AutoConfig.register("aimv2", AIMv2Config) + + +# ---------------------------------------------------------------------- +# Visual Tokenizer Configuration +# ---------------------------------------------------------------------- +class BaseVisualTokenizerConfig(PretrainedConfig): + + def __init__(self, + vocab_size=16384, + tokenize_function="softmax", + tau=1.0, + depths=None, + drop_cls_token=False, + backbone_config: Optional[Union[PretrainedConfig, + dict]] = None, + hidden_stride: int = 1, + **kwargs): + super().__init__(**kwargs) + self.vocab_size = vocab_size + self.tokenize_function = tokenize_function + self.tau = tau + if isinstance(depths, str): + depths = [int(x) for x in depths.split('|')] + self.depths = depths + self.backbone_kwargs = dict[str, Any]() + self.drop_cls_token = drop_cls_token + if backbone_config is not None: + assert isinstance(backbone_config, (PretrainedConfig, dict)), \ + f"expect `backbone_config` to be instance of PretrainedConfig or dict, but got {type(backbone_config)} type" + if not isinstance(backbone_config, PretrainedConfig): + model_type = backbone_config['model_type'] + backbone_config.pop('model_type') + backbone_config = AutoConfig.for_model(model_type, + **backbone_config) + self.backbone_config = backbone_config + self.hidden_stride = hidden_stride + + +class Aimv2VisualTokenizerConfig(BaseVisualTokenizerConfig): + model_type = "aimv2_visual_tokenizer" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + if self.drop_cls_token: + self.drop_cls_token = False + if self.depths: + assert len(self.depths) == 1 + self.backbone_kwargs['num_hidden_layers'] = self.depths[0] + + +AutoConfig.register("aimv2_visual_tokenizer", Aimv2VisualTokenizerConfig) + + +# ---------------------------------------------------------------------- +# Ovis Configuration +# ---------------------------------------------------------------------- +class OvisConfig(PretrainedConfig): + model_type = "ovis" + + def __init__(self, + llm_config: Optional[Union[PretrainedConfig, dict]] = None, + visual_tokenizer_config: Optional[Union[PretrainedConfig, + dict]] = None, + multimodal_max_length=8192, + hidden_size=None, + conversation_formatter_class=None, + llm_attn_implementation=None, + disable_tie_weight=False, + **kwargs): + super().__init__(**kwargs) + if llm_config is not None: + assert isinstance(llm_config, (PretrainedConfig, dict)), \ + f"expect `llm_config` to be instance of PretrainedConfig or dict, but got {type(llm_config)} type" + if not isinstance(llm_config, PretrainedConfig): + model_type = llm_config['model_type'] + llm_config.pop('model_type') + llm_config = AutoConfig.for_model(model_type, **llm_config) + + # map llm_config to text_config + self.text_config = llm_config + if visual_tokenizer_config is not None: + assert isinstance(visual_tokenizer_config, (PretrainedConfig, dict)), \ + f"expect `visual_tokenizer_config` to be instance of PretrainedConfig or dict, but got {type(visual_tokenizer_config)} type" + if not isinstance(visual_tokenizer_config, PretrainedConfig): + model_type = visual_tokenizer_config['model_type'] + visual_tokenizer_config.pop('model_type') + visual_tokenizer_config = AutoConfig.for_model( + model_type, **visual_tokenizer_config) + + self.visual_tokenizer_config = visual_tokenizer_config + self.multimodal_max_length = multimodal_max_length + self.hidden_size = hidden_size + self.conversation_formatter_class = conversation_formatter_class + self.llm_attn_implementation = llm_attn_implementation + self.disable_tie_weight = disable_tie_weight diff --git a/vllm/transformers_utils/processors/__init__.py b/vllm/transformers_utils/processors/__init__.py index 4696f0c49df9..2e9cf3e4d90b 100644 --- a/vllm/transformers_utils/processors/__init__.py +++ b/vllm/transformers_utils/processors/__init__.py @@ -2,5 +2,6 @@ from vllm.transformers_utils.processors.deepseek_vl2 import ( DeepseekVLV2Processor) +from vllm.transformers_utils.processors.ovis2 import OvisProcessor -__all__ = ["DeepseekVLV2Processor"] +__all__ = ["DeepseekVLV2Processor", "OvisProcessor"] diff --git a/vllm/transformers_utils/processors/ovis2.py b/vllm/transformers_utils/processors/ovis2.py new file mode 100644 index 000000000000..fa5bdd40e727 --- /dev/null +++ b/vllm/transformers_utils/processors/ovis2.py @@ -0,0 +1,397 @@ +# SPDX-License-Identifier: Apache-2.0 + +# yapf: disable +# ruff: noqa: E501 +# coding=utf-8 +# adapted from https://github.com/AIDC-AI/Ovis/blob/35ab51a1a1e3542fa6db260a1084cefbc8f164bb/ovis/vllm/processing_ovis.py +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Union + +import PIL +import torch +from transformers import AutoProcessor, BatchFeature +from transformers.image_utils import ImageInput +from transformers.processing_utils import (ProcessingKwargs, ProcessorMixin, + Unpack) +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput + +__all__ = [ 'OvisProcessor'] +IGNORE_ID = -100 + +class OvisProcessorKwargs(ProcessingKwargs, total=False): # type: ignore[call-arg] + _defaults = { + "text_kwargs": { + "padding": False, + }, + "images_kwargs": { + 'max_partition':9, + 'covering_threshold':0.9, + 'convert_to_rgb':True, + 'return_tensors':'pt'}, + } + + + +class OvisProcessor(ProcessorMixin): + r""" + Constructs a Ovis processor which wraps a Ovis image processor and a Qwen2 tokenizer into a single processor. + [`OvisProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the + [`~OvisProcessor.__call__`] and [`~OvisProcessor.decode`] for more information. + Args: + image_processor ([`Qwen2VLImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`Qwen2TokenizerFast`], *optional*): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer"] + valid_kwargs = ["chat_template"] + + image_processor_class = "AutoImageProcessor" + tokenizer_class = "Qwen2Tokenizer" + + def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs): + self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token + self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + self.extra_special_tokens = { + "image_token": "", + "image_atom": "", + "image_start": "", + "image_prefix": "
",
+            "image_col_sep": "",
+            "image_row_sep": "",
+            "image_end": "",
+            'image_pad': '',
+        }
+
+    def __call__(
+        self,
+        images: ImageInput = None,
+        text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
+        **kwargs: Unpack[OvisProcessorKwargs],
+    ) -> BatchFeature:
+        """
+        Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
+        and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
+        the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
+        Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
+            Args:
+                images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
+                    The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+                    tensor. Both channels-first and channels-last formats are supported.
+                text (`str`, `List[str]`, `List[List[str]]`):
+                    The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+                    (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+                    `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+                videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
+                    The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
+                    tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
+                return_tensors (`str` or [`~utils.TensorType`], *optional*):
+                    If set, will return tensors of a particular framework. Acceptable values are:
+                    - `'tf'`: Return TensorFlow `tf.constant` objects.
+                    - `'pt'`: Return PyTorch `torch.Tensor` objects.
+                    - `'np'`: Return NumPy `np.ndarray` objects.
+                    - `'jax'`: Return JAX `jnp.ndarray` objects.
+            Returns:
+                [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+                - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
+                - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+                  `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+                  `None`).
+                - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
+                - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
+                - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
+                - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
+                - **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`.
+        """
+        output_kwargs = self._merge_kwargs(
+            OvisProcessorKwargs,
+            tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+            **kwargs,
+        )
+
+        # Process all images first
+        image_features = {}
+        if images is not None:
+            processed_images = []
+            image_placeholders_list = []
+            grids = []
+
+            # Process each image
+            for image in images if isinstance(images, list) else [images]:
+                pixel_values, image_placeholders, grid = self.preprocess_image(
+                    image=image, **output_kwargs["images_kwargs"]
+                )
+                processed_images.append(pixel_values)
+                image_placeholders_list.append(image_placeholders)
+                grids.append(grid)
+
+            # assign all processed images
+            if processed_images:
+                image_features["image_placeholders"] = image_placeholders_list
+
+        # Process text input
+        if text is not None:
+
+            if not isinstance(text, list):
+                text = [text]
+
+            tokenized_batched_text = self.tokenizer.batch_encode_plus(
+                text,
+                **output_kwargs["text_kwargs"]
+            )
+            image_token_id = self.get_token_value("image_token")
+            replaced_ids_list = []
+            replaced_attn_mask_list = []
+            idx = 0
+            for ids_tensor, attn_mask in zip(tokenized_batched_text['input_ids'],
+                                             tokenized_batched_text['attention_mask']):
+                if image_token_id in ids_tensor and "image_placeholders" in image_features:
+                    if idx < len(image_features["image_placeholders"]):
+                        # Converts in list for ease of use
+                        ids_list = ids_tensor.tolist()
+                        attn_list = attn_mask.tolist()
+
+                        new_ids = []
+                        new_attn = []
+
+                        # replace placeholders
+                        for i, token_id in enumerate(ids_list):
+                            if token_id == image_token_id:
+                                placeholder_ids = image_features["image_placeholders"][idx]
+                                new_ids.extend(placeholder_ids)
+                                new_attn.extend([1] * len(placeholder_ids))
+                                idx += 1
+                            else:
+                                new_ids.append(token_id)
+                                new_attn.append(attn_list[i])
+
+                        # Converts back to tensors
+                        ids_tensor = torch.tensor(new_ids, dtype=torch.long)
+                        attn_mask = torch.tensor(new_attn, dtype=torch.long)
+                    else:
+                        raise RuntimeError(
+                            'Mismatch between the images you provided and the number of placeholder present in the text')
+
+                replaced_ids_list.append(ids_tensor)
+                replaced_attn_mask_list.append(attn_mask)
+
+            if replaced_ids_list:
+                replaced_and_tokenized_ids = torch.stack(replaced_ids_list)
+                replaced_and_tokenized_attn_mask = torch.stack(replaced_attn_mask_list)
+            else:
+                replaced_and_tokenized_ids = torch.tensor([], dtype=torch.long)
+                replaced_and_tokenized_attn_mask = torch.tensor([], dtype=torch.long)
+
+            # Create the output with text features
+            output = BatchFeature(
+                data={
+                    "input_ids": replaced_and_tokenized_ids,
+                    "attention_mask": replaced_and_tokenized_attn_mask,
+                }
+            )
+
+            # Add image features if present
+            if image_features:
+                output["pixel_values"] = processed_images
+                output['grids'] = grids
+
+            return output
+
+
+        # If only images were provided
+        return BatchFeature(data=image_features)
+
+
+    def get_image_size(self):
+        height = self.image_processor.crop_size["height"]
+        width = self.image_processor.crop_size["width"]
+        return height, width
+
+    def get_token_value(self, tok):
+        return self.tokenizer.get_vocab()[self.extra_special_tokens[tok]]
+
+    def construct_image_placeholders(self, grid):
+
+        image_placeholders = [self.get_token_value('image_start'),
+                              self.get_token_value('image_atom'),
+                              self.get_token_value('image_prefix')]
+        if grid[0] * grid[1] > 1:
+            for r in range(grid[0]):
+                for c in range(grid[1]):
+                    image_placeholders.append(self.get_token_value('image_atom') )
+                    if c < grid[1] - 1:
+                        image_placeholders.append(self.get_token_value('image_col_sep'))
+                if r < grid[0] - 1:
+                    image_placeholders.append(self.get_token_value('image_row_sep'))
+        image_placeholders.append(self.get_token_value('image_end'))
+        # return image_placeholders
+
+        image_atom_token_id = self.get_token_value('image_atom')
+        # Extract the padding token ID from tokenizer
+        image_padding_token_id = self.get_token_value('image_pad')
+
+        # Create a new list with padding tokens inserted
+        padded_placeholder_tokens = []
+        for token in image_placeholders:
+            padded_placeholder_tokens.append(token)
+            if token == image_atom_token_id:
+                # Add 255 padding tokens after each image atom token
+                padded_placeholder_tokens.extend([image_padding_token_id] * 255)
+        return padded_placeholder_tokens
+
+    def preprocess_image(self, image: PIL.Image.Image, max_partition, covering_threshold, convert_to_rgb, return_tensors):
+        def _preprocess(img: PIL.Image.Image, side):
+            # first resize and preprocess
+            w, h = img.size
+            if w == h:
+                new_width = new_height = side
+            elif w > h:
+                new_width = side
+                new_height = int(h / w * new_width)
+            else:
+                new_height = side
+                new_width = int(w / h * new_height)
+            new_size = dict(height=new_height, width=new_width)
+            pixel_values = self.image_processor.preprocess(img, size=new_size, return_tensors=return_tensors)['pixel_values']
+
+            # then pad to square
+            square_values = torch.zeros([1, 3, side, side], dtype=pixel_values.dtype, device=pixel_values.device)
+            new_height, new_width = pixel_values.shape[2:]
+            if new_height == new_width:
+                square_values[:, :, :, :] = pixel_values
+            elif new_height > new_width:
+                from_index = (side - new_width) // 2
+                square_values[:, :, :, from_index:from_index + new_width] = pixel_values
+            else:
+                from_index = (side - new_height) // 2
+                square_values[:, :, from_index:from_index + new_height, :] = pixel_values
+
+            return square_values
+
+        def _partition(img, grid) -> list[tuple[int, int, int, int]]:
+            w, h = img.size
+            row_height = h // grid[0]
+            col_width = w // grid[1]
+
+            partition = []
+            for row in range(grid[0]):
+                for col in range(grid[1]):
+                    left = col * col_width
+                    upper = row * row_height
+                    right = w if col == grid[1] - 1 else (col + 1) * col_width
+                    lower = h if row == grid[0] - 1 else (row + 1) * row_height
+                    partition.append((left, upper, right, lower))
+
+            return partition
+
+        def _covering_area(left, upper, right, lower, side):
+            w = right - left
+            h = lower - upper
+            w, h = max(w, h), min(w, h)
+            if w > side:
+                h = h / w * side
+                w = side
+            return w * h
+
+        def _get_best_grid(img, side):
+            img_area = img.size[0] * img.size[1]
+
+            candidate_grids = []
+            for i in range(1, max_partition + 1):
+                for j in range(1, max_partition + 1):
+                    if i * j <= max_partition:
+                        candidate_grids.append((i, j))
+
+            all_grids = []
+            good_grids = []
+            for grid in candidate_grids:
+                partition = _partition(img, grid)
+                covering_ratio = sum([_covering_area(*p, side) for p in partition]) / img_area
+                assert covering_ratio <= 1.0
+                all_grids.append((grid, covering_ratio))
+                if covering_ratio > covering_threshold:
+                    good_grids.append((grid, covering_ratio))
+
+            if len(good_grids) > 0:
+                # pick the good partition with minimum #sub_images and break the tie using covering_ratio
+                return sorted(good_grids, key=lambda x: (x[0][0] * x[0][1], -x[1]))[0][0]
+            else:
+                # pick the partition with maximum covering_ratio and break the tie using #sub_images
+                return sorted(all_grids, key=lambda x: (-x[1], x[0][0] * x[0][1]))[0][0]
+
+        if convert_to_rgb and image.mode != 'RGB':
+            image = image.convert('RGB')
+
+
+        sides = self.get_image_size()
+        if sides[0] != sides[1]:
+            raise ValueError('get_image_size() returns non-square size')
+        side = sides[0]
+        grid = _get_best_grid(image, side)
+        partition = _partition(image, grid)
+        crops = [image.crop(p) for p in partition]
+        if len(crops) > 1:
+            crops.insert(0, image)
+        pixel_values = torch.cat([_preprocess(crop, side) for crop in crops], dim=0)
+        image_placeholders = self.construct_image_placeholders(grid)
+        return pixel_values, image_placeholders, grid
+
+    def batch_decode(self, *args, **kwargs):
+        """
+        This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
+        refer to the docstring of this method for more information.
+        """
+        return self.tokenizer.batch_decode(*args, **kwargs)
+
+    def decode(self, *args, **kwargs):
+        """
+        This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
+        the docstring of this method for more information.
+        """
+        return self.tokenizer.decode(*args, **kwargs)
+
+    def post_process_image_text_to_text(self, generated_outputs):
+        """
+        Post-process the output of the model to decode the text.
+        Args:
+            generated_outputs (`torch.Tensor` or `np.ndarray`):
+                The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
+                or `(sequence_length,)`.
+        Returns:
+            `List[str]`: The decoded text.
+        """
+        return self.tokenizer.batch_decode(
+            generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
+        )
+
+    @property
+    def model_input_names(self):
+        tokenizer_input_names = self.tokenizer.model_input_names
+        image_processor_input_names = self.image_processor.model_input_names
+        names_from_processor = list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
+        return names_from_processor + ["second_per_grid_ts"]
+
+
+AutoProcessor.register("OvisProcessor", OvisProcessor)
\ No newline at end of file