diff --git a/ovis/vllm/README.md b/ovis/vllm/README.md new file mode 100644 index 0000000..6a8ea27 --- /dev/null +++ b/ovis/vllm/README.md @@ -0,0 +1,7 @@ +This directory contains the file for the usage of Ovis2 in vllm +To run the model in vllm, until the PR will be accepted, one should do +```python +from ovis_modeling_directory.ovis_modeling import OvisForConditionalGeneration +ModelRegistry.register_model("Ovis", OvisForConditionalGeneration) +llm = LLM(model="AIDC-AI/Ovis2-2B") # or some other ovis model +``` \ No newline at end of file diff --git a/ovis/vllm/aimv2/configuration_aimv2.py b/ovis/vllm/aimv2/configuration_aimv2.py new file mode 100644 index 0000000..06b2c6d --- /dev/null +++ b/ovis/vllm/aimv2/configuration_aimv2.py @@ -0,0 +1,63 @@ +# copied from https://huggingface.co/apple/aimv2-huge-patch14-448 +from typing import Any + +from transformers.configuration_utils import PretrainedConfig + +__all__ = ["AIMv2Config"] + + +class AIMv2Config(PretrainedConfig): + """This is the configuration class to store the configuration of an [`AIMv2Model`]. + + Instantiating a configuration with the defaults will yield a similar configuration + to that of the [apple/aimv2-large-patch14-224](https://huggingface.co/apple/aimv2-large-patch14-224). + + Args: + hidden_size: Dimension of the hidden representations. + intermediate_size: Dimension of the SwiGLU representations. + num_hidden_layers: Number of hidden layers in the Transformer. + num_attention_heads: Number of attention heads for each attention layer + in the Transformer. + num_channels: Number of input channels. + image_size: Image size. + patch_size: Patch size. + rms_norm_eps: Epsilon value used for the RMS normalization layer. + attention_dropout: Dropout ratio for attention probabilities. + projection_dropout: Dropout ratio for the projection layer after the attention. + qkv_bias: Whether to add a bias to the queries, keys and values. + use_bias: Whether to add a bias in the feed-forward and projection layers. + kwargs: Keyword arguments for the [`PretrainedConfig`]. + """ + + model_type: str = "aimv2" + + def __init__( + self, + hidden_size: int = 1024, + intermediate_size: int = 2816, + num_hidden_layers: int = 24, + num_attention_heads: int = 8, + num_channels: int = 3, + image_size: int = 224, + patch_size: int = 14, + rms_norm_eps: float = 1e-5, + attention_dropout: float = 0.0, + projection_dropout: float = 0.0, + qkv_bias: bool = False, + use_bias: bool = False, + **kwargs: Any, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.rms_norm_eps = rms_norm_eps + + self.projection_dropout = projection_dropout + self.qkv_bias = qkv_bias + self.use_bias = use_bias diff --git a/ovis/vllm/aimv2/modeling_aimv2.py b/ovis/vllm/aimv2/modeling_aimv2.py new file mode 100644 index 0000000..ecee908 --- /dev/null +++ b/ovis/vllm/aimv2/modeling_aimv2.py @@ -0,0 +1,205 @@ +from typing import Optional, Tuple, Union + +import torch +from vllm.attention import Attention, AttentionType +from vllm.config import VllmConfig, CacheConfig +from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear, ColumnParallelLinear +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig + +from .configuration_aimv2 import AIMv2Config +from torch import nn +from torch.nn import functional as F +from transformers.modeling_outputs import BaseModelOutputWithNoAttention +from transformers.modeling_utils import PreTrainedModel + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(dim)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + output = self._norm(x.float()).type_as(x) + return output * self.weight + + def extra_repr(self) -> str: + return f"{tuple(self.weight.shape)}, eps={self.eps}" + + def _norm(self, x: torch.Tensor) -> torch.Tensor: + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + +class AIMv2SwiGLUFFN(nn.Module): + def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig, prefix: str): + super().__init__() + hidden_features = config.intermediate_size + in_features = config.hidden_size + bias = config.use_bias + + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)#ColumnParallelLinear(in_features, + # hidden_features, + # bias=bias, + # quant_config=quant_config, + # prefix=f"{prefix}.fc1") + self.fc2 = nn.Linear(hidden_features, in_features, bias=bias)#ColumnParallelLinear(hidden_features, + # in_features, + # bias=bias, + # quant_config=quant_config, + # prefix=f"{prefix}.fc2") + self.fc3 = nn.Linear(in_features, hidden_features, bias=bias)#RowParallelLinear(in_features, + # hidden_features, + # bias=bias, + # quant_config=quant_config, + # prefix=f"{prefix}.fc3") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_parallel= self.fc1(x)#, _ = self.fc1(x) + gate= self.fc3(x)#, _ = self.fc3(x) + x_parallel = F.silu(x_parallel) * gate + out =self.fc2(x_parallel)#, _ = self.fc2(x_parallel) + return out + + + +class AIMv2PatchEmbed(nn.Module): + def __init__(self, config: AIMv2Config): + super().__init__() + self.proj = nn.Conv2d( + config.num_channels, + config.hidden_size, + kernel_size=(config.patch_size, config.patch_size), + stride=(config.patch_size, config.patch_size), + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x).flatten(2).transpose(1, 2) + x = self.norm(x) + return x + + +class AIMv2ViTPreprocessor(nn.Module): + def __init__(self, config: AIMv2Config): + super().__init__() + num_patches = (config.image_size // config.patch_size) ** 2 + + self.patchifier = AIMv2PatchEmbed(config) + self.pos_embed = nn.Parameter(torch.zeros((1, num_patches, config.hidden_size))) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + tokens = self.patchifier(x) + _, N, _ = tokens.shape + pos_embed = self.pos_embed.to(tokens.device) + tokens = tokens + pos_embed[:, :N] + return tokens + + +class AIMv2Attention(nn.Module): + def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig, prefix: str): + super().__init__() + dim = config.hidden_size + + self.num_heads = config.num_attention_heads + self.qkv = nn.Linear(dim, dim * 3, bias=config.qkv_bias)#QKVParallelLinear( + # hidden_size=dim, + # head_size=dim // config.num_attention_heads, + # total_num_heads=config.num_attention_heads, + # bias=config.qkv_bias, + # quant_config=quant_config, + # prefix=f"{prefix}.qkv") + self.attn_drop = nn.Dropout(config.attention_dropout) + self.proj = nn.Linear(dim, dim, bias=config.use_bias)#RowParallelLinear(input_size=dim, + # output_size=dim, + # bias = config.use_bias, + # quant_config=quant_config, + # prefix=f"{prefix}.proj") + + self.proj_drop = nn.Dropout(config.projection_dropout) + + def forward( # todo might implement multiple attn implementations + self, x: torch.Tensor, mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x) #, _ = self.qkv(x) + + qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv.unbind(0) + + x = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) + x = x.transpose(1, 2).contiguous().reshape(B, N, C) + x= self.proj(x)#, _ = self.proj(x) + x = self.proj_drop(x) + return x + + +class AIMv2Block(nn.Module): + def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig, prefix: str): + super().__init__() + self.attn = AIMv2Attention(config, quant_config=quant_config, prefix=f"{prefix}.attn") + self.norm_1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.mlp = AIMv2SwiGLUFFN(config, quant_config=quant_config, prefix=f"{prefix}.mlp") + self.norm_2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, x: torch.Tensor, mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + x = x + self.attn(self.norm_1(x), mask) + x = x + self.mlp(self.norm_2(x)) + return x + + +class AIMv2Transformer(nn.Module): + def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig, prefix: str): + super().__init__() + + self.blocks = nn.ModuleList( + [AIMv2Block(config, quant_config, prefix=f"{prefix}.blocks.{i}") for i in range(config.num_hidden_layers)] + ) + self.post_trunk_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + tokens: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, ...]]]: + #outputs = [] + for block in self.blocks: # they take the -1 as the ref embeddings, like a clip skip + tokens = block(tokens, mask) + #outputs.append(tokens) + #tokens = self.post_trunk_norm(tokens) NO NORM IN THE OG IMPLEMENTATION + return tokens + + +class AIMv2Model(torch.nn.Module): + def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig, prefix: str = ""): + super().__init__() + self.preprocessor = AIMv2ViTPreprocessor(config) + self.trunk = AIMv2Transformer(config, quant_config=quant_config, prefix=f"{prefix}.trunk") + + @property + def dtype(self): + return self.trunk.blocks[0].attn.qkv.weight.dtype + + @property + def device(self): + return self.trunk.blocks[0].attn.qkv.device + + def forward( + self, + pixel_values: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> Union[ + Tuple[torch.Tensor], + Tuple[torch.Tensor, Tuple[torch.Tensor, ...]], + BaseModelOutputWithNoAttention, + ]: + + x = self.preprocessor(pixel_values) + x = self.trunk( + x, mask + ) + + return x + diff --git a/ovis/vllm/aimv2/visual_tokenizer_aimv2.py b/ovis/vllm/aimv2/visual_tokenizer_aimv2.py new file mode 100644 index 0000000..43df38b --- /dev/null +++ b/ovis/vllm/aimv2/visual_tokenizer_aimv2.py @@ -0,0 +1,120 @@ +import PIL +import torch +from torch import softmax +from torch.nn.functional import gumbel_softmax, pad +from transformers import AutoImageProcessor +from vllm.config import QuantizationConfig +from vllm.model_executor.layers.linear import ColumnParallelLinear +from .aimv2.modeling_aimv2 import AIMv2Model +from ..ovis_config import BaseVisualTokenizerConfig +IMAGE_INDICATOR_IDS = [-301, -302, -303, -304, -305] # kept for vocab prefixed tokens + +class Aimv2VisualTokenizer(torch.nn.Module): + + def __init__(self, config: BaseVisualTokenizerConfig, quant_config: QuantizationConfig, prefix: str = "", **kwargs): + super().__init__() + self.image_processor = AutoImageProcessor.from_pretrained(kwargs['image_processor_name_or_path']) + self.config = config + self.image_processor.do_center_crop = False + self.backbone = AIMv2Model( + config = config.backbone_config, # noqa + quant_config=quant_config, + prefix=f"{prefix}.visual_tokenizer" + ) + head_dim = config.vocab_size - len(IMAGE_INDICATOR_IDS) # reserved tokens for IMAGE_INDICATORS + self.head = torch.nn.Sequential( + ColumnParallelLinear( + config.backbone_config.hidden_size * config.hidden_stride * config.hidden_stride, head_dim, + bias=False, + ), + torch.nn.LayerNorm(head_dim) + ) + + assert all((self.image_processor.do_resize, + not getattr(self.image_processor, 'do_center_crop', False), + self.image_processor.do_rescale, + self.image_processor.do_normalize + )), f"image_processor `{self.image_processor}` is not supported currently" + + @property + def dtype(self): + return self.backbone.dtype + + @property + def device(self): + return self.backbone.device + + def get_backbone(self): + return self.backbone + + def get_image_processor(self): + return self.image_processor + + def get_head(self): + return self.head + + def get_image_size(self): + raise NotImplementedError + + + def tokenize(self, logits): + def st_argmax(y_soft, dim): # straight-through softmax + index = y_soft.max(dim, keepdim=True)[1] + y_hard = torch.zeros_like(y_soft, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0) + ret = y_hard - y_soft.detach() + y_soft + return ret + + if self.config.tokenize_function == 'softmax': + tokens = softmax(logits, dim=-1) + elif self.config.tokenize_function == 'gumbel_argmax': + tokens = gumbel_softmax(logits, tau=self.config.tau, hard=True) + elif self.config.tokenize_function == 'st_argmax': + tokens = st_argmax(logits, dim=-1) + else: + raise ValueError( + f'Invalid `max_type`, expected softmax or gumbel_argmax or st_argmax, but got {self.config.tokenize_function}') + return tokens + + def encode(self, pixel_values): + features = self.backbone(pixel_values) + if self.config.drop_cls_token: + features = features[:, 1:, :] + + # merge number of `hidden_stride * hidden_stride` hidden states together to reduce token sequence length + # e.g., for hidden_stride=2, this leads to a token length reduction: 1024 -> 256 for aimv2 + if self.config.hidden_stride > 1: + n, l, d = features.shape # this `d` maybe different from the above `d + sqrt_l = int(l ** 0.5) + assert sqrt_l ** 2 == l, "The token sequence length should be a perfect square." + features = features.reshape(n, sqrt_l, sqrt_l, d) + pl = (self.config.hidden_stride - (sqrt_l % self.config.hidden_stride)) % self.config.hidden_stride + features = pad(features, (0, 0, 0, pl, 0, pl), "constant", 0) + sqrt_l += pl + features = features.reshape(n, sqrt_l // self.config.hidden_stride, self.config.hidden_stride, + sqrt_l // self.config.hidden_stride, self.config.hidden_stride, d) + features = features.permute(0, 1, 3, 2, 4, 5) # [n, sqrt_l/hs, sqrt_l/hs, hs, hs, d] + features = features.flatten(3) # [n, sqrt_l/hs, sqrt_l/hs, hs*hs*d] + features = features.reshape( + n, -1, self.config.hidden_stride * self.config.hidden_stride * d) + + return features + + def forward(self, pixel_values) -> torch.Tensor: # [BatchSize, ImageShape] -> [BatchSize, #Token, VocabSize] + features = self.encode(pixel_values) + logits, _ = self.head[0](features) # we spllit the sequncial here for not throwing an error + logits = self.head[1](logits) + tokens = self.tokenize(logits) + # tokens' shape is [BatchSize, #Token, VocabSize-5], so padding with [BatchSize, #Token, 5], after + # which, tokens' shape should become [BatchSize, #Token, VocabSize] + batch_size, token_len, _ = tokens.shape + padding_tensor = torch.zeros(size=(batch_size, token_len, len(IMAGE_INDICATOR_IDS)), + dtype=tokens.dtype, + device=tokens.device, + layout=tokens.layout, + requires_grad=False) + tokens = torch.cat((tokens, padding_tensor), dim=2) + return tokens + + + + diff --git a/ovis/vllm/ovis_config.py b/ovis/vllm/ovis_config.py new file mode 100644 index 0000000..6607ade --- /dev/null +++ b/ovis/vllm/ovis_config.py @@ -0,0 +1,208 @@ +from abc import ABC, abstractmethod +from typing import List, Dict, Union, Optional + +from transformers import PretrainedConfig, AutoConfig, AutoModel +from .aimv2.configuration_aimv2 import AIMv2Config +from .aimv2.modeling_aimv2 import AIMv2Model + +IGNORE_ID = -100 +IMAGE_TOKEN_ID = -200 +IMAGE_TOKEN = "" +IMAGE_ATOM_ID = -300 +IMAGE_INDICATOR_IDS = [-301, -302, -303, -304, -305] + +AutoConfig.register("aimv2", AIMv2Config) +AutoModel.register(AIMv2Config, AIMv2Model) + +# ---------------------------------------------------------------------- +# Visual Tokenizer Configuration +# ---------------------------------------------------------------------- +class BaseVisualTokenizerConfig(PretrainedConfig): + def __init__( + self, + vocab_size=16384, + tokenize_function="softmax", + tau=1.0, + depths=None, + drop_cls_token=False, + backbone_config: Optional[Union[PretrainedConfig, dict]] = None, + hidden_stride: int = 1, + **kwargs + ): + super().__init__(**kwargs) + self.vocab_size = vocab_size + self.tokenize_function = tokenize_function + self.tau = tau + if isinstance(depths, str): + depths = [int(x) for x in depths.split('|')] + self.depths = depths + self.backbone_kwargs = {} + self.drop_cls_token = drop_cls_token + if backbone_config is not None: + assert isinstance(backbone_config, (PretrainedConfig, dict)), \ + f"expect `backbone_config` to be instance of PretrainedConfig or dict, but got {type(backbone_config)} type" + if not isinstance(backbone_config, PretrainedConfig): + model_type = backbone_config['model_type'] + backbone_config.pop('model_type') + backbone_config = AutoConfig.for_model(model_type, **backbone_config) + self.backbone_config = backbone_config + self.hidden_stride = hidden_stride + + +class Aimv2VisualTokenizerConfig(BaseVisualTokenizerConfig): + model_type = "aimv2_visual_tokenizer" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + if self.drop_cls_token: + self.drop_cls_token = False + if self.depths: + assert len(self.depths) == 1 + self.backbone_kwargs['num_hidden_layers'] = self.depths[0] + + +AutoConfig.register("aimv2_visual_tokenizer", Aimv2VisualTokenizerConfig) + + +# ---------------------------------------------------------------------- +# Ovis Configuration +# ---------------------------------------------------------------------- +class OvisConfig(PretrainedConfig): + model_type = "chamaleon" # swithched to this to have compatible image token + + def __init__( + self, + llm_config: Optional[Union[PretrainedConfig, dict]] = None, + visual_tokenizer_config: Optional[Union[PretrainedConfig, dict]] = None, + multimodal_max_length=8192, + hidden_size=None, + conversation_formatter_class=None, + llm_attn_implementation=None, + disable_tie_weight=False, + **kwargs + ): + super().__init__(**kwargs) + if llm_config is not None: + assert isinstance(llm_config, (PretrainedConfig, dict)), \ + f"expect `llm_config` to be instance of PretrainedConfig or dict, but got {type(llm_config)} type" + if not isinstance(llm_config, PretrainedConfig): + model_type = llm_config['model_type'] + llm_config.pop('model_type') + llm_config = AutoConfig.for_model(model_type, **llm_config) + self.llm_config = llm_config + if visual_tokenizer_config is not None: + assert isinstance(visual_tokenizer_config, (PretrainedConfig, dict)), \ + f"expect `visual_tokenizer_config` to be instance of PretrainedConfig or dict, but got {type(visual_tokenizer_config)} type" + if not isinstance(visual_tokenizer_config, PretrainedConfig): + model_type = visual_tokenizer_config['model_type'] + visual_tokenizer_config.pop('model_type') + visual_tokenizer_config = AutoConfig.for_model(model_type, **visual_tokenizer_config) + self.visual_tokenizer_config = visual_tokenizer_config + self.multimodal_max_length = multimodal_max_length + self.hidden_size = hidden_size + self.conversation_formatter_class = conversation_formatter_class + self.llm_attn_implementation = llm_attn_implementation + self.disable_tie_weight = disable_tie_weight + #added to work with vllm + self.num_hidden_layers = llm_config.num_hidden_layers + self.vocab_size = llm_config.vocab_size + self.num_attention_heads = llm_config.num_attention_heads if llm_config else 0 + + +# ---------------------------------------------------------------------- +# Conversation Formatter +# ---------------------------------------------------------------------- +class ConversationFormatter(ABC): + support_tokenizer_types = None + + def __init__(self, tokenizer): + tokenizer_type = type(tokenizer).__name__ + assert tokenizer_type in self.support_tokenizer_types, \ + f'Invalid tokenizer type, expected one from `{self.support_tokenizer_types}`, but got `{tokenizer_type}`' + self.tokenizer = tokenizer + self.image_token = IMAGE_TOKEN + self.image_token_id = IMAGE_TOKEN_ID + self.ignore_id = IGNORE_ID + + def _tokenize_with_image_symbol(self, text): + text_chunks = [self.tokenizer(chunk, add_special_tokens=False).input_ids for chunk in + text.split(self.image_token)] + token_ids = [] + num_chuck = len(text_chunks) + for i, chunk in enumerate(text_chunks): + token_ids.extend(chunk) + if i < num_chuck - 1: + token_ids.append(self.image_token_id) + return token_ids + + @abstractmethod + def format(self, conversations: List[Dict], generation_preface=None): + pass + + @abstractmethod + def format_query(self, query, generation_preface=""): + pass + + +class QwenConversationFormatter(ConversationFormatter): + support_tokenizer_types = ['QWenTokenizer', 'Qwen2TokenizerFast'] + + def __init__(self, tokenizer): + super().__init__(tokenizer) + self.from2role = { + "system": "<|im_start|>system\n", + "human": "<|im_start|>user\n", + "gpt": "<|im_start|>assistant\n", + } + self.gpt_token_num = None + self.im_end = "<|im_end|>\n" + self.default_system_prompt = "You are a helpful assistant." + + def format(self, conversations: List[Dict], generation_preface=None): + if self.gpt_token_num is None: + self.gpt_token_num = len(self.tokenizer(self.from2role["gpt"], add_special_tokens=False).input_ids) + + if conversations[0]["from"] != "system": + conversations.insert(0, { + "from": "system", + "value": self.default_system_prompt + }) + + if generation_preface is not None: + conversations.append({ + "from": "gpt", + "value": generation_preface + }) + + prompt = "" + input_ids = [] + labels = [] + num_conversation = len(conversations) + for i, conversation in enumerate(conversations): + frm = conversation["from"] + role = self.from2role[frm] + message = conversation["value"] + text = role + message + if i < num_conversation - 1 or generation_preface is None: + text += self.im_end + prompt += text + token_ids = self._tokenize_with_image_symbol(text) + input_ids.extend(token_ids) + label_ids = [self.ignore_id] * len(token_ids) + if frm == "gpt" and generation_preface is None: + # learning `\n` following `im_end` is meaningless, so the last `\n` token is ignored in label + label_ids[self.gpt_token_num:-1] = token_ids[self.gpt_token_num:-1] + labels.extend(label_ids) + + assert self._tokenize_with_image_symbol(prompt) == input_ids + assert len(input_ids) == len(labels) + + return prompt, input_ids, labels + + def format_query(self, query, generation_preface=""): + prompt, input_ids, _ = self.format([{ + "from": "human", + "value": query + }], generation_preface=generation_preface) + + return prompt, input_ids diff --git a/ovis/vllm/ovis_modeling.py b/ovis/vllm/ovis_modeling.py new file mode 100644 index 0000000..267f30d --- /dev/null +++ b/ovis/vllm/ovis_modeling.py @@ -0,0 +1,597 @@ +# SPDX-License-Identifier: Apache-2.0 + +# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/ovis/modeling_ovis.py +# Copyright 2023 The vLLM team. +# Copyright 2023 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Ovis model.""" +from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, + TypedDict, Union) + +import torch +import torch.nn as nn +from PIL.Image import Image +from torch import Tensor +from transformers import BatchFeature, AutoTokenizer + +from vllm.attention import AttentionMetadata +from vllm.config import VllmConfig +from vllm.model_executor.layers.linear import ColumnParallelLinear +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.model_executor.models import SupportsMultiModal, SupportsPP +from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM +from vllm.model_executor.models.utils import maybe_prefix, flatten_bn, AutoWeightsLoader, init_vllm_registered_model +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, NestedTensors, + ) +from vllm.multimodal.parse import (ImageSize, + MultiModalDataItems) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement ) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors +from .aimv2.visual_tokenizer_aimv2 import Aimv2VisualTokenizer +from .processing_ovis import OvisProcessor +from .ovis_config import OvisConfig + +from torch.nn import init + +# Cannot find the following 2 numbers from hf config. +IGNORE_ID = -100 + + +MAX_SEGMENTS = 10 # default value in the ovis modeling +NUMBER_OF_TOKEN_TO_RESERVE_FOR_SEGMENT = 256 + +class OvisImagePatchInputs(TypedDict): + type: Literal["image_patches"] + flat_data: torch.Tensor + """ + Shape: + `(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)` + """ + + patches_per_image: List[int] + """ + List of number of total patches for each image in the batch. + This is used to restore the first two dimensions of `flat_data`. + """ + +class VisualEmbedding(torch.nn.Embedding): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, visual_tokens: Tensor) -> Tensor: + if visual_tokens.dtype in [torch.int8, torch.int16, torch.int32, torch.int64, torch.long]: + return super().forward(visual_tokens) + return torch.matmul(visual_tokens, self.weight) + + def reset_parameters(self, mean=0., std=1.) -> None: + init.normal_(self.weight, mean=mean, std=std) + self._fill_padding_idx_with_zero() + + @property + def device(self): + return self.weight.device + + @property + def dtype(self): + return self.weight.dtype + +class OvisProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(OvisConfig) + + def get_hf_processor(self): + return self.ctx.get_hf_processor(OvisProcessor) + + def get_image_processor(self) -> OvisProcessor: + return self.get_hf_processor().image_processor # type: ignore + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 1}# totest hte case where a single image ios passed self.get_hf_config().multimodal_max_length // (MAX_SEGMENTS * NUMBER_OF_TOKEN_TO_RESERVE_FOR_SEGMENT)} # 32k is model token limit at the moment + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + + return {"image": (mm_counts['image'] * MAX_SEGMENTS * 256) + 11} # 6 image pos token, don't ask why + + def get_image_size(self) -> ImageSize: + image_processor = self.get_image_processor() + return ImageSize(width=image_processor.size['shortest_edge'] * 9 * 2, + height=image_processor.size['shortest_edge'] * 9 * 2) + + +class OvisDummyInputsBuilder(BaseDummyInputsBuilder[OvisProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int] + ) -> ProcessorInputs: + target_width, target_height = \ + self.info.get_image_size() + num_images = mm_counts.get("image", 0) + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images), + } + + return ProcessorInputs( + prompt_text='''<|im_start|>system +You are a helpful assistant.<|im_end|> +<|im_start|>user + +Describe the image.<|im_end|> +<|im_start|>assistant''', + mm_data=mm_data, + + ) + + +class OvisMultiModalProcessor(BaseMultiModalProcessor[OvisProcessingInfo]): + + def _get_token_value(self, tok): + return self.info.get_tokenizer()(self.info.get_tokenizer().extra_special_tokens[tok])["input_ids"] + + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + if not mm_data: + # # Avoid warning from HF logger for text-only input + prompt_ids = self.info.get_tokenizer().encode(prompt) + #prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) nope + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + + return processed_outputs + + def _apply_hf_processor_tokens_only( + self, + prompt_tokens: list[int], + ) -> list[int]: + + return prompt_tokens + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(pixel_values=MultiModalFieldConfig.batched("image"), grids=MultiModalFieldConfig.batched("image")) + + def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptReplacement]: + + def get_replacement_tokens_ovis(grid): + """ + Calculates the placeholder for the sequence, starting from the grid + + Args: + grid: the grid tuple for the image + Returns: + list: Placeholder sequence for the image with padding + """ + hf_processor = self.info.get_hf_processor() + # Get the base placeholder tokens + placeholder_tokens = hf_processor.construct_image_placeholders(grid) + image_atom_token_id = \ + self.info.get_tokenizer()(self.info.get_tokenizer().extra_special_tokens['image_atom'])['input_ids'][0] + + # Extract the padding token ID from tokenizer + image_padding_token_id = \ + self.info.get_tokenizer()(self.info.get_tokenizer().extra_special_tokens['image_pad'])['input_ids'][0] + + # Create a new list with padding tokens inserted + padded_placeholder_tokens = [] + for token in placeholder_tokens: + padded_placeholder_tokens.append(token) + if token == image_atom_token_id: + # Add 255 padding tokens after each image atom token + padded_placeholder_tokens.extend([image_padding_token_id] * 255) + + return padded_placeholder_tokens + + return [ + PromptReplacement( + modality="image", + target= self.info.get_tokenizer()( + self.info.get_tokenizer() + .extra_special_tokens['image_token'] + )['input_ids'], + replacement=get_replacement_tokens_ovis(grid), + ) + for grid in out_mm_kwargs["grids"]] + + +#useful for comparison of numerical identity between implementations +'''import torch +import tempfile +import os +import numpy as np + + +def compare_with_saved_tensor(tensor, saved_tensor_path): + """ + Loads a tensor from disk and compares it with an existing tensor. + + Args: + tensor: torch.Tensor - The tensor to compare against + saved_tensor_path: str - Path to the saved tensor file + + Returns: + dict: Comparison metrics and information + """ + # Load the saved tensor properly using a file handle + try: + with open(saved_tensor_path, 'rb') as f: + loaded_tensor = torch.load(f) + except TypeError as e: + if 'BFloat16' in str(e): + # Handle BFloat16 tensors by loading with a different approach + with open(saved_tensor_path, 'rb') as f: + loaded_tensor = torch.load(f, map_location=torch.device('cpu')) + # Convert to a supported dtype if needed + loaded_tensor = loaded_tensor.to(dtype=torch.float32) + # Also convert the comparison tensor to float32 + tensor = tensor.to(dtype=torch.float32) + else: + raise e + + # Ensure both tensors are on the same device + if tensor.device != loaded_tensor.device: + loaded_tensor = loaded_tensor.to(tensor.device) + + # Basic shape comparison + shapes_match = tensor.shape == loaded_tensor.shape + + if not shapes_match: + return { + "shapes_match": False, + "tensor_shape": tensor.shape, + "loaded_tensor_shape": loaded_tensor.shape, + "error": "Shapes don't match, cannot compute element-wise metrics" + } + + # Element-wise comparison + diff = tensor - loaded_tensor + abs_diff = torch.abs(diff) + + # Compute metrics + metrics = { + "shapes_match": True, + "tensor_shape": tensor.shape, + "tensor_dtype": str(tensor.dtype), + "loaded_tensor_dtype": str(loaded_tensor.dtype), + "exact_match": torch.equal(tensor, loaded_tensor), + "mean_diff": diff.mean().item(), + "mean_abs_diff": abs_diff.mean().item(), + "max_abs_diff": abs_diff.max().item(), + "min_abs_diff": abs_diff.min().item(), + "std_diff": diff.std().item(), + "l2_norm_diff": torch.norm(diff).item(), + "percent_exact_match": (tensor == loaded_tensor).float().mean().item() * 100, + "nonzero_count_original": torch.count_nonzero(tensor).item(), + "nonzero_count_loaded": torch.count_nonzero(loaded_tensor).item() + } + + # Generate histogram data for the differences + diff_np = diff.flatten().cpu().float().numpy() + hist, bin_edges = np.histogram(diff_np, bins=10) + metrics["diff_histogram"] = { + "counts": hist.tolist(), + "bin_edges": bin_edges.tolist() + } + + # Find positions with largest differences + if tensor.numel() > 0: + top_k = min(10, tensor.numel()) + flat_indices = torch.topk(abs_diff.flatten(), k=top_k)[1] + + # Convert flat indices to multi-dimensional indices + top_diff_positions = [] + for idx in flat_indices: + idx = idx.item() + # Convert flat index to multi-dimensional index + multi_idx = np.unravel_index(idx, tensor.shape) + # Get values at this position + original_val = tensor[multi_idx].item() + loaded_val = loaded_tensor[multi_idx].item() + diff_val = diff[multi_idx].item() + + top_diff_positions.append({ + "position": multi_idx, + "original_value": original_val, + "loaded_value": loaded_val, + "difference": diff_val + }) + + metrics["top_differences"] = top_diff_positions + + return metrics''' + +@MULTIMODAL_REGISTRY.register_processor(OvisMultiModalProcessor, + info=OvisProcessingInfo, + dummy_inputs=OvisDummyInputsBuilder) +class OvisForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + self.padding_idx = config.pad_token_id + self.llm = init_vllm_registered_model( + vllm_config=vllm_config.with_hf_config(config.llm_config), + prefix=maybe_prefix(prefix, "language_model"), + architectures=["Qwen2ForCausalLM"], + ) + + self.visual_tokenizer = Aimv2VisualTokenizer( + config = config.visual_tokenizer_config, + quant_config=quant_config, + prefix=f"{prefix}.visual_tokenizer", + image_processor_name_or_path=config.visual_tokenizer_config.backbone_config.name_or_path, + ).to(self.config.torch_dtype) + + self.vte = VisualEmbedding( + self.config.visual_tokenizer_config.vocab_size, + self.config.hidden_size, + device='cuda', + dtype=self.visual_tokenizer.dtype + ) + + # we'll instantiate a tokenizer and keep just the external mapping + tokenizer = AutoTokenizer.from_pretrained(config.name_or_path) + + self.extra_token_mapping = { + k: tokenizer(v)['input_ids'][0] for k, v in tokenizer.extra_special_tokens.items() + } + + self.extra_token_mapping_for_substitution = { + k: tokenizer(v)['input_ids'][0] for k, v in tokenizer.extra_special_tokens.items() if k in + {'image_atom', + 'image_pad'} + } + + + self.visual_indicators_embeds_dict = None + #VocabParallelEmbedding( if enabled leads to numerical diff + # self.config.visual_tokenizer_config.vocab_size, + # self.config.hidden_size, + # params_dtype=self.visual_tokenizer.dtype, + # quant_config=quant_config, + # prefix=f"{prefix}.vte" + #) + + + #self.make_empty_intermediate_tensors = ( + # self.language_model.make_empty_intermediate_tensors) ? + + + def _init_embed_representation(self): + if not self.visual_indicators_embeds_dict: + # we precalcualte the embeddings for the image tokens + visual_vocab_size = self.visual_tokenizer.config.vocab_size + visual_indicator_embeds = self.vte( + torch.tensor( + list(range(visual_vocab_size - 5, visual_vocab_size)), + dtype=torch.long, + device=self.vte.device + ) + ) + + self.visual_indicators_embeds_dict = { + 'image_start': visual_indicator_embeds[0], + 'image_prefix': visual_indicator_embeds[1], + 'image_col_sep': visual_indicator_embeds[2], + 'image_row_sep': visual_indicator_embeds[3], + 'image_end': visual_indicator_embeds[4], + } + + @property + def sampler(self): + return self.llm.sampler + + def merge_multimodal( + self, + text_input_ids: Union[List[torch.Tensor], torch.Tensor], + pixel_values: Optional[Union[List[torch.Tensor], torch.Tensor, object]], + left_padding: bool = True # must be true during inference + ): # todo check when different sized inputs are batched + # todo the tokenizer do not uses /n + # we need to decompose the pixel_value_tensor + # vllm batches it fi it is ccompatible otherwise it will pass it as list + self._init_embed_representation() + if pixel_values is not None and not isinstance(pixel_values, list): + if pixel_values.dim() == 6: + # if is [tensor_batch, 1, num_segments, ch, w, h] we need -> [tensor_batch, num_segments, ch, w, h] + pixel_values = pixel_values.squeeze(1) + pixel_values = [pixel_value.to(self.config.torch_dtype) for pixel_value in pixel_values] + else: + pixel_values = [pixel_values] + + # When inference, sample can include only text with `None` pixel_value + num_images = [x.shape[0] if x is not None else 0 for x in pixel_values] + if sum(num_images) > 0: + visual_tokens = self.visual_tokenizer( + torch.cat( + [x for x in pixel_values if x is not None], + dim=0).to(self.visual_tokenizer.dtype) + ) + + visual_embeds = self.vte(visual_tokens) # 1:1 numeric eq. + + + else: + # just placeholders + visual_embeds = [None] * len(num_images) + + input_embeds = [] + + for text_input_id, visual_embed in zip(text_input_ids, visual_embeds): + + placeholder_token_mask = torch.zeros_like(text_input_id, dtype=torch.bool) + for value in self.extra_token_mapping_for_substitution.values(): + placeholder_token_mask |= torch.eq(text_input_id, value) + + text_embed = torch.zeros((text_input_id.shape[0],self.llm.model.norm.hidden_size), + device=text_input_id.device, dtype=self.visual_tokenizer.dtype) + text_embed[~placeholder_token_mask] = self.llm.model.embed_tokens(text_input_id[~placeholder_token_mask]) # 1:1 + + for key, indicator_id in self.extra_token_mapping.items(): + if key in self.visual_indicators_embeds_dict: + text_embed[text_input_id == indicator_id] = self.visual_indicators_embeds_dict[key].to(text_embed.device) + #image_atom_positions = torch.where(torch.eq(text_input_id, self.extra_token_mapping['image_atom']))[0].tolist() + #if len(image_atom_positions) > 0: + #if not is_testing: + # input_embed_parts = [] + # prev_image_atom_position = -1 + # for index, image_atom_position in enumerate(image_atom_positions): + # input_embed_parts.append( + # text_embed[prev_image_atom_position + 1:image_atom_position, :]) +# + # input_embed_parts.append(visual_embeds[index]) +# + # prev_image_atom_position = image_atom_position + # if prev_image_atom_position + 1 < text_input_id.shape[0]: + # input_embed_parts.append( + # text_embed[prev_image_atom_position + 1:, :]) +# + # input_embed = torch.cat(input_embed_parts, dim=0) + #else: + + # here we have already preallocated the multimodal tokens (in the testing phase) se the logic should be different + # we should check consider that each atom token should replace 256 text tokens embeddings + + # It just needs this unified verison, since if no images aare present it should just skip this + text_embed[placeholder_token_mask] = visual_embeds.view(-1, text_embed.shape[-1]) + + + #else: + # input_embed = text_embed + + input_embeds.append(text_embed) + + + batch_input_embeds = self.pad_truncate_sequence(input_embeds, batch_first=True, padding_value=0.0, + left_padding=left_padding) + + return batch_input_embeds + + def pad_truncate_sequence(self, sequences: List[torch.Tensor], batch_first: bool = True, padding_value: float = 0.0, left_padding: bool = False) -> torch.Tensor: + if not left_padding: + pad_sequence = torch.nn.utils.rnn.pad_sequence(sequences, batch_first=batch_first, padding_value=padding_value) + return pad_sequence[:,:self.config.multimodal_max_length] + else: + pad_sequence = torch.nn.utils.rnn.pad_sequence([i.flip(dims=[0]) for i in sequences],batch_first=True, padding_value=padding_value).flip(dims=[1]) + return pad_sequence[:,-self.config.multimodal_max_length:] + + + def get_tensor_formatted(self, input: Union[torch.Tensor, List]) -> List[torch.Tensor]: + ''' + if thhe input is list check if its input arte 1d if so usueeze() them in 0 + if it is a tensor it needs to be splittend in a list + :param input: + :return: + ''' + if isinstance(input, list): + output_list = [] + for element in input: + if element.dim() == 1: + output_list.append(element.unsqueeze(0)) + else: + output_list.append(element) + return output_list + else: + return [tensor for tensor in input] if input.dim() > 1 else [input] + + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None and 'pixel_values' in kwargs: # vllm batches the input or make it a list but does not have a attn mask + inputs_embeds = self.merge_multimodal(text_input_ids=self.get_tensor_formatted(input_ids) , + pixel_values=kwargs['pixel_values'],) + #is_testing = kv_caches[0].numel() == 0) valid approach but probably not needed + #input_ids = None + # up until here we have a inputs_embeds 100% numerical identity between the OG HF Transformers implementation and ours + hidden_states = self.llm( + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.llm.logits_processor( + self.llm.lm_head, hidden_states, sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.llm.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) \ No newline at end of file diff --git a/ovis/vllm/processing_ovis.py b/ovis/vllm/processing_ovis.py new file mode 100644 index 0000000..ae5cb1c --- /dev/null +++ b/ovis/vllm/processing_ovis.py @@ -0,0 +1,373 @@ +# coding=utf-8 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import defaultdict +from typing import List, Union + +import PIL +import torch +from transformers import BatchFeature +from transformers.image_utils import ImageInput +from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack +from transformers.tokenization_utils_base import TextInput, PreTokenizedInput + +__all__ = [ 'OvisProcessor'] +IGNORE_ID = -100 + +class OvisProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": False, + }, + "images_kwargs": { + 'max_partition':9, + 'covering_threshold':0.9, + 'convert_to_rgb':True, + 'return_tensors':'pt'}, + } + + + +class OvisProcessor(ProcessorMixin): + r""" + Constructs a Ovis processor which wraps a Ovis image processor and a Qwen2 tokenizer into a single processor. + [`OvisProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the + [`~OvisProcessor.__call__`] and [`~OvisProcessor.decode`] for more information. + Args: + image_processor ([`Qwen2VLImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`Qwen2TokenizerFast`], *optional*): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer"] + valid_kwargs = ["chat_template"] + + image_processor_class = "AutoImageProcessor" + tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") + + def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs): + self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token + self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + images: ImageInput = None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + **kwargs: Unpack[OvisProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to + Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch + tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`. + - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. + - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. + - **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`. + """ + output_kwargs = self._merge_kwargs( + OvisProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + # Process all images first + image_features = {} + if images is not None: + processed_images = [] + image_placeholders_list = [] + grids = [] + + # Process each image + for image in images if isinstance(images, list) else [images]: + pixel_values, image_placeholders, grid = self.preprocess_image( + image=image, **output_kwargs["images_kwargs"] + ) + processed_images.append(pixel_values) + image_placeholders_list.append(image_placeholders) + grids.append(grid) + + # assign all processed images + if processed_images: + image_features["image_placeholders"] = image_placeholders_list + + # Process text input + if text is not None: + + if not isinstance(text, list): + text = [text] + + tokenized_batched_text = self.tokenizer.batch_encode_plus( + text, + **output_kwargs["text_kwargs"] + ) + image_token_id = self.tokenizer(self.tokenizer.extra_special_tokens['image_token'])['input_ids'][0] + replaced_ids_list = [] + replaced_attn_mask_list = [] + idx = 0 + for ids_tensor, attn_mask in zip(tokenized_batched_text['input_ids'], + tokenized_batched_text['attention_mask']): + if image_token_id in ids_tensor and "image_placeholders" in image_features: + if idx < len(image_features["image_placeholders"]): + # Converts in list for ease of use + ids_list = ids_tensor.tolist() + attn_list = attn_mask.tolist() + placeholder_ids = image_features["image_placeholders"][idx] + + new_ids = [] + new_attn = [] + + # replace placeholders + for i, token_id in enumerate(ids_list): + if token_id == image_token_id: + new_ids.extend(placeholder_ids) + new_attn.extend([1] * len(placeholder_ids)) + else: + new_ids.append(token_id) + new_attn.append(attn_list[i]) + + # Converts back to tensors + ids_tensor = torch.tensor(new_ids, dtype=torch.long) + attn_mask = torch.tensor(new_attn, dtype=torch.long) + idx += 1 + else: + raise RuntimeError( + 'Mismatch between the images you provided and the number of placeholder present in the text') + + replaced_ids_list.append(ids_tensor) + replaced_attn_mask_list.append(attn_mask) + + if replaced_ids_list: + replaced_and_tokenized_ids = torch.stack(replaced_ids_list) + replaced_and_tokenized_attn_mask = torch.stack(replaced_attn_mask_list) + else: + replaced_and_tokenized_ids = torch.tensor([], dtype=torch.long) + replaced_and_tokenized_attn_mask = torch.tensor([], dtype=torch.long) + + # Create the output with text features + output = BatchFeature( + data={ + "input_ids": replaced_and_tokenized_ids, + "attention_mask": replaced_and_tokenized_attn_mask, + } + ) + + # Add image features if present + if image_features: + output["pixel_values"] = processed_images + output['grids'] = grids + + return output + + + # If only images were provided + return BatchFeature(data=image_features) + + + def get_image_size(self): + height = self.image_processor.crop_size["height"] + width = self.image_processor.crop_size["width"] + return height, width + + def get_token_value(self, tok): + return self.tokenizer(self.tokenizer.extra_special_tokens[tok])["input_ids"][0] + + def construct_image_placeholders(self, grid): + + image_placeholders = [self.get_token_value('image_start'), + self.get_token_value('image_atom'), + self.get_token_value('image_prefix')] + if grid[0] * grid[1] > 1: + for r in range(grid[0]): + for c in range(grid[1]): + image_placeholders.append(self.get_token_value('image_atom') ) + if c < grid[1] - 1: + image_placeholders.append(self.get_token_value('image_col_sep')) + if r < grid[0] - 1: + image_placeholders.append(self.get_token_value('image_row_sep')) + image_placeholders.append(self.get_token_value('image_end')) + return image_placeholders + + def preprocess_image(self, image: PIL.Image.Image, max_partition, covering_threshold, convert_to_rgb, return_tensors): + def _preprocess(img: PIL.Image.Image, side): + # first resize and preprocess + w, h = img.size + if w == h: + new_width = new_height = side + elif w > h: + new_width = side + new_height = int(h / w * new_width) + else: + new_height = side + new_width = int(w / h * new_height) + new_size = dict(height=new_height, width=new_width) + pixel_values = self.image_processor.preprocess(img, size=new_size, return_tensors=return_tensors)['pixel_values'] + + # then pad to square + square_values = torch.zeros([1, 3, side, side], dtype=pixel_values.dtype, device=pixel_values.device) + new_height, new_width = pixel_values.shape[2:] + if new_height == new_width: + square_values[:, :, :, :] = pixel_values + elif new_height > new_width: + from_index = (side - new_width) // 2 + square_values[:, :, :, from_index:from_index + new_width] = pixel_values + else: + from_index = (side - new_height) // 2 + square_values[:, :, from_index:from_index + new_height, :] = pixel_values + + return square_values + + def _partition(img, grid): + w, h = img.size + row_height = h // grid[0] + col_width = w // grid[1] + + partition = [] + for row in range(grid[0]): + for col in range(grid[1]): + left = col * col_width + upper = row * row_height + right = w if col == grid[1] - 1 else (col + 1) * col_width + lower = h if row == grid[0] - 1 else (row + 1) * row_height + partition.append((left, upper, right, lower)) + + return partition + + def _covering_area(left, upper, right, lower, side): + w = right - left + h = lower - upper + w, h = max(w, h), min(w, h) + if w > side: + h = h / w * side + w = side + return w * h + + def _get_best_grid(img, side): + img_area = img.size[0] * img.size[1] + + candidate_grids = [] + for i in range(1, max_partition + 1): + for j in range(1, max_partition + 1): + if i * j <= max_partition: + candidate_grids.append((i, j)) + + all_grids = [] + good_grids = [] + for grid in candidate_grids: + partition = _partition(img, grid) + covering_ratio = sum([_covering_area(*p, side) for p in partition]) / img_area + assert covering_ratio <= 1.0 + all_grids.append((grid, covering_ratio)) + if covering_ratio > covering_threshold: + good_grids.append((grid, covering_ratio)) + + if len(good_grids) > 0: + # pick the good partition with minimum #sub_images and break the tie using covering_ratio + return sorted(good_grids, key=lambda x: (x[0][0] * x[0][1], -x[1]))[0][0] + else: + # pick the partition with maximum covering_ratio and break the tie using #sub_images + return sorted(all_grids, key=lambda x: (-x[1], x[0][0] * x[0][1]))[0][0] + + if convert_to_rgb and image.mode != 'RGB': + image = image.convert('RGB') + + + sides = self.get_image_size() + if sides[0] != sides[1]: + raise ValueError('get_image_size() returns non-square size') + side = sides[0] + grid = _get_best_grid(image, side) + partition = _partition(image, grid) + crops = [image.crop(p) for p in partition] + if len(crops) > 1: + crops.insert(0, image) + pixel_values = torch.cat([_preprocess(crop, side) for crop in crops], dim=0) + image_placeholders = self.construct_image_placeholders(grid) + return pixel_values, image_placeholders, grid + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + def post_process_image_text_to_text(self, generated_outputs): + """ + Post-process the output of the model to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + + Returns: + `List[str]`: The decoded text. + """ + return self.tokenizer.batch_decode( + generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + names_from_processor = list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + return names_from_processor + ["second_per_grid_ts"] + + +__all__ = ["OvisProcessor"]