diff --git a/tests/test_deit.py b/tests/test_deit.py new file mode 100644 index 0000000..2b65e95 --- /dev/null +++ b/tests/test_deit.py @@ -0,0 +1,39 @@ +import pytest +import timm +import torch + +from vision_toolbox.backbones import DeiT, DeiT3 + + +@pytest.mark.parametrize("cls", (DeiT, DeiT3)) +def test_forward(cls): + m = cls.from_config("Ti_16", 224) + m(torch.randn(1, 3, 224, 224)) + + +@pytest.mark.parametrize("cls", (DeiT, DeiT3)) +def test_resize_pe(cls): + m = cls.from_config("Ti_16", 224) + m(torch.randn(1, 3, 224, 224)) + m.resize_pe(256) + m(torch.randn(1, 3, 256, 256)) + + +@pytest.mark.parametrize( + "cls,variant,timm_name", + ( + (DeiT, "Ti_16", "deit_tiny_distilled_patch16_224.fb_in1k"), + (DeiT3, "S_16", "deit3_small_patch16_224.fb_in22k_ft_in1k"), + ), +) +def test_from_pretrained(cls, variant, timm_name): + m = cls.from_config(variant, 224, True).eval() + x = torch.randn(1, 3, 224, 224) + out = m(x) + # out = m.patch_embed(x).flatten(2).transpose(1, 2) + + m_timm = timm.create_model(timm_name, pretrained=True, num_classes=0).eval() + out_timm = m_timm(x) + # out_timm = m_timm.patch_embed(x) + + torch.testing.assert_close(out, out_timm, rtol=2e-5, atol=2e-5) diff --git a/vision_toolbox/backbones/__init__.py b/vision_toolbox/backbones/__init__.py index eeae4b1..4b7d7fd 100644 --- a/vision_toolbox/backbones/__init__.py +++ b/vision_toolbox/backbones/__init__.py @@ -1,6 +1,7 @@ from .cait import CaiT from .convnext import ConvNeXt from .darknet import Darknet, DarknetYOLOv5 +from .deit import DeiT, DeiT3 from .mlp_mixer import MLPMixer from .patchconvnet import PatchConvNet from .swin import SwinTransformer diff --git a/vision_toolbox/backbones/cait.py b/vision_toolbox/backbones/cait.py index 88346a0..c7e18f4 100644 --- a/vision_toolbox/backbones/cait.py +++ b/vision_toolbox/backbones/cait.py @@ -10,7 +10,7 @@ from torch import Tensor, nn from .base import _act, _norm -from .vit import MHA, ViTBlock +from .vit import MHA, ViT, ViTBlock # basically attention pooling @@ -152,12 +152,7 @@ def forward(self, imgs: Tensor) -> Tensor: @torch.no_grad() def resize_pe(self, size: int, interpolation_mode: str = "bicubic") -> None: - old_size = int(self.pe.shape[1] ** 0.5) - new_size = size // self.patch_embed.weight.shape[2] - pe = self.pe.unflatten(1, (old_size, old_size)).permute(0, 3, 1, 2) - pe = F.interpolate(pe, (new_size, new_size), mode=interpolation_mode) - pe = pe.permute(0, 2, 3, 1).flatten(1, 2) - self.pe = nn.Parameter(pe) + ViT.resize_pe(self, size, interpolation_mode) @staticmethod def from_config(variant: str, img_size: int, pretrained: bool = False) -> CaiT: diff --git a/vision_toolbox/backbones/deit.py b/vision_toolbox/backbones/deit.py new file mode 100644 index 0000000..102f60f --- /dev/null +++ b/vision_toolbox/backbones/deit.py @@ -0,0 +1,180 @@ +# https://arxiv.org/abs/2012.12877 +# https://arxiv.org/abs/2204.07118 +# https://github.com/facebookresearch/deit + +from __future__ import annotations + +from functools import partial + +import torch +from torch import Tensor, nn + +from ..components import LayerScale +from .base import _act, _norm +from .vit import ViT, ViTBlock + + +class DeiT(ViT): + def __init__( + self, + d_model: int, + depth: int, + n_heads: int, + patch_size: int, + img_size: int, + bias: bool = True, + mlp_ratio: float = 4.0, + dropout: float = 0.0, + layer_scale_init: float | None = None, + stochastic_depth: float = 0.0, + norm: _norm = partial(nn.LayerNorm, eps=1e-6), + act: _act = nn.GELU, + ) -> None: + # fmt: off + super().__init__( + d_model, depth, n_heads, patch_size, img_size, True, bias, mlp_ratio, + dropout, layer_scale_init, stochastic_depth, norm, act + ) + # fmt: on + self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model)) + + def forward(self, imgs: Tensor) -> Tensor: + out = self.patch_embed(imgs).flatten(2).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C) + out = torch.cat([self.cls_token, self.dist_token, out + self.pe], 1) + out = self.layers(out) + return self.norm(out[:, :2]).mean(1) + + @staticmethod + def from_config(variant: str, img_size: int, pretrained: bool = False) -> DeiT: + variant, patch_size = variant.split("_") + + d_model, depth, n_heads = dict( + Ti=(192, 12, 3), + S=(384, 12, 6), + M=(512, 12, 8), + B=(768, 12, 12), + L=(1024, 24, 16), + H=(1280, 32, 16), + )[variant] + patch_size = int(patch_size) + m = DeiT(d_model, depth, n_heads, patch_size, img_size) + + if pretrained: + ckpt = dict( + Ti_16_224="deit_tiny_distilled_patch16_224-b40b3cf7.pth", + S_16_224="deit_small_distilled_patch16_224-649709d9.pth", + B_16_224="deit_base_distilled_patch16_224-df68dfff.pth", + B_16_384="deit_base_distilled_patch16_384-d0272ac0.pth", + )[f"{variant}_{patch_size}_{img_size}"] + base_url = "https://dl.fbaipublicfiles.com/deit/" + state_dict = torch.hub.load_state_dict_from_url(base_url + ckpt)["model"] + m.load_official_ckpt(state_dict) + + return m + + @torch.no_grad() + def load_official_ckpt(self, state_dict: dict[str, Tensor]) -> None: + def copy_(m: nn.Linear | nn.LayerNorm, prefix: str): + m.weight.copy_(state_dict.pop(prefix + ".weight").view(m.weight.shape)) + m.bias.copy_(state_dict.pop(prefix + ".bias")) + + copy_(self.patch_embed, "patch_embed.proj") + pe = state_dict.pop("pos_embed") + self.pe.copy_(pe[:, -self.pe.shape[1] :]) + + self.cls_token.copy_(state_dict.pop("cls_token")) + if pe.shape[1] > self.pe.shape[1]: + self.cls_token.add_(pe[:, 0]) + + if hasattr(self, "dist_token"): + self.dist_token.copy_(state_dict.pop("dist_token")) + self.dist_token.add_(pe[:, 1]) + state_dict.pop("head_dist.weight") + state_dict.pop("head_dist.bias") + + for i, block in enumerate(self.layers): + block: ViTBlock + prefix = f"blocks.{i}." + + copy_(block.mha[0], prefix + "norm1") + q_w, k_w, v_w = state_dict.pop(prefix + "attn.qkv.weight").chunk(3, 0) + block.mha[1].q_proj.weight.copy_(q_w) + block.mha[1].k_proj.weight.copy_(k_w) + block.mha[1].v_proj.weight.copy_(v_w) + q_b, k_b, v_b = state_dict.pop(prefix + "attn.qkv.bias").chunk(3, 0) + block.mha[1].q_proj.bias.copy_(q_b) + block.mha[1].k_proj.bias.copy_(k_b) + block.mha[1].v_proj.bias.copy_(v_b) + copy_(block.mha[1].out_proj, prefix + "attn.proj") + if isinstance(block.mha[2], LayerScale): + block.mha[2].gamma.copy_(state_dict.pop(prefix + "gamma_1")) + + copy_(block.mlp[0], prefix + "norm2") + copy_(block.mlp[1].linear1, prefix + "mlp.fc1") + copy_(block.mlp[1].linear2, prefix + "mlp.fc2") + if isinstance(block.mlp[2], LayerScale): + block.mlp[2].gamma.copy_(state_dict.pop(prefix + "gamma_2")) + + copy_(self.norm, "norm") + assert len(state_dict) == 2, state_dict.keys() + + +class DeiT3(ViT): + def __init__( + self, + d_model: int, + depth: int, + n_heads: int, + patch_size: int, + img_size: int, + cls_token: bool = True, + bias: bool = True, + mlp_ratio: float = 4.0, + dropout: float = 0.0, + layer_scale_init: float | None = 1e-6, + stochastic_depth: float = 0.0, + norm: _norm = partial(nn.LayerNorm, eps=1e-6), + act: _act = nn.GELU, + ): + # fmt: off + super().__init__( + d_model, depth, n_heads, patch_size, img_size, cls_token, bias, + mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm, act + ) + # fmt: on + + @staticmethod + def from_config(variant: str, img_size: int, pretrained: bool = False) -> DeiT: + variant, patch_size = variant.split("_") + + d_model, depth, n_heads = dict( + Ti=(192, 12, 3), + S=(384, 12, 6), + M=(512, 12, 8), + B=(768, 12, 12), + L=(1024, 24, 16), + H=(1280, 32, 16), + )[variant] + patch_size = int(patch_size) + m = DeiT3(d_model, depth, n_heads, patch_size, img_size) + + if pretrained: + ckpt = dict( + S_16_224="deit_3_small_224_21k.pth", + S_16_384="deit_3_small_384_21k.pth", + M_16_224="deit_3_medium_224_21k.pth", + B_16_224="deit_3_base_224_21k.pth", + B_16_384="deit_3_base_384_21k.pth", + L_16_224="deit_3_large_224_21k.pth", + L_16_384="deit_3_large_384_21k.pth", + H_16_224="deit_3_huge_224_21k.pth", + )[f"{variant}_{patch_size}_{img_size}"] + base_url = "https://dl.fbaipublicfiles.com/deit/" + state_dict = torch.hub.load_state_dict_from_url(base_url + ckpt)["model"] + m.load_official_ckpt(state_dict) + + return m + + @torch.no_grad() + def load_official_ckpt(self, state_dict: dict[str, Tensor]) -> None: + DeiT.load_official_ckpt(self, state_dict) diff --git a/vision_toolbox/backbones/vit.py b/vision_toolbox/backbones/vit.py index 850450c..bd14c2d 100644 --- a/vision_toolbox/backbones/vit.py +++ b/vision_toolbox/backbones/vit.py @@ -112,11 +112,7 @@ def __init__( super().__init__() self.patch_embed = nn.Conv2d(3, d_model, patch_size, patch_size) self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model)) if cls_token else None - - pe_size = (img_size // patch_size) ** 2 - if cls_token: - pe_size += 1 - self.pe = nn.Parameter(torch.empty(1, pe_size, d_model)) + self.pe = nn.Parameter(torch.empty(1, (img_size // patch_size) ** 2, d_model)) nn.init.normal_(self.pe, 0, 0.02) self.layers = nn.Sequential() @@ -127,25 +123,19 @@ def __init__( self.norm = norm(d_model) def forward(self, imgs: Tensor) -> Tensor: - out = self.patch_embed(imgs).flatten(2).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C) + out = self.patch_embed(imgs).flatten(2).transpose(1, 2) + self.pe # (N, C, H, W) -> (N, H*W, C) if self.cls_token is not None: out = torch.cat([self.cls_token, out], 1) - out = self.layers(out + self.pe) + out = self.layers(out) return self.norm(out[:, 0]) if self.cls_token is not None else self.norm(out).mean(1) @torch.no_grad() def resize_pe(self, size: int, interpolation_mode: str = "bicubic") -> None: - pe = self.pe if self.cls_token is None else self.pe[:, 1:] - - old_size = int(pe.shape[1] ** 0.5) + old_size = int(self.pe.shape[1] ** 0.5) new_size = size // self.patch_embed.weight.shape[2] - pe = pe.unflatten(1, (old_size, old_size)).permute(0, 3, 1, 2) + pe = self.pe.unflatten(1, (old_size, old_size)).permute(0, 3, 1, 2) pe = F.interpolate(pe, (new_size, new_size), mode=interpolation_mode) pe = pe.permute(0, 2, 3, 1).flatten(1, 2) - - if self.cls_token is not None: - pe = torch.cat((self.pe[:, 0:1], pe), 1) - self.pe = nn.Parameter(pe) @staticmethod @@ -155,6 +145,7 @@ def from_config(variant: str, img_size: int, pretrained: bool = False) -> ViT: d_model, depth, n_heads = dict( Ti=(192, 12, 3), S=(384, 12, 6), + M=(512, 12, 8), B=(768, 12, 12), L=(1024, 24, 16), H=(1280, 32, 16), @@ -186,9 +177,11 @@ def get_w(key: str) -> Tensor: return torch.from_numpy(jax_weights[key]) self.cls_token.copy_(get_w("cls")) + pe = get_w("Transformer/posembed_input/pos_embedding") + self.cls_token.add_(pe[:, 0]) + self.pe.copy_(pe[:, 1:]) self.patch_embed.weight.copy_(get_w("embedding/kernel").permute(3, 2, 0, 1)) self.patch_embed.bias.copy_(get_w("embedding/bias")) - self.pe.copy_(get_w("Transformer/posembed_input/pos_embedding")) for idx, layer in enumerate(self.layers): layer: ViTBlock