Skip to content

Commit

Permalink
Merge branch 'main' into mobile-vit
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Jul 23, 2023
2 parents ee98523 + 5416169 commit e901280
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 53 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
- name: Install dependencies
run: |
pip install torch==${{ matrix.pytorch-version }}.* torchvision==${{ matrix.torchvision-version }}.* --extra-index-url https://download.pytorch.org/whl/cpu
pip install pytest
pip install pytest timm
- name: Run tests
run: python -m pytest -v
22 changes: 22 additions & 0 deletions tests/test_vit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import timm
import torch

from vision_toolbox.backbones import ViT


def test_resize_pe():
m = ViT.from_config("Ti", 16, 224)
m(torch.randn(1, 3, 224, 224))
m.resize_pe(256)
m(torch.randn(1, 3, 256, 256))


def test_from_pretrained():
m = ViT.from_config("Ti", 16, 224, True).eval()
x = torch.randn(1, 3, 224, 224)
out = m(x)

m_timm = timm.create_model("vit_tiny_patch16_224.augreg_in21k", pretrained=True, num_classes=0).eval()
out_timm = m_timm(x)

torch.testing.assert_close(out, out_timm, rtol=2e-5, atol=2e-5)
180 changes: 128 additions & 52 deletions vision_toolbox/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@
# https://arxiv.org/abs/2106.10270
# https://github.com/google-research/vision_transformer/blob/main/vit_jax/models_vit.py

from __future__ import annotations

from typing import Mapping

import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn

from ..utils import torch_hub_download


__all__ = ["ViT"]

Expand All @@ -20,6 +25,60 @@
H=dict(n_layers=32, d_model=1280, n_heads=16),
)

checkpoints = {
("Ti", 16): "Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz",
("S", 32): "S_32-i21k-300ep-lr_0.001-aug_none-wd_0.1-do_0.0-sd_0.0.npz",
("S", 16): "S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz",
("B", 32): "B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0.npz",
("B", 16): "B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz",
("L", 16): "L_16-i21k-300ep-lr_0.001-aug_strong1-wd_0.1-do_0.0-sd_0.0.npz",
}


class MHA(nn.Module):
def __init__(self, d_model: int, n_heads: int, bias: bool = True, dropout: float = 0.0) -> None:
super().__init__()
self.in_proj = nn.Linear(d_model, d_model * 3, bias)
self.out_proj = nn.Linear(d_model, d_model)
self.n_heads = n_heads
self.dropout = dropout
self.scale = (d_model // n_heads) ** (-0.5)

def forward(self, x: Tensor) -> Tensor:
qkv = self.in_proj(x)
q, k, v = qkv.unflatten(-1, (3, self.n_heads, -1)).transpose(-2, -4).unbind(-3)

if hasattr(F, "scaled_dot_product_attention"):
out = F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout)
else:
attn = torch.softmax(q @ (k * self.scale).transpose(-1, -2), -1)
out = F.dropout(attn, self.dropout, self.training) @ v

out = out.transpose(-2, -3).flatten(-2)
out = self.out_proj(out)
return out


class TransformerEncoderLayer(nn.Module):
def __init__(
self, d_model: int, n_heads: int, bias: bool = True, dropout: float = 0.0, norm_eps: float = 1e-6
) -> None:
super().__init__()
self.norm1 = nn.LayerNorm(d_model, norm_eps)
self.mha = MHA(d_model, n_heads, bias, dropout)
self.norm2 = nn.LayerNorm(d_model, norm_eps)
self.mlp = nn.Sequential(
nn.Linear(d_model, d_model * 4, bias),
nn.GELU(),
nn.Linear(d_model * 4, d_model, bias),
nn.Dropout(dropout),
)

def forward(self, x: Tensor) -> Tensor:
x = x + self.mha(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x


class ViT(nn.Module):
def __init__(
Expand All @@ -29,11 +88,11 @@ def __init__(
n_heads: int,
patch_size: int,
img_size: int,
mlp_dim: int | None = None,
cls_token: bool = True,
bias: bool = True,
dropout: float = 0.0,
norm_eps: float = 1e-6,
):
) -> None:
super().__init__()
self.patch_embed = nn.Conv2d(3, d_model, patch_size, patch_size)
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model)) if cls_token else None
Expand All @@ -44,29 +103,59 @@ def __init__(
self.pe = nn.Parameter(torch.empty(1, pe_size, d_model))
nn.init.normal_(self.pe, 0, 0.02)

mlp_dim = mlp_dim or d_model * 4
layer = nn.TransformerEncoderLayer(d_model, n_heads, mlp_dim, dropout, "gelu", norm_eps, True, True)
self.encoder = nn.TransformerEncoder(layer, n_layers, nn.LayerNorm(d_model, norm_eps))
self.layers = nn.Sequential()
for _ in range(n_layers):
self.layers.append(TransformerEncoderLayer(d_model, n_heads, bias, dropout))
self.norm = nn.LayerNorm(d_model, norm_eps)

def forward(self, imgs: Tensor) -> Tensor:
out = self.patch_embed(imgs)
out = out.flatten(2).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C)
if self.cls_token is not None:
out = torch.cat([self.cls_token.expand(out.shape[0], -1, -1), out], 1)
out = out + self.pe
out = self.encoder(out)
out = self.layers(out)
out = self.norm(out)
out = out[:, 0] if self.cls_token is not None else out.mean(1)
return out

@torch.no_grad()
def resize_pe(self, size: int, interpolation_mode: str = "bicubic") -> None:
pe = self.pe if self.cls_token is None else self.pe[:, 1:]

old_size = int(pe.shape[1] ** 0.5)
new_size = size // self.patch_embed.weight.shape[2]
pe = pe.unflatten(1, (old_size, old_size)).permute(0, 3, 1, 2)
pe = F.interpolate(pe, (new_size, new_size), mode=interpolation_mode)
pe = pe.permute(0, 2, 3, 1).flatten(1, 2)

if self.cls_token is not None:
pe = torch.cat((self.pe[:, 0:1], pe), 1)

self.pe = nn.Parameter(pe)

@staticmethod
def from_config(variant: str, patch_size: int, img_size: int) -> "ViT":
return ViT(**configs[variant], patch_size=patch_size, img_size=img_size)
def from_config(variant: str, patch_size: int, img_size: int, pretrained: bool = False) -> ViT:
if pretrained:
if (variant, patch_size) not in checkpoints:
raise ValueError(f"There is no pre-trained checkpoint for ViT-{variant}/{patch_size}")
url = "https://storage.googleapis.com/vit_models/augreg/" + checkpoints[(variant, patch_size)]
m = ViT.from_jax_weights(torch_hub_download(url))
if img_size != 224:
m.resize_pe(img_size)
else:
m = ViT(**configs[variant], patch_size=patch_size, img_size=img_size)
return m

# weights from https://github.com/google-research/vision_transformer
@torch.no_grad()
@staticmethod
def from_jax_weights(path: str) -> "ViT":
def from_jax_weights(path: str) -> ViT:
jax_weights: Mapping[str, np.ndarray] = np.load(path)

def get_w(key: str) -> Tensor:
return torch.from_numpy(jax_weights[key])

n_layers = 1
while True:
if f"Transformer/encoderblock_{n_layers}/LayerNorm_0/bias" not in jax_weights:
Expand All @@ -78,46 +167,33 @@ def from_jax_weights(path: str) -> "ViT":
patch_size = jax_weights["embedding/kernel"].shape[0]
img_size = int((jax_weights["Transformer/posembed_input/pos_embedding"].shape[1] - 1) ** 0.5) * patch_size

def _get(key: str) -> Tensor:
return torch.from_numpy(jax_weights[key])

torch_weights = dict()

def _convert_layer_norm(jax_prefix: str, torch_prefix: str) -> None:
torch_weights[f"{torch_prefix}.weight"] = _get(f"{jax_prefix}/scale")
torch_weights[f"{torch_prefix}.bias"] = _get(f"{jax_prefix}/bias")

def _convert_linear(jax_prefix: str, torch_prefix: str) -> None:
torch_weights[f"{torch_prefix}.weight"] = _get(f"{jax_prefix}/kernel").T
torch_weights[f"{torch_prefix}.bias"] = _get(f"{jax_prefix}/bias")

def _convert_mha(jax_prefix: str, torch_prefix: str) -> None:
w = torch.stack([_get(f"{jax_prefix}/{x}/kernel") for x in ["query", "key", "value"]], 1)
torch_weights[f"{torch_prefix}.in_proj_weight"] = w.flatten(1).T

b = torch.stack([_get(f"{jax_prefix}/{x}/bias") for x in ["query", "key", "value"]], 0)
torch_weights[f"{torch_prefix}.in_proj_bias"] = b.flatten()

torch_weights[f"{torch_prefix}.out_proj.weight"] = _get(f"{jax_prefix}/out/kernel").flatten(0, 1).T
torch_weights[f"{torch_prefix}.out_proj.bias"] = _get(f"{jax_prefix}/out/bias")

torch_weights["cls_token"] = _get("cls")
torch_weights["patch_embed.weight"] = _get("embedding/kernel").permute(3, 2, 0, 1)
torch_weights["patch_embed.bias"] = _get("embedding/bias")
torch_weights["pe"] = _get("Transformer/posembed_input/pos_embedding")

for idx in range(n_layers):
jax_prefix = f"Transformer/encoderblock_{idx}"
torch_prefix = f"encoder.layers.{idx}"

_convert_layer_norm(f"{jax_prefix}/LayerNorm_0", f"{torch_prefix}.norm1")
_convert_mha(f"{jax_prefix}/MultiHeadDotProductAttention_1", f"{torch_prefix}.self_attn")
_convert_layer_norm(f"{jax_prefix}/LayerNorm_2", f"{torch_prefix}.norm2")
_convert_linear(f"{jax_prefix}/MlpBlock_3/Dense_0", f"{torch_prefix}.linear1")
_convert_linear(f"{jax_prefix}/MlpBlock_3/Dense_1", f"{torch_prefix}.linear2")

_convert_layer_norm("Transformer/encoder_norm", "encoder.norm")

model = ViT(n_layers, d_model, n_heads, patch_size, img_size)
model.load_state_dict(torch_weights)
return model
m = ViT(n_layers, d_model, n_heads, patch_size, img_size)

m.cls_token.copy_(get_w("cls"))
m.patch_embed.weight.copy_(get_w("embedding/kernel").permute(3, 2, 0, 1))
m.patch_embed.bias.copy_(get_w("embedding/bias"))
m.pe.copy_(get_w("Transformer/posembed_input/pos_embedding"))

for idx, layer in enumerate(m.layers):
prefix = f"Transformer/encoderblock_{idx}/"
mha_prefix = prefix + "MultiHeadDotProductAttention_1/"

layer.norm1.weight.copy_(get_w(prefix + "LayerNorm_0/scale"))
layer.norm1.bias.copy_(get_w(prefix + "LayerNorm_0/bias"))
w = torch.stack([get_w(mha_prefix + x + "/kernel") for x in ["query", "key", "value"]], 1)
b = torch.stack([get_w(mha_prefix + x + "/bias") for x in ["query", "key", "value"]], 0)
layer.mha.in_proj.weight.copy_(w.flatten(1).T)
layer.mha.in_proj.bias.copy_(b.flatten())
layer.mha.out_proj.weight.copy_(get_w(mha_prefix + "out/kernel").flatten(0, 1).T)
layer.mha.out_proj.bias.copy_(get_w(mha_prefix + "out/bias"))

layer.norm2.weight.copy_(get_w(prefix + "LayerNorm_2/scale"))
layer.norm2.bias.copy_(get_w(prefix + "LayerNorm_2/bias"))
layer.mlp[0].weight.copy_(get_w(prefix + "MlpBlock_3/Dense_0/kernel").T)
layer.mlp[0].bias.copy_(get_w(prefix + "MlpBlock_3/Dense_0/bias"))
layer.mlp[2].weight.copy_(get_w(prefix + "MlpBlock_3/Dense_1/kernel").T)
layer.mlp[2].bias.copy_(get_w(prefix + "MlpBlock_3/Dense_1/bias"))

m.norm.weight.copy_(get_w("Transformer/encoder_norm/scale"))
m.norm.bias.copy_(get_w("Transformer/encoder_norm/bias"))
return m
10 changes: 10 additions & 0 deletions vision_toolbox/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import os

import torch


def torch_hub_download(url: str) -> str:
save_path = os.path.join(torch.hub.get_dir(), os.path.basename(url))
if not os.path.exists(save_path):
torch.hub.download_url_to_file(url, save_path)
return save_path

0 comments on commit e901280

Please sign in to comment.