Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SigLIP weights #22

Merged
merged 8 commits into from
Oct 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions tests/test_vit.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
import timm
import torch

Expand All @@ -16,12 +17,19 @@ def test_resize_pe():
m(torch.randn(1, 3, 256, 256))


def test_from_pretrained():
m = ViT.from_config("Ti_16", 224, True).eval()
@pytest.mark.parametrize(
"config,timm_name",
[
(dict(variant="Ti_16", img_size=224, weights="augreg"), "vit_tiny_patch16_224.augreg_in21k"),
(dict(variant="B_16", img_size=224, weights="siglip"), "vit_base_patch16_siglip_224"),
],
)
def test_from_pretrained(config, timm_name):
m = ViT.from_config(**config).eval()
x = torch.randn(1, 3, 224, 224)
out = m(x)

m_timm = timm.create_model("vit_tiny_patch16_224.augreg_in21k", pretrained=True, num_classes=0).eval()
m_timm = timm.create_model(timm_name, pretrained=True, num_classes=0).eval()
out_timm = m_timm(x)

torch.testing.assert_close(out, out_timm, rtol=2e-5, atol=2e-5)
6 changes: 3 additions & 3 deletions vision_toolbox/backbones/deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def __init__(
) -> None:
# fmt: off
super().__init__(
d_model, depth, n_heads, patch_size, img_size, True, bias, mlp_ratio,
dropout, layer_scale_init, stochastic_depth, norm_eps
d_model, depth, n_heads, patch_size, img_size, True, "cls_token", bias,
mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm_eps,
)
# fmt: on
self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model))
Expand Down Expand Up @@ -133,7 +133,7 @@ def __init__(
):
# fmt: off
super().__init__(
d_model, depth, n_heads, patch_size, img_size, cls_token, bias,
d_model, depth, n_heads, patch_size, img_size, cls_token, "cls_token", bias,
mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm_eps,
)
# fmt: on
Expand Down
40 changes: 12 additions & 28 deletions vision_toolbox/backbones/mlp_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch import Tensor, nn

from ..utils import torch_hub_download
from .vit import MLP
from .vit import MLP, load_flax_conv2d, load_flax_linear, load_flax_ln


class MixerBlock(nn.Module):
Expand Down Expand Up @@ -84,33 +84,17 @@ def from_config(variant: str, patch_size: int, img_size: int, pretrained: bool =
return m

@torch.no_grad()
def load_jax_weights(self, path: str) -> MLPMixer:
jax_weights: Mapping[str, np.ndarray] = np.load(path)
def load_jax_weights(self, path: str) -> None:
jax_weights = {k: torch.from_numpy(v) for k, v in np.load(path).items()}

def get_w(key: str) -> Tensor:
return torch.from_numpy(jax_weights[key])

self.patch_embed.weight.copy_(get_w("stem/kernel").permute(3, 2, 0, 1))
self.patch_embed.bias.copy_(get_w("stem/bias"))
load_flax_conv2d(self.patch_embed, jax_weights, "stem")
load_flax_ln(self.norm, jax_weights, "pre_head_layer_norm")

for i, layer in enumerate(self.layers):
layer: MixerBlock
prefix = f"MixerBlock_{i}/"

layer.norm1.weight.copy_(get_w(prefix + "LayerNorm_0/scale"))
layer.norm1.bias.copy_(get_w(prefix + "LayerNorm_0/bias"))
layer.token_mixing.linear1.weight.copy_(get_w(prefix + "token_mixing/Dense_0/kernel").T)
layer.token_mixing.linear1.bias.copy_(get_w(prefix + "token_mixing/Dense_0/bias"))
layer.token_mixing.linear2.weight.copy_(get_w(prefix + "token_mixing/Dense_1/kernel").T)
layer.token_mixing.linear2.bias.copy_(get_w(prefix + "token_mixing/Dense_1/bias"))

layer.norm2.weight.copy_(get_w(prefix + "LayerNorm_1/scale"))
layer.norm2.bias.copy_(get_w(prefix + "LayerNorm_1/bias"))
layer.channel_mixing.linear1.weight.copy_(get_w(prefix + "channel_mixing/Dense_0/kernel").T)
layer.channel_mixing.linear1.bias.copy_(get_w(prefix + "channel_mixing/Dense_0/bias"))
layer.channel_mixing.linear2.weight.copy_(get_w(prefix + "channel_mixing/Dense_1/kernel").T)
layer.channel_mixing.linear2.bias.copy_(get_w(prefix + "channel_mixing/Dense_1/bias"))

self.norm.weight.copy_(get_w("pre_head_layer_norm/scale"))
self.norm.bias.copy_(get_w("pre_head_layer_norm/bias"))
return self
load_flax_ln(layer.norm1, jax_weights, f"MixerBlock_{i}/LayerNorm_0")
load_flax_linear(layer.token_mixing.linear1, jax_weights, f"MixerBlock_{i}/token_mixing/Dense_0")
load_flax_linear(layer.token_mixing.linear2, jax_weights, f"MixerBlock_{i}/token_mixing/Dense_1")

load_flax_ln(layer.norm2, jax_weights, f"MixerBlock_{i}/LayerNorm_1")
load_flax_linear(layer.channel_mixing.linear1, jax_weights, f"MixerBlock_{i}/channel_mixing/Dense_0")
load_flax_linear(layer.channel_mixing.linear2, jax_weights, f"MixerBlock_{i}/channel_mixing/Dense_1")
2 changes: 1 addition & 1 deletion vision_toolbox/backbones/swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def forward(self, x: Tensor) -> Tensor:
attn_bias = attn_bias + self.attn_mask.unsqueeze(1) # add n_heads dim

x, nH, nW = window_partition(x, self.window_size) # (B * nH * nW, win_size * win_size, C)
x = super().forward(x, attn_bias)
x = super().forward(x, attn_bias=attn_bias)
x = window_unpartition(x, self.window_size, nH, nW) # (B, H, W, C)

if self.shift > 0:
Expand Down
183 changes: 132 additions & 51 deletions vision_toolbox/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from __future__ import annotations

from functools import partial
from typing import Mapping

import numpy as np
import torch
Expand All @@ -27,10 +26,14 @@ def __init__(self, d_model: int, n_heads: int, bias: bool = True, dropout: float
self.dropout = dropout
self.scale = (d_model // n_heads) ** (-0.5)

def forward(self, x: Tensor, attn_bias: Tensor | None = None) -> Tensor:
q = self.q_proj(x).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3) # (B, n_heads, L, head_dim)
k = self.k_proj(x).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3)
v = self.v_proj(x).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3)
def forward(
self, q: Tensor, k: Tensor | None = None, v: Tensor | None = None, *, attn_bias: Tensor | None = None
) -> Tensor:
k = q if k is None else k
v = k if v is None else v
q = self.q_proj(q).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3) # (B, n_heads, L, head_dim)
k = self.k_proj(k).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3)
v = self.v_proj(v).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3)

if hasattr(F, "scaled_dot_product_attention"):
out = F.scaled_dot_product_attention(q, k, v, attn_bias, self.dropout if self.training else 0.0)
Expand Down Expand Up @@ -89,6 +92,22 @@ def forward(self, x: Tensor) -> Tensor:
return x


class MHAPooling(nn.Module):
def __init__(
self, d_model: int, n_heads: int, bias: bool = True, mlp_ratio: float = 4.0, norm_eps: float = 1e-6
) -> None:
super().__init__()
self.probe = nn.Parameter(torch.zeros(1, 1, d_model))
self.mha = MHA(d_model, n_heads, bias)
self.norm = nn.LayerNorm(d_model, norm_eps)
self.mlp = MLP(d_model, int(d_model * mlp_ratio))

def forward(self, x: Tensor) -> Tensor:
x = self.mha(self.probe, x).squeeze(1)
x = x + self.mlp(self.norm(x))
return x


class ViT(nn.Module):
def __init__(
self,
Expand All @@ -98,6 +117,7 @@ def __init__(
patch_size: int,
img_size: int,
cls_token: bool = True,
pool_type: str = "cls_token",
bias: bool = True,
mlp_ratio: float = 4.0,
dropout: float = 0.0,
Expand All @@ -118,13 +138,23 @@ def __init__(
self.layers.append(block)

self.norm = nn.LayerNorm(d_model, norm_eps)
self.pool_type = pool_type
self.pooler = MHAPooling(d_model, n_heads, bias, mlp_ratio, norm_eps) if pool_type == "mha" else None

def forward(self, imgs: Tensor) -> Tensor:
out = self.patch_embed(imgs).flatten(2).transpose(1, 2) + self.pe # (N, C, H, W) -> (N, H*W, C)
if self.cls_token is not None:
out = torch.cat([self.cls_token, out], 1)
out = self.layers(out)
return self.norm(out[:, 0]) if self.cls_token is not None else self.norm(out).mean(1)

if self.pool_type == "cls_token":
return self.norm(out[:, 0])
elif self.pool_type == "gap":
return self.norm(out).mean(1)
elif self.pool_type == "mha":
return self.pooler(self.norm(out))
else:
raise RuntimeError

@torch.no_grad()
def resize_pe(self, size: int, interpolation_mode: str = "bicubic") -> None:
Expand All @@ -136,7 +166,7 @@ def resize_pe(self, size: int, interpolation_mode: str = "bicubic") -> None:
self.pe = nn.Parameter(pe)

@staticmethod
def from_config(variant: str, img_size: int, pretrained: bool = False) -> ViT:
def from_config(variant: str, img_size: int, *, weights: str | None = None) -> ViT:
variant, patch_size = variant.split("_")

d_model, depth, n_heads = dict(
Expand All @@ -148,9 +178,13 @@ def from_config(variant: str, img_size: int, pretrained: bool = False) -> ViT:
H=(1280, 32, 16),
)[variant]
patch_size = int(patch_size)
m = ViT(d_model, depth, n_heads, patch_size, img_size)
kwargs = dict()
if weights == "siglip":
kwargs.update(cls_token=False, pool_type="mha")

m = ViT(d_model, depth, n_heads, patch_size, img_size, **kwargs)

if pretrained:
if weights == "augreg":
assert img_size == 224
ckpt = {
("Ti", 16): "Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz",
Expand All @@ -160,49 +194,96 @@ def from_config(variant: str, img_size: int, pretrained: bool = False) -> ViT:
("B", 16): "B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz",
("L", 16): "L_16-i21k-300ep-lr_0.001-aug_strong1-wd_0.1-do_0.0-sd_0.0.npz",
}[(variant, patch_size)]
base_url = "https://storage.googleapis.com/vit_models/augreg/"
m.load_jax_weights(torch_hub_download(base_url + ckpt))
m.load_flax_ckpt(f"augreg/{ckpt}")

elif weights == "siglip":
ckpt = {
("B", 16, 224): "webli_en_b16_224_63724782.npz",
("B", 16, 256): "webli_en_b16_256_60500360.npz",
("B", 16, 384): "webli_en_b16_384_68578854.npz",
("B", 16, 512): "webli_en_b16_512_68580893.npz",
("L", 16, 256): "webli_en_l16_256_60552751.npz",
("L", 16, 384): "webli_en_l16_384_63634585.npz",
}[(variant, patch_size, img_size)]
m.load_flax_ckpt(f"siglip/{ckpt}", big_vision=True, prefix="params/img/")

elif not weights is None:
raise ValueError(f"Unsupported weights={weights}")

return m

# weights from https://github.com/google-research/vision_transformer
@torch.no_grad()
def load_jax_weights(self, path: str) -> ViT:
jax_weights: Mapping[str, np.ndarray] = np.load(path)

def get_w(key: str) -> Tensor:
return torch.from_numpy(jax_weights[key])

self.cls_token.copy_(get_w("cls"))
pe = get_w("Transformer/posembed_input/pos_embedding")
self.cls_token.add_(pe[:, 0])
self.pe.copy_(pe[:, 1:])
self.patch_embed.weight.copy_(get_w("embedding/kernel").permute(3, 2, 0, 1))
self.patch_embed.bias.copy_(get_w("embedding/bias"))

for idx, layer in enumerate(self.layers):
layer: ViTBlock
prefix = f"Transformer/encoderblock_{idx}/"
mha_prefix = prefix + "MultiHeadDotProductAttention_1/"

layer.mha[0].weight.copy_(get_w(prefix + "LayerNorm_0/scale"))
layer.mha[0].bias.copy_(get_w(prefix + "LayerNorm_0/bias"))
layer.mha[1].q_proj.weight.copy_(get_w(mha_prefix + "query/kernel").flatten(1).T)
layer.mha[1].k_proj.weight.copy_(get_w(mha_prefix + "key/kernel").flatten(1).T)
layer.mha[1].v_proj.weight.copy_(get_w(mha_prefix + "value/kernel").flatten(1).T)
layer.mha[1].q_proj.bias.copy_(get_w(mha_prefix + "query/bias").flatten())
layer.mha[1].k_proj.bias.copy_(get_w(mha_prefix + "key/bias").flatten())
layer.mha[1].v_proj.bias.copy_(get_w(mha_prefix + "value/bias").flatten())
layer.mha[1].out_proj.weight.copy_(get_w(mha_prefix + "out/kernel").flatten(0, 1).T)
layer.mha[1].out_proj.bias.copy_(get_w(mha_prefix + "out/bias"))

layer.mlp[0].weight.copy_(get_w(prefix + "LayerNorm_2/scale"))
layer.mlp[0].bias.copy_(get_w(prefix + "LayerNorm_2/bias"))
layer.mlp[1].linear1.weight.copy_(get_w(prefix + "MlpBlock_3/Dense_0/kernel").T)
layer.mlp[1].linear1.bias.copy_(get_w(prefix + "MlpBlock_3/Dense_0/bias"))
layer.mlp[1].linear2.weight.copy_(get_w(prefix + "MlpBlock_3/Dense_1/kernel").T)
layer.mlp[1].linear2.bias.copy_(get_w(prefix + "MlpBlock_3/Dense_1/bias"))

self.norm.weight.copy_(get_w("Transformer/encoder_norm/scale"))
self.norm.bias.copy_(get_w("Transformer/encoder_norm/bias"))
return self
def load_flax_ckpt(self, ckpt: str, *, big_vision: bool = False, prefix: str = "") -> None:
if big_vision:
# https://github.com/google-research/big_vision
gcs_bucket = "big_vision"
mha_norm = "LayerNorm_0"
mha = "MultiHeadDotProductAttention_0"
mlp_norm = "LayerNorm_1"
mlp = "MlpBlock_0"

else:
# https://github.com/google-research/vision_transformer
gcs_bucket = "vit_models"
mha_norm = "LayerNorm_0"
mha = "MultiHeadDotProductAttention_1"
mlp_norm = "LayerNorm_2"
mlp = "MlpBlock_3"

path = torch_hub_download(f"https://storage.googleapis.com/{gcs_bucket}/{ckpt}")
jax_weights = {k[len(prefix) :]: torch.from_numpy(v) for k, v in np.load(path).items() if k.startswith(prefix)}

if self.cls_token is not None:
self.cls_token.copy_(jax_weights.pop("cls"))
if big_vision:
self.pe.copy_(jax_weights.pop("pos_embedding"))
else:
pe = jax_weights.pop("Transformer/posembed_input/pos_embedding")
self.cls_token.add_(pe[:, 0])
self.pe.copy_(pe[:, 1:])
load_flax_conv2d(self.patch_embed, jax_weights, "embedding")
load_flax_ln(self.norm, jax_weights, "Transformer/encoder_norm")

for i, layer in enumerate(self.layers):
load_flax_ln(layer.mha[0], jax_weights, f"Transformer/encoderblock_{i}/{mha_norm}")
load_flax_mha(layer.mha[1], jax_weights, f"Transformer/encoderblock_{i}/{mha}")
load_flax_ln(layer.mlp[0], jax_weights, f"Transformer/encoderblock_{i}/{mlp_norm}")
load_flax_linear(layer.mlp[1].linear1, jax_weights, f"Transformer/encoderblock_{i}/{mlp}/Dense_0")
load_flax_linear(layer.mlp[1].linear2, jax_weights, f"Transformer/encoderblock_{i}/{mlp}/Dense_1")

# big_vision only
if self.pooler is not None:
self.pooler.probe.copy_(jax_weights.pop("MAPHead_0/probe"))
load_flax_mha(self.pooler.mha, jax_weights, "MAPHead_0/MultiHeadDotProductAttention_0")
load_flax_ln(self.pooler.norm, jax_weights, "MAPHead_0/LayerNorm_0")
load_flax_linear(self.pooler.mlp.linear1, jax_weights, "MAPHead_0/MlpBlock_0/Dense_0")
load_flax_linear(self.pooler.mlp.linear2, jax_weights, "MAPHead_0/MlpBlock_0/Dense_1")

if len(jax_weights) > 0:
print(jax_weights.keys())


def load_flax_ln(norm: nn.LayerNorm, weights: dict[str, Tensor], prefix: str) -> None:
norm.weight.copy_(weights.pop(f"{prefix}/scale"))
norm.bias.copy_(weights.pop(f"{prefix}/bias"))


def load_flax_linear(linear: nn.Linear, weights: dict[str, Tensor], prefix: str) -> None:
linear.weight.copy_(weights.pop(f"{prefix}/kernel").T)
linear.bias.copy_(weights.pop(f"{prefix}/bias"))


def load_flax_conv2d(conv2d: nn.Conv2d, weights: dict[str, Tensor], prefix: str) -> None:
conv2d.weight.copy_(weights.pop(f"{prefix}/kernel").permute(3, 2, 0, 1))
conv2d.bias.copy_(weights.pop(f"{prefix}/bias"))


def load_flax_mha(mha: MHA, weights: dict[str, Tensor], prefix: str) -> None:
mha.q_proj.weight.copy_(weights.pop(f"{prefix}/query/kernel").flatten(1).T)
mha.q_proj.bias.copy_(weights.pop(f"{prefix}/query/bias").flatten())
mha.k_proj.weight.copy_(weights.pop(f"{prefix}/key/kernel").flatten(1).T)
mha.k_proj.bias.copy_(weights.pop(f"{prefix}/key/bias").flatten())
mha.v_proj.weight.copy_(weights.pop(f"{prefix}/value/kernel").flatten(1).T)
mha.v_proj.bias.copy_(weights.pop(f"{prefix}/value/bias").flatten())
mha.out_proj.weight.copy_(weights.pop(f"{prefix}/out/kernel").flatten(0, 1).T)
mha.out_proj.bias.copy_(weights.pop(f"{prefix}/out/bias").flatten())