From a767dc423ab90c0ea371e843fbd0d94f413df760 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 29 Oct 2023 14:54:04 +0800 Subject: [PATCH] Add SigLIP weights (#22) --- tests/test_vit.py | 14 +- vision_toolbox/backbones/deit.py | 6 +- vision_toolbox/backbones/mlp_mixer.py | 40 ++---- vision_toolbox/backbones/swin.py | 2 +- vision_toolbox/backbones/vit.py | 183 +++++++++++++++++++------- 5 files changed, 159 insertions(+), 86 deletions(-) diff --git a/tests/test_vit.py b/tests/test_vit.py index 207bfc0..c898728 100644 --- a/tests/test_vit.py +++ b/tests/test_vit.py @@ -1,3 +1,4 @@ +import pytest import timm import torch @@ -16,12 +17,19 @@ def test_resize_pe(): m(torch.randn(1, 3, 256, 256)) -def test_from_pretrained(): - m = ViT.from_config("Ti_16", 224, True).eval() +@pytest.mark.parametrize( + "config,timm_name", + [ + (dict(variant="Ti_16", img_size=224, weights="augreg"), "vit_tiny_patch16_224.augreg_in21k"), + (dict(variant="B_16", img_size=224, weights="siglip"), "vit_base_patch16_siglip_224"), + ], +) +def test_from_pretrained(config, timm_name): + m = ViT.from_config(**config).eval() x = torch.randn(1, 3, 224, 224) out = m(x) - m_timm = timm.create_model("vit_tiny_patch16_224.augreg_in21k", pretrained=True, num_classes=0).eval() + m_timm = timm.create_model(timm_name, pretrained=True, num_classes=0).eval() out_timm = m_timm(x) torch.testing.assert_close(out, out_timm, rtol=2e-5, atol=2e-5) diff --git a/vision_toolbox/backbones/deit.py b/vision_toolbox/backbones/deit.py index a5252c8..ae3e486 100644 --- a/vision_toolbox/backbones/deit.py +++ b/vision_toolbox/backbones/deit.py @@ -28,8 +28,8 @@ def __init__( ) -> None: # fmt: off super().__init__( - d_model, depth, n_heads, patch_size, img_size, True, bias, mlp_ratio, - dropout, layer_scale_init, stochastic_depth, norm_eps + d_model, depth, n_heads, patch_size, img_size, True, "cls_token", bias, + mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm_eps, ) # fmt: on self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model)) @@ -133,7 +133,7 @@ def __init__( ): # fmt: off super().__init__( - d_model, depth, n_heads, patch_size, img_size, cls_token, bias, + d_model, depth, n_heads, patch_size, img_size, cls_token, "cls_token", bias, mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm_eps, ) # fmt: on diff --git a/vision_toolbox/backbones/mlp_mixer.py b/vision_toolbox/backbones/mlp_mixer.py index 85d0081..1b2c3ba 100644 --- a/vision_toolbox/backbones/mlp_mixer.py +++ b/vision_toolbox/backbones/mlp_mixer.py @@ -10,7 +10,7 @@ from torch import Tensor, nn from ..utils import torch_hub_download -from .vit import MLP +from .vit import MLP, load_flax_conv2d, load_flax_linear, load_flax_ln class MixerBlock(nn.Module): @@ -84,33 +84,17 @@ def from_config(variant: str, patch_size: int, img_size: int, pretrained: bool = return m @torch.no_grad() - def load_jax_weights(self, path: str) -> MLPMixer: - jax_weights: Mapping[str, np.ndarray] = np.load(path) + def load_jax_weights(self, path: str) -> None: + jax_weights = {k: torch.from_numpy(v) for k, v in np.load(path).items()} - def get_w(key: str) -> Tensor: - return torch.from_numpy(jax_weights[key]) - - self.patch_embed.weight.copy_(get_w("stem/kernel").permute(3, 2, 0, 1)) - self.patch_embed.bias.copy_(get_w("stem/bias")) + load_flax_conv2d(self.patch_embed, jax_weights, "stem") + load_flax_ln(self.norm, jax_weights, "pre_head_layer_norm") for i, layer in enumerate(self.layers): - layer: MixerBlock - prefix = f"MixerBlock_{i}/" - - layer.norm1.weight.copy_(get_w(prefix + "LayerNorm_0/scale")) - layer.norm1.bias.copy_(get_w(prefix + "LayerNorm_0/bias")) - layer.token_mixing.linear1.weight.copy_(get_w(prefix + "token_mixing/Dense_0/kernel").T) - layer.token_mixing.linear1.bias.copy_(get_w(prefix + "token_mixing/Dense_0/bias")) - layer.token_mixing.linear2.weight.copy_(get_w(prefix + "token_mixing/Dense_1/kernel").T) - layer.token_mixing.linear2.bias.copy_(get_w(prefix + "token_mixing/Dense_1/bias")) - - layer.norm2.weight.copy_(get_w(prefix + "LayerNorm_1/scale")) - layer.norm2.bias.copy_(get_w(prefix + "LayerNorm_1/bias")) - layer.channel_mixing.linear1.weight.copy_(get_w(prefix + "channel_mixing/Dense_0/kernel").T) - layer.channel_mixing.linear1.bias.copy_(get_w(prefix + "channel_mixing/Dense_0/bias")) - layer.channel_mixing.linear2.weight.copy_(get_w(prefix + "channel_mixing/Dense_1/kernel").T) - layer.channel_mixing.linear2.bias.copy_(get_w(prefix + "channel_mixing/Dense_1/bias")) - - self.norm.weight.copy_(get_w("pre_head_layer_norm/scale")) - self.norm.bias.copy_(get_w("pre_head_layer_norm/bias")) - return self + load_flax_ln(layer.norm1, jax_weights, f"MixerBlock_{i}/LayerNorm_0") + load_flax_linear(layer.token_mixing.linear1, jax_weights, f"MixerBlock_{i}/token_mixing/Dense_0") + load_flax_linear(layer.token_mixing.linear2, jax_weights, f"MixerBlock_{i}/token_mixing/Dense_1") + + load_flax_ln(layer.norm2, jax_weights, f"MixerBlock_{i}/LayerNorm_1") + load_flax_linear(layer.channel_mixing.linear1, jax_weights, f"MixerBlock_{i}/channel_mixing/Dense_0") + load_flax_linear(layer.channel_mixing.linear2, jax_weights, f"MixerBlock_{i}/channel_mixing/Dense_1") diff --git a/vision_toolbox/backbones/swin.py b/vision_toolbox/backbones/swin.py index da0fe35..b373ba7 100644 --- a/vision_toolbox/backbones/swin.py +++ b/vision_toolbox/backbones/swin.py @@ -78,7 +78,7 @@ def forward(self, x: Tensor) -> Tensor: attn_bias = attn_bias + self.attn_mask.unsqueeze(1) # add n_heads dim x, nH, nW = window_partition(x, self.window_size) # (B * nH * nW, win_size * win_size, C) - x = super().forward(x, attn_bias) + x = super().forward(x, attn_bias=attn_bias) x = window_unpartition(x, self.window_size, nH, nW) # (B, H, W, C) if self.shift > 0: diff --git a/vision_toolbox/backbones/vit.py b/vision_toolbox/backbones/vit.py index 5ded0b0..a09f1f7 100644 --- a/vision_toolbox/backbones/vit.py +++ b/vision_toolbox/backbones/vit.py @@ -5,7 +5,6 @@ from __future__ import annotations from functools import partial -from typing import Mapping import numpy as np import torch @@ -27,10 +26,14 @@ def __init__(self, d_model: int, n_heads: int, bias: bool = True, dropout: float self.dropout = dropout self.scale = (d_model // n_heads) ** (-0.5) - def forward(self, x: Tensor, attn_bias: Tensor | None = None) -> Tensor: - q = self.q_proj(x).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3) # (B, n_heads, L, head_dim) - k = self.k_proj(x).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3) - v = self.v_proj(x).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3) + def forward( + self, q: Tensor, k: Tensor | None = None, v: Tensor | None = None, *, attn_bias: Tensor | None = None + ) -> Tensor: + k = q if k is None else k + v = k if v is None else v + q = self.q_proj(q).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3) # (B, n_heads, L, head_dim) + k = self.k_proj(k).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3) + v = self.v_proj(v).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3) if hasattr(F, "scaled_dot_product_attention"): out = F.scaled_dot_product_attention(q, k, v, attn_bias, self.dropout if self.training else 0.0) @@ -89,6 +92,22 @@ def forward(self, x: Tensor) -> Tensor: return x +class MHAPooling(nn.Module): + def __init__( + self, d_model: int, n_heads: int, bias: bool = True, mlp_ratio: float = 4.0, norm_eps: float = 1e-6 + ) -> None: + super().__init__() + self.probe = nn.Parameter(torch.zeros(1, 1, d_model)) + self.mha = MHA(d_model, n_heads, bias) + self.norm = nn.LayerNorm(d_model, norm_eps) + self.mlp = MLP(d_model, int(d_model * mlp_ratio)) + + def forward(self, x: Tensor) -> Tensor: + x = self.mha(self.probe, x).squeeze(1) + x = x + self.mlp(self.norm(x)) + return x + + class ViT(nn.Module): def __init__( self, @@ -98,6 +117,7 @@ def __init__( patch_size: int, img_size: int, cls_token: bool = True, + pool_type: str = "cls_token", bias: bool = True, mlp_ratio: float = 4.0, dropout: float = 0.0, @@ -118,13 +138,23 @@ def __init__( self.layers.append(block) self.norm = nn.LayerNorm(d_model, norm_eps) + self.pool_type = pool_type + self.pooler = MHAPooling(d_model, n_heads, bias, mlp_ratio, norm_eps) if pool_type == "mha" else None def forward(self, imgs: Tensor) -> Tensor: out = self.patch_embed(imgs).flatten(2).transpose(1, 2) + self.pe # (N, C, H, W) -> (N, H*W, C) if self.cls_token is not None: out = torch.cat([self.cls_token, out], 1) out = self.layers(out) - return self.norm(out[:, 0]) if self.cls_token is not None else self.norm(out).mean(1) + + if self.pool_type == "cls_token": + return self.norm(out[:, 0]) + elif self.pool_type == "gap": + return self.norm(out).mean(1) + elif self.pool_type == "mha": + return self.pooler(self.norm(out)) + else: + raise RuntimeError @torch.no_grad() def resize_pe(self, size: int, interpolation_mode: str = "bicubic") -> None: @@ -136,7 +166,7 @@ def resize_pe(self, size: int, interpolation_mode: str = "bicubic") -> None: self.pe = nn.Parameter(pe) @staticmethod - def from_config(variant: str, img_size: int, pretrained: bool = False) -> ViT: + def from_config(variant: str, img_size: int, *, weights: str | None = None) -> ViT: variant, patch_size = variant.split("_") d_model, depth, n_heads = dict( @@ -148,9 +178,13 @@ def from_config(variant: str, img_size: int, pretrained: bool = False) -> ViT: H=(1280, 32, 16), )[variant] patch_size = int(patch_size) - m = ViT(d_model, depth, n_heads, patch_size, img_size) + kwargs = dict() + if weights == "siglip": + kwargs.update(cls_token=False, pool_type="mha") + + m = ViT(d_model, depth, n_heads, patch_size, img_size, **kwargs) - if pretrained: + if weights == "augreg": assert img_size == 224 ckpt = { ("Ti", 16): "Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz", @@ -160,49 +194,96 @@ def from_config(variant: str, img_size: int, pretrained: bool = False) -> ViT: ("B", 16): "B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz", ("L", 16): "L_16-i21k-300ep-lr_0.001-aug_strong1-wd_0.1-do_0.0-sd_0.0.npz", }[(variant, patch_size)] - base_url = "https://storage.googleapis.com/vit_models/augreg/" - m.load_jax_weights(torch_hub_download(base_url + ckpt)) + m.load_flax_ckpt(f"augreg/{ckpt}") + + elif weights == "siglip": + ckpt = { + ("B", 16, 224): "webli_en_b16_224_63724782.npz", + ("B", 16, 256): "webli_en_b16_256_60500360.npz", + ("B", 16, 384): "webli_en_b16_384_68578854.npz", + ("B", 16, 512): "webli_en_b16_512_68580893.npz", + ("L", 16, 256): "webli_en_l16_256_60552751.npz", + ("L", 16, 384): "webli_en_l16_384_63634585.npz", + }[(variant, patch_size, img_size)] + m.load_flax_ckpt(f"siglip/{ckpt}", big_vision=True, prefix="params/img/") + + elif not weights is None: + raise ValueError(f"Unsupported weights={weights}") return m - # weights from https://github.com/google-research/vision_transformer @torch.no_grad() - def load_jax_weights(self, path: str) -> ViT: - jax_weights: Mapping[str, np.ndarray] = np.load(path) - - def get_w(key: str) -> Tensor: - return torch.from_numpy(jax_weights[key]) - - self.cls_token.copy_(get_w("cls")) - pe = get_w("Transformer/posembed_input/pos_embedding") - self.cls_token.add_(pe[:, 0]) - self.pe.copy_(pe[:, 1:]) - self.patch_embed.weight.copy_(get_w("embedding/kernel").permute(3, 2, 0, 1)) - self.patch_embed.bias.copy_(get_w("embedding/bias")) - - for idx, layer in enumerate(self.layers): - layer: ViTBlock - prefix = f"Transformer/encoderblock_{idx}/" - mha_prefix = prefix + "MultiHeadDotProductAttention_1/" - - layer.mha[0].weight.copy_(get_w(prefix + "LayerNorm_0/scale")) - layer.mha[0].bias.copy_(get_w(prefix + "LayerNorm_0/bias")) - layer.mha[1].q_proj.weight.copy_(get_w(mha_prefix + "query/kernel").flatten(1).T) - layer.mha[1].k_proj.weight.copy_(get_w(mha_prefix + "key/kernel").flatten(1).T) - layer.mha[1].v_proj.weight.copy_(get_w(mha_prefix + "value/kernel").flatten(1).T) - layer.mha[1].q_proj.bias.copy_(get_w(mha_prefix + "query/bias").flatten()) - layer.mha[1].k_proj.bias.copy_(get_w(mha_prefix + "key/bias").flatten()) - layer.mha[1].v_proj.bias.copy_(get_w(mha_prefix + "value/bias").flatten()) - layer.mha[1].out_proj.weight.copy_(get_w(mha_prefix + "out/kernel").flatten(0, 1).T) - layer.mha[1].out_proj.bias.copy_(get_w(mha_prefix + "out/bias")) - - layer.mlp[0].weight.copy_(get_w(prefix + "LayerNorm_2/scale")) - layer.mlp[0].bias.copy_(get_w(prefix + "LayerNorm_2/bias")) - layer.mlp[1].linear1.weight.copy_(get_w(prefix + "MlpBlock_3/Dense_0/kernel").T) - layer.mlp[1].linear1.bias.copy_(get_w(prefix + "MlpBlock_3/Dense_0/bias")) - layer.mlp[1].linear2.weight.copy_(get_w(prefix + "MlpBlock_3/Dense_1/kernel").T) - layer.mlp[1].linear2.bias.copy_(get_w(prefix + "MlpBlock_3/Dense_1/bias")) - - self.norm.weight.copy_(get_w("Transformer/encoder_norm/scale")) - self.norm.bias.copy_(get_w("Transformer/encoder_norm/bias")) - return self + def load_flax_ckpt(self, ckpt: str, *, big_vision: bool = False, prefix: str = "") -> None: + if big_vision: + # https://github.com/google-research/big_vision + gcs_bucket = "big_vision" + mha_norm = "LayerNorm_0" + mha = "MultiHeadDotProductAttention_0" + mlp_norm = "LayerNorm_1" + mlp = "MlpBlock_0" + + else: + # https://github.com/google-research/vision_transformer + gcs_bucket = "vit_models" + mha_norm = "LayerNorm_0" + mha = "MultiHeadDotProductAttention_1" + mlp_norm = "LayerNorm_2" + mlp = "MlpBlock_3" + + path = torch_hub_download(f"https://storage.googleapis.com/{gcs_bucket}/{ckpt}") + jax_weights = {k[len(prefix) :]: torch.from_numpy(v) for k, v in np.load(path).items() if k.startswith(prefix)} + + if self.cls_token is not None: + self.cls_token.copy_(jax_weights.pop("cls")) + if big_vision: + self.pe.copy_(jax_weights.pop("pos_embedding")) + else: + pe = jax_weights.pop("Transformer/posembed_input/pos_embedding") + self.cls_token.add_(pe[:, 0]) + self.pe.copy_(pe[:, 1:]) + load_flax_conv2d(self.patch_embed, jax_weights, "embedding") + load_flax_ln(self.norm, jax_weights, "Transformer/encoder_norm") + + for i, layer in enumerate(self.layers): + load_flax_ln(layer.mha[0], jax_weights, f"Transformer/encoderblock_{i}/{mha_norm}") + load_flax_mha(layer.mha[1], jax_weights, f"Transformer/encoderblock_{i}/{mha}") + load_flax_ln(layer.mlp[0], jax_weights, f"Transformer/encoderblock_{i}/{mlp_norm}") + load_flax_linear(layer.mlp[1].linear1, jax_weights, f"Transformer/encoderblock_{i}/{mlp}/Dense_0") + load_flax_linear(layer.mlp[1].linear2, jax_weights, f"Transformer/encoderblock_{i}/{mlp}/Dense_1") + + # big_vision only + if self.pooler is not None: + self.pooler.probe.copy_(jax_weights.pop("MAPHead_0/probe")) + load_flax_mha(self.pooler.mha, jax_weights, "MAPHead_0/MultiHeadDotProductAttention_0") + load_flax_ln(self.pooler.norm, jax_weights, "MAPHead_0/LayerNorm_0") + load_flax_linear(self.pooler.mlp.linear1, jax_weights, "MAPHead_0/MlpBlock_0/Dense_0") + load_flax_linear(self.pooler.mlp.linear2, jax_weights, "MAPHead_0/MlpBlock_0/Dense_1") + + if len(jax_weights) > 0: + print(jax_weights.keys()) + + +def load_flax_ln(norm: nn.LayerNorm, weights: dict[str, Tensor], prefix: str) -> None: + norm.weight.copy_(weights.pop(f"{prefix}/scale")) + norm.bias.copy_(weights.pop(f"{prefix}/bias")) + + +def load_flax_linear(linear: nn.Linear, weights: dict[str, Tensor], prefix: str) -> None: + linear.weight.copy_(weights.pop(f"{prefix}/kernel").T) + linear.bias.copy_(weights.pop(f"{prefix}/bias")) + + +def load_flax_conv2d(conv2d: nn.Conv2d, weights: dict[str, Tensor], prefix: str) -> None: + conv2d.weight.copy_(weights.pop(f"{prefix}/kernel").permute(3, 2, 0, 1)) + conv2d.bias.copy_(weights.pop(f"{prefix}/bias")) + + +def load_flax_mha(mha: MHA, weights: dict[str, Tensor], prefix: str) -> None: + mha.q_proj.weight.copy_(weights.pop(f"{prefix}/query/kernel").flatten(1).T) + mha.q_proj.bias.copy_(weights.pop(f"{prefix}/query/bias").flatten()) + mha.k_proj.weight.copy_(weights.pop(f"{prefix}/key/kernel").flatten(1).T) + mha.k_proj.bias.copy_(weights.pop(f"{prefix}/key/bias").flatten()) + mha.v_proj.weight.copy_(weights.pop(f"{prefix}/value/kernel").flatten(1).T) + mha.v_proj.bias.copy_(weights.pop(f"{prefix}/value/bias").flatten()) + mha.out_proj.weight.copy_(weights.pop(f"{prefix}/out/kernel").flatten(0, 1).T) + mha.out_proj.bias.copy_(weights.pop(f"{prefix}/out/bias").flatten())