Skip to content

Commit

Permalink
Add SigLIP weights (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Oct 29, 2023
1 parent f6d7c4b commit a767dc4
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 86 deletions.
14 changes: 11 additions & 3 deletions tests/test_vit.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
import timm
import torch

Expand All @@ -16,12 +17,19 @@ def test_resize_pe():
m(torch.randn(1, 3, 256, 256))


def test_from_pretrained():
m = ViT.from_config("Ti_16", 224, True).eval()
@pytest.mark.parametrize(
"config,timm_name",
[
(dict(variant="Ti_16", img_size=224, weights="augreg"), "vit_tiny_patch16_224.augreg_in21k"),
(dict(variant="B_16", img_size=224, weights="siglip"), "vit_base_patch16_siglip_224"),
],
)
def test_from_pretrained(config, timm_name):
m = ViT.from_config(**config).eval()
x = torch.randn(1, 3, 224, 224)
out = m(x)

m_timm = timm.create_model("vit_tiny_patch16_224.augreg_in21k", pretrained=True, num_classes=0).eval()
m_timm = timm.create_model(timm_name, pretrained=True, num_classes=0).eval()
out_timm = m_timm(x)

torch.testing.assert_close(out, out_timm, rtol=2e-5, atol=2e-5)
6 changes: 3 additions & 3 deletions vision_toolbox/backbones/deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def __init__(
) -> None:
# fmt: off
super().__init__(
d_model, depth, n_heads, patch_size, img_size, True, bias, mlp_ratio,
dropout, layer_scale_init, stochastic_depth, norm_eps
d_model, depth, n_heads, patch_size, img_size, True, "cls_token", bias,
mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm_eps,
)
# fmt: on
self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model))
Expand Down Expand Up @@ -133,7 +133,7 @@ def __init__(
):
# fmt: off
super().__init__(
d_model, depth, n_heads, patch_size, img_size, cls_token, bias,
d_model, depth, n_heads, patch_size, img_size, cls_token, "cls_token", bias,
mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm_eps,
)
# fmt: on
Expand Down
40 changes: 12 additions & 28 deletions vision_toolbox/backbones/mlp_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch import Tensor, nn

from ..utils import torch_hub_download
from .vit import MLP
from .vit import MLP, load_flax_conv2d, load_flax_linear, load_flax_ln


class MixerBlock(nn.Module):
Expand Down Expand Up @@ -84,33 +84,17 @@ def from_config(variant: str, patch_size: int, img_size: int, pretrained: bool =
return m

@torch.no_grad()
def load_jax_weights(self, path: str) -> MLPMixer:
jax_weights: Mapping[str, np.ndarray] = np.load(path)
def load_jax_weights(self, path: str) -> None:
jax_weights = {k: torch.from_numpy(v) for k, v in np.load(path).items()}

def get_w(key: str) -> Tensor:
return torch.from_numpy(jax_weights[key])

self.patch_embed.weight.copy_(get_w("stem/kernel").permute(3, 2, 0, 1))
self.patch_embed.bias.copy_(get_w("stem/bias"))
load_flax_conv2d(self.patch_embed, jax_weights, "stem")
load_flax_ln(self.norm, jax_weights, "pre_head_layer_norm")

for i, layer in enumerate(self.layers):
layer: MixerBlock
prefix = f"MixerBlock_{i}/"

layer.norm1.weight.copy_(get_w(prefix + "LayerNorm_0/scale"))
layer.norm1.bias.copy_(get_w(prefix + "LayerNorm_0/bias"))
layer.token_mixing.linear1.weight.copy_(get_w(prefix + "token_mixing/Dense_0/kernel").T)
layer.token_mixing.linear1.bias.copy_(get_w(prefix + "token_mixing/Dense_0/bias"))
layer.token_mixing.linear2.weight.copy_(get_w(prefix + "token_mixing/Dense_1/kernel").T)
layer.token_mixing.linear2.bias.copy_(get_w(prefix + "token_mixing/Dense_1/bias"))

layer.norm2.weight.copy_(get_w(prefix + "LayerNorm_1/scale"))
layer.norm2.bias.copy_(get_w(prefix + "LayerNorm_1/bias"))
layer.channel_mixing.linear1.weight.copy_(get_w(prefix + "channel_mixing/Dense_0/kernel").T)
layer.channel_mixing.linear1.bias.copy_(get_w(prefix + "channel_mixing/Dense_0/bias"))
layer.channel_mixing.linear2.weight.copy_(get_w(prefix + "channel_mixing/Dense_1/kernel").T)
layer.channel_mixing.linear2.bias.copy_(get_w(prefix + "channel_mixing/Dense_1/bias"))

self.norm.weight.copy_(get_w("pre_head_layer_norm/scale"))
self.norm.bias.copy_(get_w("pre_head_layer_norm/bias"))
return self
load_flax_ln(layer.norm1, jax_weights, f"MixerBlock_{i}/LayerNorm_0")
load_flax_linear(layer.token_mixing.linear1, jax_weights, f"MixerBlock_{i}/token_mixing/Dense_0")
load_flax_linear(layer.token_mixing.linear2, jax_weights, f"MixerBlock_{i}/token_mixing/Dense_1")

load_flax_ln(layer.norm2, jax_weights, f"MixerBlock_{i}/LayerNorm_1")
load_flax_linear(layer.channel_mixing.linear1, jax_weights, f"MixerBlock_{i}/channel_mixing/Dense_0")
load_flax_linear(layer.channel_mixing.linear2, jax_weights, f"MixerBlock_{i}/channel_mixing/Dense_1")
2 changes: 1 addition & 1 deletion vision_toolbox/backbones/swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def forward(self, x: Tensor) -> Tensor:
attn_bias = attn_bias + self.attn_mask.unsqueeze(1) # add n_heads dim

x, nH, nW = window_partition(x, self.window_size) # (B * nH * nW, win_size * win_size, C)
x = super().forward(x, attn_bias)
x = super().forward(x, attn_bias=attn_bias)
x = window_unpartition(x, self.window_size, nH, nW) # (B, H, W, C)

if self.shift > 0:
Expand Down
183 changes: 132 additions & 51 deletions vision_toolbox/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from __future__ import annotations

from functools import partial
from typing import Mapping

import numpy as np
import torch
Expand All @@ -27,10 +26,14 @@ def __init__(self, d_model: int, n_heads: int, bias: bool = True, dropout: float
self.dropout = dropout
self.scale = (d_model // n_heads) ** (-0.5)

def forward(self, x: Tensor, attn_bias: Tensor | None = None) -> Tensor:
q = self.q_proj(x).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3) # (B, n_heads, L, head_dim)
k = self.k_proj(x).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3)
v = self.v_proj(x).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3)
def forward(
self, q: Tensor, k: Tensor | None = None, v: Tensor | None = None, *, attn_bias: Tensor | None = None
) -> Tensor:
k = q if k is None else k
v = k if v is None else v
q = self.q_proj(q).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3) # (B, n_heads, L, head_dim)
k = self.k_proj(k).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3)
v = self.v_proj(v).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3)

if hasattr(F, "scaled_dot_product_attention"):
out = F.scaled_dot_product_attention(q, k, v, attn_bias, self.dropout if self.training else 0.0)
Expand Down Expand Up @@ -89,6 +92,22 @@ def forward(self, x: Tensor) -> Tensor:
return x


class MHAPooling(nn.Module):
def __init__(
self, d_model: int, n_heads: int, bias: bool = True, mlp_ratio: float = 4.0, norm_eps: float = 1e-6
) -> None:
super().__init__()
self.probe = nn.Parameter(torch.zeros(1, 1, d_model))
self.mha = MHA(d_model, n_heads, bias)
self.norm = nn.LayerNorm(d_model, norm_eps)
self.mlp = MLP(d_model, int(d_model * mlp_ratio))

def forward(self, x: Tensor) -> Tensor:
x = self.mha(self.probe, x).squeeze(1)
x = x + self.mlp(self.norm(x))
return x


class ViT(nn.Module):
def __init__(
self,
Expand All @@ -98,6 +117,7 @@ def __init__(
patch_size: int,
img_size: int,
cls_token: bool = True,
pool_type: str = "cls_token",
bias: bool = True,
mlp_ratio: float = 4.0,
dropout: float = 0.0,
Expand All @@ -118,13 +138,23 @@ def __init__(
self.layers.append(block)

self.norm = nn.LayerNorm(d_model, norm_eps)
self.pool_type = pool_type
self.pooler = MHAPooling(d_model, n_heads, bias, mlp_ratio, norm_eps) if pool_type == "mha" else None

def forward(self, imgs: Tensor) -> Tensor:
out = self.patch_embed(imgs).flatten(2).transpose(1, 2) + self.pe # (N, C, H, W) -> (N, H*W, C)
if self.cls_token is not None:
out = torch.cat([self.cls_token, out], 1)
out = self.layers(out)
return self.norm(out[:, 0]) if self.cls_token is not None else self.norm(out).mean(1)

if self.pool_type == "cls_token":
return self.norm(out[:, 0])
elif self.pool_type == "gap":
return self.norm(out).mean(1)
elif self.pool_type == "mha":
return self.pooler(self.norm(out))
else:
raise RuntimeError

@torch.no_grad()
def resize_pe(self, size: int, interpolation_mode: str = "bicubic") -> None:
Expand All @@ -136,7 +166,7 @@ def resize_pe(self, size: int, interpolation_mode: str = "bicubic") -> None:
self.pe = nn.Parameter(pe)

@staticmethod
def from_config(variant: str, img_size: int, pretrained: bool = False) -> ViT:
def from_config(variant: str, img_size: int, *, weights: str | None = None) -> ViT:
variant, patch_size = variant.split("_")

d_model, depth, n_heads = dict(
Expand All @@ -148,9 +178,13 @@ def from_config(variant: str, img_size: int, pretrained: bool = False) -> ViT:
H=(1280, 32, 16),
)[variant]
patch_size = int(patch_size)
m = ViT(d_model, depth, n_heads, patch_size, img_size)
kwargs = dict()
if weights == "siglip":
kwargs.update(cls_token=False, pool_type="mha")

m = ViT(d_model, depth, n_heads, patch_size, img_size, **kwargs)

if pretrained:
if weights == "augreg":
assert img_size == 224
ckpt = {
("Ti", 16): "Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz",
Expand All @@ -160,49 +194,96 @@ def from_config(variant: str, img_size: int, pretrained: bool = False) -> ViT:
("B", 16): "B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz",
("L", 16): "L_16-i21k-300ep-lr_0.001-aug_strong1-wd_0.1-do_0.0-sd_0.0.npz",
}[(variant, patch_size)]
base_url = "https://storage.googleapis.com/vit_models/augreg/"
m.load_jax_weights(torch_hub_download(base_url + ckpt))
m.load_flax_ckpt(f"augreg/{ckpt}")

elif weights == "siglip":
ckpt = {
("B", 16, 224): "webli_en_b16_224_63724782.npz",
("B", 16, 256): "webli_en_b16_256_60500360.npz",
("B", 16, 384): "webli_en_b16_384_68578854.npz",
("B", 16, 512): "webli_en_b16_512_68580893.npz",
("L", 16, 256): "webli_en_l16_256_60552751.npz",
("L", 16, 384): "webli_en_l16_384_63634585.npz",
}[(variant, patch_size, img_size)]
m.load_flax_ckpt(f"siglip/{ckpt}", big_vision=True, prefix="params/img/")

elif not weights is None:
raise ValueError(f"Unsupported weights={weights}")

return m

# weights from https://github.com/google-research/vision_transformer
@torch.no_grad()
def load_jax_weights(self, path: str) -> ViT:
jax_weights: Mapping[str, np.ndarray] = np.load(path)

def get_w(key: str) -> Tensor:
return torch.from_numpy(jax_weights[key])

self.cls_token.copy_(get_w("cls"))
pe = get_w("Transformer/posembed_input/pos_embedding")
self.cls_token.add_(pe[:, 0])
self.pe.copy_(pe[:, 1:])
self.patch_embed.weight.copy_(get_w("embedding/kernel").permute(3, 2, 0, 1))
self.patch_embed.bias.copy_(get_w("embedding/bias"))

for idx, layer in enumerate(self.layers):
layer: ViTBlock
prefix = f"Transformer/encoderblock_{idx}/"
mha_prefix = prefix + "MultiHeadDotProductAttention_1/"

layer.mha[0].weight.copy_(get_w(prefix + "LayerNorm_0/scale"))
layer.mha[0].bias.copy_(get_w(prefix + "LayerNorm_0/bias"))
layer.mha[1].q_proj.weight.copy_(get_w(mha_prefix + "query/kernel").flatten(1).T)
layer.mha[1].k_proj.weight.copy_(get_w(mha_prefix + "key/kernel").flatten(1).T)
layer.mha[1].v_proj.weight.copy_(get_w(mha_prefix + "value/kernel").flatten(1).T)
layer.mha[1].q_proj.bias.copy_(get_w(mha_prefix + "query/bias").flatten())
layer.mha[1].k_proj.bias.copy_(get_w(mha_prefix + "key/bias").flatten())
layer.mha[1].v_proj.bias.copy_(get_w(mha_prefix + "value/bias").flatten())
layer.mha[1].out_proj.weight.copy_(get_w(mha_prefix + "out/kernel").flatten(0, 1).T)
layer.mha[1].out_proj.bias.copy_(get_w(mha_prefix + "out/bias"))

layer.mlp[0].weight.copy_(get_w(prefix + "LayerNorm_2/scale"))
layer.mlp[0].bias.copy_(get_w(prefix + "LayerNorm_2/bias"))
layer.mlp[1].linear1.weight.copy_(get_w(prefix + "MlpBlock_3/Dense_0/kernel").T)
layer.mlp[1].linear1.bias.copy_(get_w(prefix + "MlpBlock_3/Dense_0/bias"))
layer.mlp[1].linear2.weight.copy_(get_w(prefix + "MlpBlock_3/Dense_1/kernel").T)
layer.mlp[1].linear2.bias.copy_(get_w(prefix + "MlpBlock_3/Dense_1/bias"))

self.norm.weight.copy_(get_w("Transformer/encoder_norm/scale"))
self.norm.bias.copy_(get_w("Transformer/encoder_norm/bias"))
return self
def load_flax_ckpt(self, ckpt: str, *, big_vision: bool = False, prefix: str = "") -> None:
if big_vision:
# https://github.com/google-research/big_vision
gcs_bucket = "big_vision"
mha_norm = "LayerNorm_0"
mha = "MultiHeadDotProductAttention_0"
mlp_norm = "LayerNorm_1"
mlp = "MlpBlock_0"

else:
# https://github.com/google-research/vision_transformer
gcs_bucket = "vit_models"
mha_norm = "LayerNorm_0"
mha = "MultiHeadDotProductAttention_1"
mlp_norm = "LayerNorm_2"
mlp = "MlpBlock_3"

path = torch_hub_download(f"https://storage.googleapis.com/{gcs_bucket}/{ckpt}")
jax_weights = {k[len(prefix) :]: torch.from_numpy(v) for k, v in np.load(path).items() if k.startswith(prefix)}

if self.cls_token is not None:
self.cls_token.copy_(jax_weights.pop("cls"))
if big_vision:
self.pe.copy_(jax_weights.pop("pos_embedding"))
else:
pe = jax_weights.pop("Transformer/posembed_input/pos_embedding")
self.cls_token.add_(pe[:, 0])
self.pe.copy_(pe[:, 1:])
load_flax_conv2d(self.patch_embed, jax_weights, "embedding")
load_flax_ln(self.norm, jax_weights, "Transformer/encoder_norm")

for i, layer in enumerate(self.layers):
load_flax_ln(layer.mha[0], jax_weights, f"Transformer/encoderblock_{i}/{mha_norm}")
load_flax_mha(layer.mha[1], jax_weights, f"Transformer/encoderblock_{i}/{mha}")
load_flax_ln(layer.mlp[0], jax_weights, f"Transformer/encoderblock_{i}/{mlp_norm}")
load_flax_linear(layer.mlp[1].linear1, jax_weights, f"Transformer/encoderblock_{i}/{mlp}/Dense_0")
load_flax_linear(layer.mlp[1].linear2, jax_weights, f"Transformer/encoderblock_{i}/{mlp}/Dense_1")

# big_vision only
if self.pooler is not None:
self.pooler.probe.copy_(jax_weights.pop("MAPHead_0/probe"))
load_flax_mha(self.pooler.mha, jax_weights, "MAPHead_0/MultiHeadDotProductAttention_0")
load_flax_ln(self.pooler.norm, jax_weights, "MAPHead_0/LayerNorm_0")
load_flax_linear(self.pooler.mlp.linear1, jax_weights, "MAPHead_0/MlpBlock_0/Dense_0")
load_flax_linear(self.pooler.mlp.linear2, jax_weights, "MAPHead_0/MlpBlock_0/Dense_1")

if len(jax_weights) > 0:
print(jax_weights.keys())


def load_flax_ln(norm: nn.LayerNorm, weights: dict[str, Tensor], prefix: str) -> None:
norm.weight.copy_(weights.pop(f"{prefix}/scale"))
norm.bias.copy_(weights.pop(f"{prefix}/bias"))


def load_flax_linear(linear: nn.Linear, weights: dict[str, Tensor], prefix: str) -> None:
linear.weight.copy_(weights.pop(f"{prefix}/kernel").T)
linear.bias.copy_(weights.pop(f"{prefix}/bias"))


def load_flax_conv2d(conv2d: nn.Conv2d, weights: dict[str, Tensor], prefix: str) -> None:
conv2d.weight.copy_(weights.pop(f"{prefix}/kernel").permute(3, 2, 0, 1))
conv2d.bias.copy_(weights.pop(f"{prefix}/bias"))


def load_flax_mha(mha: MHA, weights: dict[str, Tensor], prefix: str) -> None:
mha.q_proj.weight.copy_(weights.pop(f"{prefix}/query/kernel").flatten(1).T)
mha.q_proj.bias.copy_(weights.pop(f"{prefix}/query/bias").flatten())
mha.k_proj.weight.copy_(weights.pop(f"{prefix}/key/kernel").flatten(1).T)
mha.k_proj.bias.copy_(weights.pop(f"{prefix}/key/bias").flatten())
mha.v_proj.weight.copy_(weights.pop(f"{prefix}/value/kernel").flatten(1).T)
mha.v_proj.bias.copy_(weights.pop(f"{prefix}/value/bias").flatten())
mha.out_proj.weight.copy_(weights.pop(f"{prefix}/out/kernel").flatten(0, 1).T)
mha.out_proj.bias.copy_(weights.pop(f"{prefix}/out/bias").flatten())

0 comments on commit a767dc4

Please sign in to comment.