Skip to content

Commit

Permalink
Add DeiT and DeiT-III (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Aug 20, 2023
1 parent 6fca040 commit bb7e4f7
Show file tree
Hide file tree
Showing 5 changed files with 231 additions and 23 deletions.
39 changes: 39 additions & 0 deletions tests/test_deit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest
import timm
import torch

from vision_toolbox.backbones import DeiT, DeiT3


@pytest.mark.parametrize("cls", (DeiT, DeiT3))
def test_forward(cls):
m = cls.from_config("Ti_16", 224)
m(torch.randn(1, 3, 224, 224))


@pytest.mark.parametrize("cls", (DeiT, DeiT3))
def test_resize_pe(cls):
m = cls.from_config("Ti_16", 224)
m(torch.randn(1, 3, 224, 224))
m.resize_pe(256)
m(torch.randn(1, 3, 256, 256))


@pytest.mark.parametrize(
"cls,variant,timm_name",
(
(DeiT, "Ti_16", "deit_tiny_distilled_patch16_224.fb_in1k"),
(DeiT3, "S_16", "deit3_small_patch16_224.fb_in22k_ft_in1k"),
),
)
def test_from_pretrained(cls, variant, timm_name):
m = cls.from_config(variant, 224, True).eval()
x = torch.randn(1, 3, 224, 224)
out = m(x)
# out = m.patch_embed(x).flatten(2).transpose(1, 2)

m_timm = timm.create_model(timm_name, pretrained=True, num_classes=0).eval()
out_timm = m_timm(x)
# out_timm = m_timm.patch_embed(x)

torch.testing.assert_close(out, out_timm, rtol=2e-5, atol=2e-5)
1 change: 1 addition & 0 deletions vision_toolbox/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .cait import CaiT
from .convnext import ConvNeXt
from .darknet import Darknet, DarknetYOLOv5
from .deit import DeiT, DeiT3
from .mlp_mixer import MLPMixer
from .patchconvnet import PatchConvNet
from .swin import SwinTransformer
Expand Down
9 changes: 2 additions & 7 deletions vision_toolbox/backbones/cait.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch import Tensor, nn

from .base import _act, _norm
from .vit import MHA, ViTBlock
from .vit import MHA, ViT, ViTBlock


# basically attention pooling
Expand Down Expand Up @@ -152,12 +152,7 @@ def forward(self, imgs: Tensor) -> Tensor:

@torch.no_grad()
def resize_pe(self, size: int, interpolation_mode: str = "bicubic") -> None:
old_size = int(self.pe.shape[1] ** 0.5)
new_size = size // self.patch_embed.weight.shape[2]
pe = self.pe.unflatten(1, (old_size, old_size)).permute(0, 3, 1, 2)
pe = F.interpolate(pe, (new_size, new_size), mode=interpolation_mode)
pe = pe.permute(0, 2, 3, 1).flatten(1, 2)
self.pe = nn.Parameter(pe)
ViT.resize_pe(self, size, interpolation_mode)

@staticmethod
def from_config(variant: str, img_size: int, pretrained: bool = False) -> CaiT:
Expand Down
180 changes: 180 additions & 0 deletions vision_toolbox/backbones/deit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# https://arxiv.org/abs/2012.12877
# https://arxiv.org/abs/2204.07118
# https://github.com/facebookresearch/deit

from __future__ import annotations

from functools import partial

import torch
from torch import Tensor, nn

from ..components import LayerScale
from .base import _act, _norm
from .vit import ViT, ViTBlock


class DeiT(ViT):
def __init__(
self,
d_model: int,
depth: int,
n_heads: int,
patch_size: int,
img_size: int,
bias: bool = True,
mlp_ratio: float = 4.0,
dropout: float = 0.0,
layer_scale_init: float | None = None,
stochastic_depth: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
) -> None:
# fmt: off
super().__init__(
d_model, depth, n_heads, patch_size, img_size, True, bias, mlp_ratio,
dropout, layer_scale_init, stochastic_depth, norm, act
)
# fmt: on
self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model))

def forward(self, imgs: Tensor) -> Tensor:
out = self.patch_embed(imgs).flatten(2).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C)
out = torch.cat([self.cls_token, self.dist_token, out + self.pe], 1)
out = self.layers(out)
return self.norm(out[:, :2]).mean(1)

@staticmethod
def from_config(variant: str, img_size: int, pretrained: bool = False) -> DeiT:
variant, patch_size = variant.split("_")

d_model, depth, n_heads = dict(
Ti=(192, 12, 3),
S=(384, 12, 6),
M=(512, 12, 8),
B=(768, 12, 12),
L=(1024, 24, 16),
H=(1280, 32, 16),
)[variant]
patch_size = int(patch_size)
m = DeiT(d_model, depth, n_heads, patch_size, img_size)

if pretrained:
ckpt = dict(
Ti_16_224="deit_tiny_distilled_patch16_224-b40b3cf7.pth",
S_16_224="deit_small_distilled_patch16_224-649709d9.pth",
B_16_224="deit_base_distilled_patch16_224-df68dfff.pth",
B_16_384="deit_base_distilled_patch16_384-d0272ac0.pth",
)[f"{variant}_{patch_size}_{img_size}"]
base_url = "https://dl.fbaipublicfiles.com/deit/"
state_dict = torch.hub.load_state_dict_from_url(base_url + ckpt)["model"]
m.load_official_ckpt(state_dict)

return m

@torch.no_grad()
def load_official_ckpt(self, state_dict: dict[str, Tensor]) -> None:
def copy_(m: nn.Linear | nn.LayerNorm, prefix: str):
m.weight.copy_(state_dict.pop(prefix + ".weight").view(m.weight.shape))
m.bias.copy_(state_dict.pop(prefix + ".bias"))

copy_(self.patch_embed, "patch_embed.proj")
pe = state_dict.pop("pos_embed")
self.pe.copy_(pe[:, -self.pe.shape[1] :])

self.cls_token.copy_(state_dict.pop("cls_token"))
if pe.shape[1] > self.pe.shape[1]:
self.cls_token.add_(pe[:, 0])

if hasattr(self, "dist_token"):
self.dist_token.copy_(state_dict.pop("dist_token"))
self.dist_token.add_(pe[:, 1])
state_dict.pop("head_dist.weight")
state_dict.pop("head_dist.bias")

for i, block in enumerate(self.layers):
block: ViTBlock
prefix = f"blocks.{i}."

copy_(block.mha[0], prefix + "norm1")
q_w, k_w, v_w = state_dict.pop(prefix + "attn.qkv.weight").chunk(3, 0)
block.mha[1].q_proj.weight.copy_(q_w)
block.mha[1].k_proj.weight.copy_(k_w)
block.mha[1].v_proj.weight.copy_(v_w)
q_b, k_b, v_b = state_dict.pop(prefix + "attn.qkv.bias").chunk(3, 0)
block.mha[1].q_proj.bias.copy_(q_b)
block.mha[1].k_proj.bias.copy_(k_b)
block.mha[1].v_proj.bias.copy_(v_b)
copy_(block.mha[1].out_proj, prefix + "attn.proj")
if isinstance(block.mha[2], LayerScale):
block.mha[2].gamma.copy_(state_dict.pop(prefix + "gamma_1"))

copy_(block.mlp[0], prefix + "norm2")
copy_(block.mlp[1].linear1, prefix + "mlp.fc1")
copy_(block.mlp[1].linear2, prefix + "mlp.fc2")
if isinstance(block.mlp[2], LayerScale):
block.mlp[2].gamma.copy_(state_dict.pop(prefix + "gamma_2"))

copy_(self.norm, "norm")
assert len(state_dict) == 2, state_dict.keys()


class DeiT3(ViT):
def __init__(
self,
d_model: int,
depth: int,
n_heads: int,
patch_size: int,
img_size: int,
cls_token: bool = True,
bias: bool = True,
mlp_ratio: float = 4.0,
dropout: float = 0.0,
layer_scale_init: float | None = 1e-6,
stochastic_depth: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
):
# fmt: off
super().__init__(
d_model, depth, n_heads, patch_size, img_size, cls_token, bias,
mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm, act
)
# fmt: on

@staticmethod
def from_config(variant: str, img_size: int, pretrained: bool = False) -> DeiT:
variant, patch_size = variant.split("_")

d_model, depth, n_heads = dict(
Ti=(192, 12, 3),
S=(384, 12, 6),
M=(512, 12, 8),
B=(768, 12, 12),
L=(1024, 24, 16),
H=(1280, 32, 16),
)[variant]
patch_size = int(patch_size)
m = DeiT3(d_model, depth, n_heads, patch_size, img_size)

if pretrained:
ckpt = dict(
S_16_224="deit_3_small_224_21k.pth",
S_16_384="deit_3_small_384_21k.pth",
M_16_224="deit_3_medium_224_21k.pth",
B_16_224="deit_3_base_224_21k.pth",
B_16_384="deit_3_base_384_21k.pth",
L_16_224="deit_3_large_224_21k.pth",
L_16_384="deit_3_large_384_21k.pth",
H_16_224="deit_3_huge_224_21k.pth",
)[f"{variant}_{patch_size}_{img_size}"]
base_url = "https://dl.fbaipublicfiles.com/deit/"
state_dict = torch.hub.load_state_dict_from_url(base_url + ckpt)["model"]
m.load_official_ckpt(state_dict)

return m

@torch.no_grad()
def load_official_ckpt(self, state_dict: dict[str, Tensor]) -> None:
DeiT.load_official_ckpt(self, state_dict)
25 changes: 9 additions & 16 deletions vision_toolbox/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,7 @@ def __init__(
super().__init__()
self.patch_embed = nn.Conv2d(3, d_model, patch_size, patch_size)
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model)) if cls_token else None

pe_size = (img_size // patch_size) ** 2
if cls_token:
pe_size += 1
self.pe = nn.Parameter(torch.empty(1, pe_size, d_model))
self.pe = nn.Parameter(torch.empty(1, (img_size // patch_size) ** 2, d_model))
nn.init.normal_(self.pe, 0, 0.02)

self.layers = nn.Sequential()
Expand All @@ -127,25 +123,19 @@ def __init__(
self.norm = norm(d_model)

def forward(self, imgs: Tensor) -> Tensor:
out = self.patch_embed(imgs).flatten(2).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C)
out = self.patch_embed(imgs).flatten(2).transpose(1, 2) + self.pe # (N, C, H, W) -> (N, H*W, C)
if self.cls_token is not None:
out = torch.cat([self.cls_token, out], 1)
out = self.layers(out + self.pe)
out = self.layers(out)
return self.norm(out[:, 0]) if self.cls_token is not None else self.norm(out).mean(1)

@torch.no_grad()
def resize_pe(self, size: int, interpolation_mode: str = "bicubic") -> None:
pe = self.pe if self.cls_token is None else self.pe[:, 1:]

old_size = int(pe.shape[1] ** 0.5)
old_size = int(self.pe.shape[1] ** 0.5)
new_size = size // self.patch_embed.weight.shape[2]
pe = pe.unflatten(1, (old_size, old_size)).permute(0, 3, 1, 2)
pe = self.pe.unflatten(1, (old_size, old_size)).permute(0, 3, 1, 2)
pe = F.interpolate(pe, (new_size, new_size), mode=interpolation_mode)
pe = pe.permute(0, 2, 3, 1).flatten(1, 2)

if self.cls_token is not None:
pe = torch.cat((self.pe[:, 0:1], pe), 1)

self.pe = nn.Parameter(pe)

@staticmethod
Expand All @@ -155,6 +145,7 @@ def from_config(variant: str, img_size: int, pretrained: bool = False) -> ViT:
d_model, depth, n_heads = dict(
Ti=(192, 12, 3),
S=(384, 12, 6),
M=(512, 12, 8),
B=(768, 12, 12),
L=(1024, 24, 16),
H=(1280, 32, 16),
Expand Down Expand Up @@ -186,9 +177,11 @@ def get_w(key: str) -> Tensor:
return torch.from_numpy(jax_weights[key])

self.cls_token.copy_(get_w("cls"))
pe = get_w("Transformer/posembed_input/pos_embedding")
self.cls_token.add_(pe[:, 0])
self.pe.copy_(pe[:, 1:])
self.patch_embed.weight.copy_(get_w("embedding/kernel").permute(3, 2, 0, 1))
self.patch_embed.bias.copy_(get_w("embedding/bias"))
self.pe.copy_(get_w("Transformer/posembed_input/pos_embedding"))

for idx, layer in enumerate(self.layers):
layer: ViTBlock
Expand Down

0 comments on commit bb7e4f7

Please sign in to comment.