Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions ovis/vllm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
This directory contains the file for the usage of Ovis2 in vllm
To run the model in vllm, until the PR will be accepted, one should do
```python
from ovis_modeling_directory.ovis_modeling import OvisForConditionalGeneration
ModelRegistry.register_model("Ovis", OvisForConditionalGeneration)
llm = LLM(model="AIDC-AI/Ovis2-2B") # or some other ovis model
```
63 changes: 63 additions & 0 deletions ovis/vllm/aimv2/configuration_aimv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# copied from https://huggingface.co/apple/aimv2-huge-patch14-448
from typing import Any

from transformers.configuration_utils import PretrainedConfig

__all__ = ["AIMv2Config"]


class AIMv2Config(PretrainedConfig):
"""This is the configuration class to store the configuration of an [`AIMv2Model`].

Instantiating a configuration with the defaults will yield a similar configuration
to that of the [apple/aimv2-large-patch14-224](https://huggingface.co/apple/aimv2-large-patch14-224).

Args:
hidden_size: Dimension of the hidden representations.
intermediate_size: Dimension of the SwiGLU representations.
num_hidden_layers: Number of hidden layers in the Transformer.
num_attention_heads: Number of attention heads for each attention layer
in the Transformer.
num_channels: Number of input channels.
image_size: Image size.
patch_size: Patch size.
rms_norm_eps: Epsilon value used for the RMS normalization layer.
attention_dropout: Dropout ratio for attention probabilities.
projection_dropout: Dropout ratio for the projection layer after the attention.
qkv_bias: Whether to add a bias to the queries, keys and values.
use_bias: Whether to add a bias in the feed-forward and projection layers.
kwargs: Keyword arguments for the [`PretrainedConfig`].
"""

model_type: str = "aimv2"

def __init__(
self,
hidden_size: int = 1024,
intermediate_size: int = 2816,
num_hidden_layers: int = 24,
num_attention_heads: int = 8,
num_channels: int = 3,
image_size: int = 224,
patch_size: int = 14,
rms_norm_eps: float = 1e-5,
attention_dropout: float = 0.0,
projection_dropout: float = 0.0,
qkv_bias: bool = False,
use_bias: bool = False,
**kwargs: Any,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.attention_dropout = attention_dropout
self.rms_norm_eps = rms_norm_eps

self.projection_dropout = projection_dropout
self.qkv_bias = qkv_bias
self.use_bias = use_bias
205 changes: 205 additions & 0 deletions ovis/vllm/aimv2/modeling_aimv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
from typing import Optional, Tuple, Union

import torch
from vllm.attention import Attention, AttentionType
from vllm.config import VllmConfig, CacheConfig
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear, ColumnParallelLinear
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig

from .configuration_aimv2 import AIMv2Config
from torch import nn
from torch.nn import functional as F
from transformers.modeling_outputs import BaseModelOutputWithNoAttention
from transformers.modeling_utils import PreTrainedModel

class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps

def forward(self, x: torch.Tensor) -> torch.Tensor:
output = self._norm(x.float()).type_as(x)
return output * self.weight

def extra_repr(self) -> str:
return f"{tuple(self.weight.shape)}, eps={self.eps}"

def _norm(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)


class AIMv2SwiGLUFFN(nn.Module):
def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig, prefix: str):
super().__init__()
hidden_features = config.intermediate_size
in_features = config.hidden_size
bias = config.use_bias

self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)#ColumnParallelLinear(in_features,
# hidden_features,
# bias=bias,
# quant_config=quant_config,
# prefix=f"{prefix}.fc1")
self.fc2 = nn.Linear(hidden_features, in_features, bias=bias)#ColumnParallelLinear(hidden_features,
# in_features,
# bias=bias,
# quant_config=quant_config,
# prefix=f"{prefix}.fc2")
self.fc3 = nn.Linear(in_features, hidden_features, bias=bias)#RowParallelLinear(in_features,
# hidden_features,
# bias=bias,
# quant_config=quant_config,
# prefix=f"{prefix}.fc3")

def forward(self, x: torch.Tensor) -> torch.Tensor:
x_parallel= self.fc1(x)#, _ = self.fc1(x)
gate= self.fc3(x)#, _ = self.fc3(x)
x_parallel = F.silu(x_parallel) * gate
out =self.fc2(x_parallel)#, _ = self.fc2(x_parallel)
return out



class AIMv2PatchEmbed(nn.Module):
def __init__(self, config: AIMv2Config):
super().__init__()
self.proj = nn.Conv2d(
config.num_channels,
config.hidden_size,
kernel_size=(config.patch_size, config.patch_size),
stride=(config.patch_size, config.patch_size),
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x).flatten(2).transpose(1, 2)
x = self.norm(x)
return x


class AIMv2ViTPreprocessor(nn.Module):
def __init__(self, config: AIMv2Config):
super().__init__()
num_patches = (config.image_size // config.patch_size) ** 2

self.patchifier = AIMv2PatchEmbed(config)
self.pos_embed = nn.Parameter(torch.zeros((1, num_patches, config.hidden_size)))

def forward(self, x: torch.Tensor) -> torch.Tensor:
tokens = self.patchifier(x)
_, N, _ = tokens.shape
pos_embed = self.pos_embed.to(tokens.device)
tokens = tokens + pos_embed[:, :N]
return tokens


class AIMv2Attention(nn.Module):
def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig, prefix: str):
super().__init__()
dim = config.hidden_size

self.num_heads = config.num_attention_heads
self.qkv = nn.Linear(dim, dim * 3, bias=config.qkv_bias)#QKVParallelLinear(
# hidden_size=dim,
# head_size=dim // config.num_attention_heads,
# total_num_heads=config.num_attention_heads,
# bias=config.qkv_bias,
# quant_config=quant_config,
# prefix=f"{prefix}.qkv")
self.attn_drop = nn.Dropout(config.attention_dropout)
self.proj = nn.Linear(dim, dim, bias=config.use_bias)#RowParallelLinear(input_size=dim,
# output_size=dim,
# bias = config.use_bias,
# quant_config=quant_config,
# prefix=f"{prefix}.proj")

self.proj_drop = nn.Dropout(config.projection_dropout)

def forward( # todo might implement multiple attn implementations
self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x) #, _ = self.qkv(x)

qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)

q, k, v = qkv.unbind(0)

x = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
x = x.transpose(1, 2).contiguous().reshape(B, N, C)
x= self.proj(x)#, _ = self.proj(x)
x = self.proj_drop(x)
return x


class AIMv2Block(nn.Module):
def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig, prefix: str):
super().__init__()
self.attn = AIMv2Attention(config, quant_config=quant_config, prefix=f"{prefix}.attn")
self.norm_1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp = AIMv2SwiGLUFFN(config, quant_config=quant_config, prefix=f"{prefix}.mlp")
self.norm_2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def forward(
self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
x = x + self.attn(self.norm_1(x), mask)
x = x + self.mlp(self.norm_2(x))
return x


class AIMv2Transformer(nn.Module):
def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig, prefix: str):
super().__init__()

self.blocks = nn.ModuleList(
[AIMv2Block(config, quant_config, prefix=f"{prefix}.blocks.{i}") for i in range(config.num_hidden_layers)]
)
self.post_trunk_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def forward(
self,
tokens: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, ...]]]:
#outputs = []
for block in self.blocks: # they take the -1 as the ref embeddings, like a clip skip
tokens = block(tokens, mask)
#outputs.append(tokens)
#tokens = self.post_trunk_norm(tokens) NO NORM IN THE OG IMPLEMENTATION
return tokens


class AIMv2Model(torch.nn.Module):
def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig, prefix: str = ""):
super().__init__()
self.preprocessor = AIMv2ViTPreprocessor(config)
self.trunk = AIMv2Transformer(config, quant_config=quant_config, prefix=f"{prefix}.trunk")

@property
def dtype(self):
return self.trunk.blocks[0].attn.qkv.weight.dtype

@property
def device(self):
return self.trunk.blocks[0].attn.qkv.device

def forward(
self,
pixel_values: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> Union[
Tuple[torch.Tensor],
Tuple[torch.Tensor, Tuple[torch.Tensor, ...]],
BaseModelOutputWithNoAttention,
]:

x = self.preprocessor(pixel_values)
x = self.trunk(
x, mask
)

return x

120 changes: 120 additions & 0 deletions ovis/vllm/aimv2/visual_tokenizer_aimv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import PIL
import torch
from torch import softmax
from torch.nn.functional import gumbel_softmax, pad
from transformers import AutoImageProcessor
from vllm.config import QuantizationConfig
from vllm.model_executor.layers.linear import ColumnParallelLinear
from .aimv2.modeling_aimv2 import AIMv2Model
from ..ovis_config import BaseVisualTokenizerConfig
IMAGE_INDICATOR_IDS = [-301, -302, -303, -304, -305] # kept for vocab prefixed tokens

class Aimv2VisualTokenizer(torch.nn.Module):

def __init__(self, config: BaseVisualTokenizerConfig, quant_config: QuantizationConfig, prefix: str = "", **kwargs):
super().__init__()
self.image_processor = AutoImageProcessor.from_pretrained(kwargs['image_processor_name_or_path'])
self.config = config
self.image_processor.do_center_crop = False
self.backbone = AIMv2Model(
config = config.backbone_config, # noqa
quant_config=quant_config,
prefix=f"{prefix}.visual_tokenizer"
)
head_dim = config.vocab_size - len(IMAGE_INDICATOR_IDS) # reserved tokens for IMAGE_INDICATORS
self.head = torch.nn.Sequential(
ColumnParallelLinear(
config.backbone_config.hidden_size * config.hidden_stride * config.hidden_stride, head_dim,
bias=False,
),
torch.nn.LayerNorm(head_dim)
)

assert all((self.image_processor.do_resize,
not getattr(self.image_processor, 'do_center_crop', False),
self.image_processor.do_rescale,
self.image_processor.do_normalize
)), f"image_processor `{self.image_processor}` is not supported currently"

@property
def dtype(self):
return self.backbone.dtype

@property
def device(self):
return self.backbone.device

def get_backbone(self):
return self.backbone

def get_image_processor(self):
return self.image_processor

def get_head(self):
return self.head

def get_image_size(self):
raise NotImplementedError


def tokenize(self, logits):
def st_argmax(y_soft, dim): # straight-through softmax
index = y_soft.max(dim, keepdim=True)[1]
y_hard = torch.zeros_like(y_soft, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
ret = y_hard - y_soft.detach() + y_soft
return ret

if self.config.tokenize_function == 'softmax':
tokens = softmax(logits, dim=-1)
elif self.config.tokenize_function == 'gumbel_argmax':
tokens = gumbel_softmax(logits, tau=self.config.tau, hard=True)
elif self.config.tokenize_function == 'st_argmax':
tokens = st_argmax(logits, dim=-1)
else:
raise ValueError(
f'Invalid `max_type`, expected softmax or gumbel_argmax or st_argmax, but got {self.config.tokenize_function}')
return tokens

def encode(self, pixel_values):
features = self.backbone(pixel_values)
if self.config.drop_cls_token:
features = features[:, 1:, :]

# merge number of `hidden_stride * hidden_stride` hidden states together to reduce token sequence length
# e.g., for hidden_stride=2, this leads to a token length reduction: 1024 -> 256 for aimv2
if self.config.hidden_stride > 1:
n, l, d = features.shape # this `d` maybe different from the above `d
sqrt_l = int(l ** 0.5)
assert sqrt_l ** 2 == l, "The token sequence length should be a perfect square."
features = features.reshape(n, sqrt_l, sqrt_l, d)
pl = (self.config.hidden_stride - (sqrt_l % self.config.hidden_stride)) % self.config.hidden_stride
features = pad(features, (0, 0, 0, pl, 0, pl), "constant", 0)
sqrt_l += pl
features = features.reshape(n, sqrt_l // self.config.hidden_stride, self.config.hidden_stride,
sqrt_l // self.config.hidden_stride, self.config.hidden_stride, d)
features = features.permute(0, 1, 3, 2, 4, 5) # [n, sqrt_l/hs, sqrt_l/hs, hs, hs, d]
features = features.flatten(3) # [n, sqrt_l/hs, sqrt_l/hs, hs*hs*d]
features = features.reshape(
n, -1, self.config.hidden_stride * self.config.hidden_stride * d)

return features

def forward(self, pixel_values) -> torch.Tensor: # [BatchSize, ImageShape] -> [BatchSize, #Token, VocabSize]
features = self.encode(pixel_values)
logits, _ = self.head[0](features) # we spllit the sequncial here for not throwing an error
logits = self.head[1](logits)
tokens = self.tokenize(logits)
# tokens' shape is [BatchSize, #Token, VocabSize-5], so padding with [BatchSize, #Token, 5], after
# which, tokens' shape should become [BatchSize, #Token, VocabSize]
batch_size, token_len, _ = tokens.shape
padding_tensor = torch.zeros(size=(batch_size, token_len, len(IMAGE_INDICATOR_IDS)),
dtype=tokens.dtype,
device=tokens.device,
layout=tokens.layout,
requires_grad=False)
tokens = torch.cat((tokens, padding_tensor), dim=2)
return tokens




Loading