Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DeiT and DeiT-III #20

Merged
merged 5 commits into from
Aug 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions tests/test_deit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest
import timm
import torch

from vision_toolbox.backbones import DeiT, DeiT3


@pytest.mark.parametrize("cls", (DeiT, DeiT3))
def test_forward(cls):
m = cls.from_config("Ti_16", 224)
m(torch.randn(1, 3, 224, 224))


@pytest.mark.parametrize("cls", (DeiT, DeiT3))
def test_resize_pe(cls):
m = cls.from_config("Ti_16", 224)
m(torch.randn(1, 3, 224, 224))
m.resize_pe(256)
m(torch.randn(1, 3, 256, 256))


@pytest.mark.parametrize(
"cls,variant,timm_name",
(
(DeiT, "Ti_16", "deit_tiny_distilled_patch16_224.fb_in1k"),
(DeiT3, "S_16", "deit3_small_patch16_224.fb_in22k_ft_in1k"),
),
)
def test_from_pretrained(cls, variant, timm_name):
m = cls.from_config(variant, 224, True).eval()
x = torch.randn(1, 3, 224, 224)
out = m(x)
# out = m.patch_embed(x).flatten(2).transpose(1, 2)

m_timm = timm.create_model(timm_name, pretrained=True, num_classes=0).eval()
out_timm = m_timm(x)
# out_timm = m_timm.patch_embed(x)

torch.testing.assert_close(out, out_timm, rtol=2e-5, atol=2e-5)
1 change: 1 addition & 0 deletions vision_toolbox/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .cait import CaiT
from .convnext import ConvNeXt
from .darknet import Darknet, DarknetYOLOv5
from .deit import DeiT, DeiT3
from .mlp_mixer import MLPMixer
from .patchconvnet import PatchConvNet
from .swin import SwinTransformer
Expand Down
9 changes: 2 additions & 7 deletions vision_toolbox/backbones/cait.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch import Tensor, nn

from .base import _act, _norm
from .vit import MHA, ViTBlock
from .vit import MHA, ViT, ViTBlock


# basically attention pooling
Expand Down Expand Up @@ -152,12 +152,7 @@ def forward(self, imgs: Tensor) -> Tensor:

@torch.no_grad()
def resize_pe(self, size: int, interpolation_mode: str = "bicubic") -> None:
old_size = int(self.pe.shape[1] ** 0.5)
new_size = size // self.patch_embed.weight.shape[2]
pe = self.pe.unflatten(1, (old_size, old_size)).permute(0, 3, 1, 2)
pe = F.interpolate(pe, (new_size, new_size), mode=interpolation_mode)
pe = pe.permute(0, 2, 3, 1).flatten(1, 2)
self.pe = nn.Parameter(pe)
ViT.resize_pe(self, size, interpolation_mode)

@staticmethod
def from_config(variant: str, img_size: int, pretrained: bool = False) -> CaiT:
Expand Down
180 changes: 180 additions & 0 deletions vision_toolbox/backbones/deit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# https://arxiv.org/abs/2012.12877
# https://arxiv.org/abs/2204.07118
# https://github.com/facebookresearch/deit

from __future__ import annotations

from functools import partial

import torch
from torch import Tensor, nn

from ..components import LayerScale
from .base import _act, _norm
from .vit import ViT, ViTBlock


class DeiT(ViT):
def __init__(
self,
d_model: int,
depth: int,
n_heads: int,
patch_size: int,
img_size: int,
bias: bool = True,
mlp_ratio: float = 4.0,
dropout: float = 0.0,
layer_scale_init: float | None = None,
stochastic_depth: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
) -> None:
# fmt: off
super().__init__(
d_model, depth, n_heads, patch_size, img_size, True, bias, mlp_ratio,
dropout, layer_scale_init, stochastic_depth, norm, act
)
# fmt: on
self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model))

def forward(self, imgs: Tensor) -> Tensor:
out = self.patch_embed(imgs).flatten(2).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C)
out = torch.cat([self.cls_token, self.dist_token, out + self.pe], 1)
out = self.layers(out)
return self.norm(out[:, :2]).mean(1)

@staticmethod
def from_config(variant: str, img_size: int, pretrained: bool = False) -> DeiT:
variant, patch_size = variant.split("_")

d_model, depth, n_heads = dict(
Ti=(192, 12, 3),
S=(384, 12, 6),
M=(512, 12, 8),
B=(768, 12, 12),
L=(1024, 24, 16),
H=(1280, 32, 16),
)[variant]
patch_size = int(patch_size)
m = DeiT(d_model, depth, n_heads, patch_size, img_size)

if pretrained:
ckpt = dict(
Ti_16_224="deit_tiny_distilled_patch16_224-b40b3cf7.pth",
S_16_224="deit_small_distilled_patch16_224-649709d9.pth",
B_16_224="deit_base_distilled_patch16_224-df68dfff.pth",
B_16_384="deit_base_distilled_patch16_384-d0272ac0.pth",
)[f"{variant}_{patch_size}_{img_size}"]
base_url = "https://dl.fbaipublicfiles.com/deit/"
state_dict = torch.hub.load_state_dict_from_url(base_url + ckpt)["model"]
m.load_official_ckpt(state_dict)

return m

@torch.no_grad()
def load_official_ckpt(self, state_dict: dict[str, Tensor]) -> None:
def copy_(m: nn.Linear | nn.LayerNorm, prefix: str):
m.weight.copy_(state_dict.pop(prefix + ".weight").view(m.weight.shape))
m.bias.copy_(state_dict.pop(prefix + ".bias"))

copy_(self.patch_embed, "patch_embed.proj")
pe = state_dict.pop("pos_embed")
self.pe.copy_(pe[:, -self.pe.shape[1] :])

self.cls_token.copy_(state_dict.pop("cls_token"))
if pe.shape[1] > self.pe.shape[1]:
self.cls_token.add_(pe[:, 0])

if hasattr(self, "dist_token"):
self.dist_token.copy_(state_dict.pop("dist_token"))
self.dist_token.add_(pe[:, 1])
state_dict.pop("head_dist.weight")
state_dict.pop("head_dist.bias")

for i, block in enumerate(self.layers):
block: ViTBlock
prefix = f"blocks.{i}."

copy_(block.mha[0], prefix + "norm1")
q_w, k_w, v_w = state_dict.pop(prefix + "attn.qkv.weight").chunk(3, 0)
block.mha[1].q_proj.weight.copy_(q_w)
block.mha[1].k_proj.weight.copy_(k_w)
block.mha[1].v_proj.weight.copy_(v_w)
q_b, k_b, v_b = state_dict.pop(prefix + "attn.qkv.bias").chunk(3, 0)
block.mha[1].q_proj.bias.copy_(q_b)
block.mha[1].k_proj.bias.copy_(k_b)
block.mha[1].v_proj.bias.copy_(v_b)
copy_(block.mha[1].out_proj, prefix + "attn.proj")
if isinstance(block.mha[2], LayerScale):
block.mha[2].gamma.copy_(state_dict.pop(prefix + "gamma_1"))

copy_(block.mlp[0], prefix + "norm2")
copy_(block.mlp[1].linear1, prefix + "mlp.fc1")
copy_(block.mlp[1].linear2, prefix + "mlp.fc2")
if isinstance(block.mlp[2], LayerScale):
block.mlp[2].gamma.copy_(state_dict.pop(prefix + "gamma_2"))

copy_(self.norm, "norm")
assert len(state_dict) == 2, state_dict.keys()


class DeiT3(ViT):
def __init__(
self,
d_model: int,
depth: int,
n_heads: int,
patch_size: int,
img_size: int,
cls_token: bool = True,
bias: bool = True,
mlp_ratio: float = 4.0,
dropout: float = 0.0,
layer_scale_init: float | None = 1e-6,
stochastic_depth: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
):
# fmt: off
super().__init__(
d_model, depth, n_heads, patch_size, img_size, cls_token, bias,
mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm, act
)
# fmt: on

@staticmethod
def from_config(variant: str, img_size: int, pretrained: bool = False) -> DeiT:
variant, patch_size = variant.split("_")

d_model, depth, n_heads = dict(
Ti=(192, 12, 3),
S=(384, 12, 6),
M=(512, 12, 8),
B=(768, 12, 12),
L=(1024, 24, 16),
H=(1280, 32, 16),
)[variant]
patch_size = int(patch_size)
m = DeiT3(d_model, depth, n_heads, patch_size, img_size)

if pretrained:
ckpt = dict(
S_16_224="deit_3_small_224_21k.pth",
S_16_384="deit_3_small_384_21k.pth",
M_16_224="deit_3_medium_224_21k.pth",
B_16_224="deit_3_base_224_21k.pth",
B_16_384="deit_3_base_384_21k.pth",
L_16_224="deit_3_large_224_21k.pth",
L_16_384="deit_3_large_384_21k.pth",
H_16_224="deit_3_huge_224_21k.pth",
)[f"{variant}_{patch_size}_{img_size}"]
base_url = "https://dl.fbaipublicfiles.com/deit/"
state_dict = torch.hub.load_state_dict_from_url(base_url + ckpt)["model"]
m.load_official_ckpt(state_dict)

return m

@torch.no_grad()
def load_official_ckpt(self, state_dict: dict[str, Tensor]) -> None:
DeiT.load_official_ckpt(self, state_dict)
25 changes: 9 additions & 16 deletions vision_toolbox/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,7 @@ def __init__(
super().__init__()
self.patch_embed = nn.Conv2d(3, d_model, patch_size, patch_size)
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model)) if cls_token else None

pe_size = (img_size // patch_size) ** 2
if cls_token:
pe_size += 1
self.pe = nn.Parameter(torch.empty(1, pe_size, d_model))
self.pe = nn.Parameter(torch.empty(1, (img_size // patch_size) ** 2, d_model))
nn.init.normal_(self.pe, 0, 0.02)

self.layers = nn.Sequential()
Expand All @@ -127,25 +123,19 @@ def __init__(
self.norm = norm(d_model)

def forward(self, imgs: Tensor) -> Tensor:
out = self.patch_embed(imgs).flatten(2).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C)
out = self.patch_embed(imgs).flatten(2).transpose(1, 2) + self.pe # (N, C, H, W) -> (N, H*W, C)
if self.cls_token is not None:
out = torch.cat([self.cls_token, out], 1)
out = self.layers(out + self.pe)
out = self.layers(out)
return self.norm(out[:, 0]) if self.cls_token is not None else self.norm(out).mean(1)

@torch.no_grad()
def resize_pe(self, size: int, interpolation_mode: str = "bicubic") -> None:
pe = self.pe if self.cls_token is None else self.pe[:, 1:]

old_size = int(pe.shape[1] ** 0.5)
old_size = int(self.pe.shape[1] ** 0.5)
new_size = size // self.patch_embed.weight.shape[2]
pe = pe.unflatten(1, (old_size, old_size)).permute(0, 3, 1, 2)
pe = self.pe.unflatten(1, (old_size, old_size)).permute(0, 3, 1, 2)
pe = F.interpolate(pe, (new_size, new_size), mode=interpolation_mode)
pe = pe.permute(0, 2, 3, 1).flatten(1, 2)

if self.cls_token is not None:
pe = torch.cat((self.pe[:, 0:1], pe), 1)

self.pe = nn.Parameter(pe)

@staticmethod
Expand All @@ -155,6 +145,7 @@ def from_config(variant: str, img_size: int, pretrained: bool = False) -> ViT:
d_model, depth, n_heads = dict(
Ti=(192, 12, 3),
S=(384, 12, 6),
M=(512, 12, 8),
B=(768, 12, 12),
L=(1024, 24, 16),
H=(1280, 32, 16),
Expand Down Expand Up @@ -186,9 +177,11 @@ def get_w(key: str) -> Tensor:
return torch.from_numpy(jax_weights[key])

self.cls_token.copy_(get_w("cls"))
pe = get_w("Transformer/posembed_input/pos_embedding")
self.cls_token.add_(pe[:, 0])
self.pe.copy_(pe[:, 1:])
self.patch_embed.weight.copy_(get_w("embedding/kernel").permute(3, 2, 0, 1))
self.patch_embed.bias.copy_(get_w("embedding/bias"))
self.pe.copy_(get_w("Transformer/posembed_input/pos_embedding"))

for idx, layer in enumerate(self.layers):
layer: ViTBlock
Expand Down