From a36355363140263541cd1d24e413f0349360fa93 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 20 Aug 2023 12:01:18 +0800 Subject: [PATCH] add forward-able CaiT --- tests/test_cait.py | 20 +++ vision_toolbox/backbones/__init__.py | 1 + vision_toolbox/backbones/cait.py | 219 +++++++++++++++++++++++++++ 3 files changed, 240 insertions(+) create mode 100644 tests/test_cait.py create mode 100644 vision_toolbox/backbones/cait.py diff --git a/tests/test_cait.py b/tests/test_cait.py new file mode 100644 index 0000000..79d8c3f --- /dev/null +++ b/tests/test_cait.py @@ -0,0 +1,20 @@ +import timm +import torch + +from vision_toolbox.backbones import CaiT + + +def test_forward(): + m = CaiT.from_config("xxs_24", 224) + m(torch.randn(1, 3, 224, 224)) + + +def test_from_pretrained(): + m = CaiT.from_config("xxs_24", 224, True).eval() + x = torch.randn(1, 3, 224, 224) + out = m(x) + + m_timm = timm.create_model("cait_xxs24_224.fb_dist_in1k", pretrained=True, num_classes=0).eval() + out_timm = m_timm(x) + + torch.testing.assert_close(out, out_timm, rtol=2e-5, atol=2e-5) diff --git a/vision_toolbox/backbones/__init__.py b/vision_toolbox/backbones/__init__.py index a724f7c..eeae4b1 100644 --- a/vision_toolbox/backbones/__init__.py +++ b/vision_toolbox/backbones/__init__.py @@ -1,3 +1,4 @@ +from .cait import CaiT from .convnext import ConvNeXt from .darknet import Darknet, DarknetYOLOv5 from .mlp_mixer import MLPMixer diff --git a/vision_toolbox/backbones/cait.py b/vision_toolbox/backbones/cait.py new file mode 100644 index 0000000..8696e1b --- /dev/null +++ b/vision_toolbox/backbones/cait.py @@ -0,0 +1,219 @@ +# https://arxiv.org/abs/2103.17239 +# https://github.com/facebookresearch/deit +from __future__ import annotations + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from ..components import LayerScale, StochasticDepth +from .base import _act, _norm +from .vit import MLP + + +# basically attention pooling +class ClassAttention(nn.Module): + def __init__(self, d_model: int, n_heads: int, bias: bool = True, dropout: float = 0.0) -> None: + super().__init__() + self.q_proj = nn.Linear(d_model, d_model, bias) + self.k_proj = nn.Linear(d_model, d_model, bias) + self.v_proj = nn.Linear(d_model, d_model, bias) + self.out_proj = nn.Linear(d_model, d_model, bias) + + self.n_heads = n_heads + self.dropout = dropout + self.scale = (d_model // n_heads) ** (-0.5) + + def forward(self, x: Tensor) -> None: + q = self.q_proj(x[:, 0]).unflatten(-1, (self.n_heads, -1)).unsqueeze(2) # (B, n_heads, 1, head_dim) + k = self.k_proj(x).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3) # (B, n_heads, L, head_dim) + v = self.v_proj(x).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3) + + if hasattr(F, "scaled_dot_product_attention"): + out = F.scaled_dot_product_attention(q, k, v, None, self.dropout if self.training else 0.0) + else: + attn = (q * self.scale) @ k.transpose(-1, -2) + out = F.dropout(torch.softmax(attn, -1), self.dropout, self.training) @ v + + return self.out_proj(out.flatten(1)) # (B, n_heads, 1, head_dim) -> (B, n_heads * head_dim) + + +# does not support flash attention +class TalkingHeadAttention(nn.Module): + def __init__(self, d_model: int, n_heads: int, bias: bool = True, dropout: float = 0.0) -> None: + super().__init__() + self.q_proj = nn.Linear(d_model, d_model, bias) + self.k_proj = nn.Linear(d_model, d_model, bias) + self.v_proj = nn.Linear(d_model, d_model, bias) + self.out_proj = nn.Linear(d_model, d_model, bias) + self.talking_head_proj = nn.Sequential( + nn.Conv2d(n_heads, n_heads, 1), # impl as 1x1 conv to avoid permutating data + nn.Softmax(-1), + nn.Conv2d(n_heads, n_heads, 1), + nn.Dropout(dropout), + ) + self.n_heads = n_heads + self.scale = (d_model // n_heads) ** (-0.5) + + def forward(self, x: Tensor) -> Tensor: + q = self.q_proj(x).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3) # (B, n_heads, L, head_dim) + k = self.k_proj(x).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3) + v = self.v_proj(x).unflatten(-1, (self.n_heads, -1)).transpose(-2, -3) + + attn = q @ (k * self.scale).transpose(-1, -2) + out = self.talking_head_proj(attn) @ v + out = out.transpose(-2, -3).flatten(-2) + out = self.out_proj(out) + return out + + +class CaiTCABlock(nn.Module): + def __init__( + self, + d_model: int, + n_heads: int, + bias: bool = True, + mlp_ratio: float = 4.0, + dropout: float = 0.0, + layer_scale_init: float | None = 1e-6, + stochastic_depth: float = 0.0, + norm: _norm = nn.LayerNorm, + act: _act = nn.GELU, + ) -> None: + super().__init__() + self.mha = nn.Sequential( + norm(d_model), + ClassAttention(d_model, n_heads, bias, dropout), + LayerScale(d_model, layer_scale_init) if layer_scale_init is not None else nn.Identity(), + StochasticDepth(stochastic_depth), + ) + self.mlp = nn.Sequential( + norm(d_model), + MLP(d_model, int(d_model * mlp_ratio), dropout, act), + LayerScale(d_model, layer_scale_init) if layer_scale_init is not None else nn.Identity(), + StochasticDepth(stochastic_depth), + ) + + def forward(self, x: Tensor, cls_token: Tensor) -> Tensor: + cls_token = cls_token + self.mha(torch.cat((cls_token, x), 1)) + cls_token = cls_token + self.mlp(cls_token) + return cls_token + + +class CaiTSABlock(nn.Module): + def __init__( + self, + d_model: int, + n_heads: int, + bias: bool = True, + mlp_ratio: float = 4.0, + dropout: float = 0.0, + layer_scale_init: float | None = 1e-6, + stochastic_depth: float = 0.0, + norm: _norm = nn.LayerNorm, + act: _act = nn.GELU, + ) -> None: + super().__init__() + self.mha = nn.Sequential( + norm(d_model), + TalkingHeadAttention(d_model, n_heads, bias, dropout), + LayerScale(d_model, layer_scale_init) if layer_scale_init is not None else nn.Identity(), + StochasticDepth(stochastic_depth), + ) + self.mlp = nn.Sequential( + norm(d_model), + MLP(d_model, int(d_model * mlp_ratio), dropout, act), + LayerScale(d_model, layer_scale_init) if layer_scale_init is not None else nn.Identity(), + StochasticDepth(stochastic_depth), + ) + + def forward(self, x: Tensor) -> Tensor: + x = x + self.mha(x) + x = x + self.mlp(x) + return x + + +class CaiT(nn.Module): + def __init__( + self, + d_model: int, + sa_depth: int, + ca_depth: int, + n_heads: int, + patch_size: int, + img_size: int, + bias: bool = True, + mlp_ratio: float = 4.0, + dropout: float = 0.0, + layer_scale_init: float | None = 1e-6, + stochastic_depth: float = 0.0, + norm: _norm = nn.LayerNorm, + act: _act = nn.GELU, + ) -> None: + assert img_size % patch_size == 0 + super().__init__() + self.patch_embed = nn.Conv2d(3, d_model, patch_size, patch_size) + self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model)) + self.pe = nn.Parameter(torch.empty(1, (img_size // patch_size) ** 2, d_model)) + nn.init.normal_(self.pe, 0, 0.02) + + self.sa_layers = nn.Sequential() + for _ in range(sa_depth): + block = CaiTSABlock( + d_model, n_heads, bias, mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm, act + ) + self.sa_layers.append(block) + + self.ca_layers = nn.ModuleList() + for _ in range(ca_depth): + block = CaiTCABlock( + d_model, n_heads, bias, mlp_ratio, dropout, layer_scale_init, stochastic_depth, norm, act + ) + self.ca_layers.append(block) + + self.norm = norm(d_model) + + def forward(self, imgs: Tensor) -> Tensor: + patches = self.patch_embed(imgs).flatten(2).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C) + print(patches.shape) + print(self.pe.shape) + patches = self.sa_layers(patches + self.pe) + + cls_token = self.cls_token + for block in self.ca_layers: + cls_token = block(patches, cls_token) + return self.norm(cls_token) + + @staticmethod + def from_config(variant: str, img_size: int, pretrained: bool = False) -> CaiT: + variant, sa_depth = variant.split("_") + + d_model = dict(xxs=192, xs=288, s=384, m=768)[variant] + sa_depth = int(sa_depth) + ca_depth = 2 + n_heads = d_model // 48 + patch_size = 16 + m = CaiT(d_model, sa_depth, ca_depth, n_heads, patch_size, img_size) + + if pretrained: + ckpt = dict( + xxs_24_224="XXS24_224.pth", + xxs_24_384="XXS24_384.pth", + xxs_36_224="XXS36_224.pth", + xxs_36_384="XXS36_384.pth", + xs_24_384="XS24_384.pth", + s_24_224="S24_224.pth", + s_24_384="S24_384.pth", + s_36_384="S36_384.pth", + m_36_384="M36_384.pth", + m_48_448="M48_448.pth", + )[f"{variant}_{sa_depth}_{img_size}"] + base_url = "https://dl.fbaipublicfiles.com/deit/" + state_dict = torch.hub.load_state_dict_from_url(base_url + ckpt) + m.load_official_ckpt(state_dict) + + return m + + @torch.no_grad() + def load_official_ckpt(self, state_dict: dict[str, Tensor]) -> None: + raise NotImplementedError()