Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor style change #15

Merged
merged 4 commits into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions vision_toolbox/backbones/mlp_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def from_config(variant: str, patch_size: int, img_size: int, pretrained: bool =
# Table 1 in https://arxiv.org/pdf/2105.01601.pdf
n_layers, d_model = dict(S=(8, 512), B=(12, 768), L=(24, 1024), H=(32, 1280))[variant]
m = MLPMixer(n_layers, d_model, patch_size, img_size)

if pretrained:
ckpt = {
("S", 8): "gsam/Mixer-S_8.npz",
Expand All @@ -86,6 +87,7 @@ def from_config(variant: str, patch_size: int, img_size: int, pretrained: bool =
}[(variant, patch_size)]
base_url = "https://storage.googleapis.com/mixer_models/"
m.load_jax_weights(torch_hub_download(base_url + ckpt))

return m

@torch.no_grad()
Expand All @@ -97,21 +99,25 @@ def get_w(key: str) -> Tensor:

self.patch_embed.weight.copy_(get_w("stem/kernel").permute(3, 2, 0, 1))
self.patch_embed.bias.copy_(get_w("stem/bias"))

for i, layer in enumerate(self.layers):
layer: MixerBlock
prefix = f"MixerBlock_{i}/"

layer.norm1.weight.copy_(get_w(prefix + "LayerNorm_0/scale"))
layer.norm1.bias.copy_(get_w(prefix + "LayerNorm_0/bias"))
layer.token_mixing.linear1.weight.copy_(get_w(prefix + "token_mixing/Dense_0/kernel").T)
layer.token_mixing.linear1.bias.copy_(get_w(prefix + "token_mixing/Dense_0/bias"))
layer.token_mixing.linear2.weight.copy_(get_w(prefix + "token_mixing/Dense_1/kernel").T)
layer.token_mixing.linear2.bias.copy_(get_w(prefix + "token_mixing/Dense_1/bias"))

layer.norm2.weight.copy_(get_w(prefix + "LayerNorm_1/scale"))
layer.norm2.bias.copy_(get_w(prefix + "LayerNorm_1/bias"))
layer.channel_mixing.linear1.weight.copy_(get_w(prefix + "channel_mixing/Dense_0/kernel").T)
layer.channel_mixing.linear1.bias.copy_(get_w(prefix + "channel_mixing/Dense_0/bias"))
layer.channel_mixing.linear2.weight.copy_(get_w(prefix + "channel_mixing/Dense_1/kernel").T)
layer.channel_mixing.linear2.bias.copy_(get_w(prefix + "channel_mixing/Dense_1/bias"))

self.norm.weight.copy_(get_w("pre_head_layer_norm/scale"))
self.norm.bias.copy_(get_w("pre_head_layer_norm/bias"))
return self
88 changes: 40 additions & 48 deletions vision_toolbox/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from __future__ import annotations

from functools import partial
from typing import Mapping

import numpy as np
Expand All @@ -12,6 +13,7 @@
from torch import Tensor, nn

from ..utils import torch_hub_download
from .base import _act, _norm


class MHA(nn.Module):
Expand All @@ -38,17 +40,23 @@ def forward(self, x: Tensor) -> Tensor:
return out


class TransformerEncoderLayer(nn.Module):
class ViTBlock(nn.Module):
def __init__(
self, d_model: int, n_heads: int, bias: bool = True, dropout: float = 0.0, norm_eps: float = 1e-6
self,
d_model: int,
n_heads: int,
bias: bool = True,
dropout: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
) -> None:
super().__init__()
self.norm1 = nn.LayerNorm(d_model, norm_eps)
self.norm1 = norm(d_model)
self.mha = MHA(d_model, n_heads, bias, dropout)
self.norm2 = nn.LayerNorm(d_model, norm_eps)
self.norm2 = norm(d_model)
self.mlp = nn.Sequential(
nn.Linear(d_model, d_model * 4, bias),
nn.GELU(),
act(),
nn.Linear(d_model * 4, d_model, bias),
nn.Dropout(dropout),
)
Expand All @@ -70,8 +78,10 @@ def __init__(
cls_token: bool = True,
bias: bool = True,
dropout: float = 0.0,
norm_eps: float = 1e-6,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
) -> None:
assert img_size % patch_size == 0
super().__init__()
self.patch_embed = nn.Conv2d(3, d_model, patch_size, patch_size)
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model)) if cls_token else None
Expand All @@ -82,18 +92,14 @@ def __init__(
self.pe = nn.Parameter(torch.empty(1, pe_size, d_model))
nn.init.normal_(self.pe, 0, 0.02)

self.layers = nn.Sequential()
for _ in range(n_layers):
self.layers.append(TransformerEncoderLayer(d_model, n_heads, bias, dropout))
self.norm = nn.LayerNorm(d_model, norm_eps)
self.layers = nn.Sequential(*[ViTBlock(d_model, n_heads, bias, dropout, norm, act) for _ in range(n_layers)])
self.norm = norm(d_model)

def forward(self, imgs: Tensor) -> Tensor:
out = self.patch_embed(imgs)
out = out.flatten(2).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C)
out = self.patch_embed(imgs).flatten(2).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C)
if self.cls_token is not None:
out = torch.cat([self.cls_token.expand(out.shape[0], -1, -1), out], 1)
out = out + self.pe
out = self.layers(out)
out = self.layers(out + self.pe)
out = self.norm(out)
out = out[:, 0] if self.cls_token is not None else out.mean(1)
return out
Expand All @@ -115,6 +121,15 @@ def resize_pe(self, size: int, interpolation_mode: str = "bicubic") -> None:

@staticmethod
def from_config(variant: str, patch_size: int, img_size: int, pretrained: bool = False) -> ViT:
n_layers, d_model, n_heads = dict(
Ti=(12, 192, 3),
S=(12, 384, 6),
B=(12, 768, 12),
L=(24, 1024, 16),
H=(32, 1280, 16),
)[variant]
m = ViT(n_layers, d_model, n_heads, patch_size, img_size)

if pretrained:
ckpt = {
("Ti", 16): "Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz",
Expand All @@ -125,50 +140,27 @@ def from_config(variant: str, patch_size: int, img_size: int, pretrained: bool =
("L", 16): "L_16-i21k-300ep-lr_0.001-aug_strong1-wd_0.1-do_0.0-sd_0.0.npz",
}[(variant, patch_size)]
base_url = "https://storage.googleapis.com/vit_models/augreg/"
m = ViT.from_jax_weights(torch_hub_download(base_url + ckpt))
m.load_jax_weights(torch_hub_download(base_url + ckpt))
if img_size != 224:
m.resize_pe(img_size)

else:
n_layers, d_model, n_heads = dict(
Ti=(12, 192, 3),
S=(12, 384, 6),
B=(12, 768, 12),
L=(24, 1024, 16),
H=(32, 1280, 16),
)[variant]
m = ViT(n_layers, d_model, n_heads, patch_size, img_size)

return m

# weights from https://github.com/google-research/vision_transformer
@torch.no_grad()
@staticmethod
def from_jax_weights(path: str) -> ViT:
def load_jax_weights(self, path: str) -> ViT:
jax_weights: Mapping[str, np.ndarray] = np.load(path)

def get_w(key: str) -> Tensor:
return torch.from_numpy(jax_weights[key])

n_layers = 1
while True:
if f"Transformer/encoderblock_{n_layers}/LayerNorm_0/bias" not in jax_weights:
break
n_layers += 1
self.cls_token.copy_(get_w("cls"))
self.patch_embed.weight.copy_(get_w("embedding/kernel").permute(3, 2, 0, 1))
self.patch_embed.bias.copy_(get_w("embedding/bias"))
self.pe.copy_(get_w("Transformer/posembed_input/pos_embedding"))

d_model = jax_weights["cls"].shape[-1]
n_heads = jax_weights["Transformer/encoderblock_0/MultiHeadDotProductAttention_1/key/bias"].shape[0]
patch_size = jax_weights["embedding/kernel"].shape[0]
img_size = int((jax_weights["Transformer/posembed_input/pos_embedding"].shape[1] - 1) ** 0.5) * patch_size

m = ViT(n_layers, d_model, n_heads, patch_size, img_size)

m.cls_token.copy_(get_w("cls"))
m.patch_embed.weight.copy_(get_w("embedding/kernel").permute(3, 2, 0, 1))
m.patch_embed.bias.copy_(get_w("embedding/bias"))
m.pe.copy_(get_w("Transformer/posembed_input/pos_embedding"))

for idx, layer in enumerate(m.layers):
for idx, layer in enumerate(self.layers):
layer: ViTBlock
prefix = f"Transformer/encoderblock_{idx}/"
mha_prefix = prefix + "MultiHeadDotProductAttention_1/"

Expand All @@ -188,6 +180,6 @@ def get_w(key: str) -> Tensor:
layer.mlp[2].weight.copy_(get_w(prefix + "MlpBlock_3/Dense_1/kernel").T)
layer.mlp[2].bias.copy_(get_w(prefix + "MlpBlock_3/Dense_1/bias"))

m.norm.weight.copy_(get_w("Transformer/encoder_norm/scale"))
m.norm.bias.copy_(get_w("Transformer/encoder_norm/bias"))
return m
self.norm.weight.copy_(get_w("Transformer/encoder_norm/scale"))
self.norm.bias.copy_(get_w("Transformer/encoder_norm/bias"))
return self