Skip to content

Commit

Permalink
Minor style change (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Aug 8, 2023
1 parent 55ea654 commit 68ca71c
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 48 deletions.
6 changes: 6 additions & 0 deletions vision_toolbox/backbones/mlp_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def from_config(variant: str, patch_size: int, img_size: int, pretrained: bool =
# Table 1 in https://arxiv.org/pdf/2105.01601.pdf
n_layers, d_model = dict(S=(8, 512), B=(12, 768), L=(24, 1024), H=(32, 1280))[variant]
m = MLPMixer(n_layers, d_model, patch_size, img_size)

if pretrained:
ckpt = {
("S", 8): "gsam/Mixer-S_8.npz",
Expand All @@ -86,6 +87,7 @@ def from_config(variant: str, patch_size: int, img_size: int, pretrained: bool =
}[(variant, patch_size)]
base_url = "https://storage.googleapis.com/mixer_models/"
m.load_jax_weights(torch_hub_download(base_url + ckpt))

return m

@torch.no_grad()
Expand All @@ -97,21 +99,25 @@ def get_w(key: str) -> Tensor:

self.patch_embed.weight.copy_(get_w("stem/kernel").permute(3, 2, 0, 1))
self.patch_embed.bias.copy_(get_w("stem/bias"))

for i, layer in enumerate(self.layers):
layer: MixerBlock
prefix = f"MixerBlock_{i}/"

layer.norm1.weight.copy_(get_w(prefix + "LayerNorm_0/scale"))
layer.norm1.bias.copy_(get_w(prefix + "LayerNorm_0/bias"))
layer.token_mixing.linear1.weight.copy_(get_w(prefix + "token_mixing/Dense_0/kernel").T)
layer.token_mixing.linear1.bias.copy_(get_w(prefix + "token_mixing/Dense_0/bias"))
layer.token_mixing.linear2.weight.copy_(get_w(prefix + "token_mixing/Dense_1/kernel").T)
layer.token_mixing.linear2.bias.copy_(get_w(prefix + "token_mixing/Dense_1/bias"))

layer.norm2.weight.copy_(get_w(prefix + "LayerNorm_1/scale"))
layer.norm2.bias.copy_(get_w(prefix + "LayerNorm_1/bias"))
layer.channel_mixing.linear1.weight.copy_(get_w(prefix + "channel_mixing/Dense_0/kernel").T)
layer.channel_mixing.linear1.bias.copy_(get_w(prefix + "channel_mixing/Dense_0/bias"))
layer.channel_mixing.linear2.weight.copy_(get_w(prefix + "channel_mixing/Dense_1/kernel").T)
layer.channel_mixing.linear2.bias.copy_(get_w(prefix + "channel_mixing/Dense_1/bias"))

self.norm.weight.copy_(get_w("pre_head_layer_norm/scale"))
self.norm.bias.copy_(get_w("pre_head_layer_norm/bias"))
return self
88 changes: 40 additions & 48 deletions vision_toolbox/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from __future__ import annotations

from functools import partial
from typing import Mapping

import numpy as np
Expand All @@ -12,6 +13,7 @@
from torch import Tensor, nn

from ..utils import torch_hub_download
from .base import _act, _norm


class MHA(nn.Module):
Expand All @@ -38,17 +40,23 @@ def forward(self, x: Tensor) -> Tensor:
return out


class TransformerEncoderLayer(nn.Module):
class ViTBlock(nn.Module):
def __init__(
self, d_model: int, n_heads: int, bias: bool = True, dropout: float = 0.0, norm_eps: float = 1e-6
self,
d_model: int,
n_heads: int,
bias: bool = True,
dropout: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
) -> None:
super().__init__()
self.norm1 = nn.LayerNorm(d_model, norm_eps)
self.norm1 = norm(d_model)
self.mha = MHA(d_model, n_heads, bias, dropout)
self.norm2 = nn.LayerNorm(d_model, norm_eps)
self.norm2 = norm(d_model)
self.mlp = nn.Sequential(
nn.Linear(d_model, d_model * 4, bias),
nn.GELU(),
act(),
nn.Linear(d_model * 4, d_model, bias),
nn.Dropout(dropout),
)
Expand All @@ -70,8 +78,10 @@ def __init__(
cls_token: bool = True,
bias: bool = True,
dropout: float = 0.0,
norm_eps: float = 1e-6,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
) -> None:
assert img_size % patch_size == 0
super().__init__()
self.patch_embed = nn.Conv2d(3, d_model, patch_size, patch_size)
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model)) if cls_token else None
Expand All @@ -82,18 +92,14 @@ def __init__(
self.pe = nn.Parameter(torch.empty(1, pe_size, d_model))
nn.init.normal_(self.pe, 0, 0.02)

self.layers = nn.Sequential()
for _ in range(n_layers):
self.layers.append(TransformerEncoderLayer(d_model, n_heads, bias, dropout))
self.norm = nn.LayerNorm(d_model, norm_eps)
self.layers = nn.Sequential(*[ViTBlock(d_model, n_heads, bias, dropout, norm, act) for _ in range(n_layers)])
self.norm = norm(d_model)

def forward(self, imgs: Tensor) -> Tensor:
out = self.patch_embed(imgs)
out = out.flatten(2).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C)
out = self.patch_embed(imgs).flatten(2).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C)
if self.cls_token is not None:
out = torch.cat([self.cls_token.expand(out.shape[0], -1, -1), out], 1)
out = out + self.pe
out = self.layers(out)
out = self.layers(out + self.pe)
out = self.norm(out)
out = out[:, 0] if self.cls_token is not None else out.mean(1)
return out
Expand All @@ -115,6 +121,15 @@ def resize_pe(self, size: int, interpolation_mode: str = "bicubic") -> None:

@staticmethod
def from_config(variant: str, patch_size: int, img_size: int, pretrained: bool = False) -> ViT:
n_layers, d_model, n_heads = dict(
Ti=(12, 192, 3),
S=(12, 384, 6),
B=(12, 768, 12),
L=(24, 1024, 16),
H=(32, 1280, 16),
)[variant]
m = ViT(n_layers, d_model, n_heads, patch_size, img_size)

if pretrained:
ckpt = {
("Ti", 16): "Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz",
Expand All @@ -125,50 +140,27 @@ def from_config(variant: str, patch_size: int, img_size: int, pretrained: bool =
("L", 16): "L_16-i21k-300ep-lr_0.001-aug_strong1-wd_0.1-do_0.0-sd_0.0.npz",
}[(variant, patch_size)]
base_url = "https://storage.googleapis.com/vit_models/augreg/"
m = ViT.from_jax_weights(torch_hub_download(base_url + ckpt))
m.load_jax_weights(torch_hub_download(base_url + ckpt))
if img_size != 224:
m.resize_pe(img_size)

else:
n_layers, d_model, n_heads = dict(
Ti=(12, 192, 3),
S=(12, 384, 6),
B=(12, 768, 12),
L=(24, 1024, 16),
H=(32, 1280, 16),
)[variant]
m = ViT(n_layers, d_model, n_heads, patch_size, img_size)

return m

# weights from https://github.com/google-research/vision_transformer
@torch.no_grad()
@staticmethod
def from_jax_weights(path: str) -> ViT:
def load_jax_weights(self, path: str) -> ViT:
jax_weights: Mapping[str, np.ndarray] = np.load(path)

def get_w(key: str) -> Tensor:
return torch.from_numpy(jax_weights[key])

n_layers = 1
while True:
if f"Transformer/encoderblock_{n_layers}/LayerNorm_0/bias" not in jax_weights:
break
n_layers += 1
self.cls_token.copy_(get_w("cls"))
self.patch_embed.weight.copy_(get_w("embedding/kernel").permute(3, 2, 0, 1))
self.patch_embed.bias.copy_(get_w("embedding/bias"))
self.pe.copy_(get_w("Transformer/posembed_input/pos_embedding"))

d_model = jax_weights["cls"].shape[-1]
n_heads = jax_weights["Transformer/encoderblock_0/MultiHeadDotProductAttention_1/key/bias"].shape[0]
patch_size = jax_weights["embedding/kernel"].shape[0]
img_size = int((jax_weights["Transformer/posembed_input/pos_embedding"].shape[1] - 1) ** 0.5) * patch_size

m = ViT(n_layers, d_model, n_heads, patch_size, img_size)

m.cls_token.copy_(get_w("cls"))
m.patch_embed.weight.copy_(get_w("embedding/kernel").permute(3, 2, 0, 1))
m.patch_embed.bias.copy_(get_w("embedding/bias"))
m.pe.copy_(get_w("Transformer/posembed_input/pos_embedding"))

for idx, layer in enumerate(m.layers):
for idx, layer in enumerate(self.layers):
layer: ViTBlock
prefix = f"Transformer/encoderblock_{idx}/"
mha_prefix = prefix + "MultiHeadDotProductAttention_1/"

Expand All @@ -188,6 +180,6 @@ def get_w(key: str) -> Tensor:
layer.mlp[2].weight.copy_(get_w(prefix + "MlpBlock_3/Dense_1/kernel").T)
layer.mlp[2].bias.copy_(get_w(prefix + "MlpBlock_3/Dense_1/bias"))

m.norm.weight.copy_(get_w("Transformer/encoder_norm/scale"))
m.norm.bias.copy_(get_w("Transformer/encoder_norm/bias"))
return m
self.norm.weight.copy_(get_w("Transformer/encoder_norm/scale"))
self.norm.bias.copy_(get_w("Transformer/encoder_norm/bias"))
return self

0 comments on commit 68ca71c

Please sign in to comment.