From 20db77e030868e75c9e1ce71bc06c432ecb8d87b Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 19 Aug 2023 11:16:47 +0800 Subject: [PATCH] Add Swin Transformer (#16) --- tests/test_swin.py | 20 ++ vision_toolbox/backbones/__init__.py | 1 + vision_toolbox/backbones/mlp_mixer.py | 19 +- vision_toolbox/backbones/swin.py | 251 ++++++++++++++++++++++++++ vision_toolbox/backbones/vit.py | 43 +++-- 5 files changed, 306 insertions(+), 28 deletions(-) create mode 100644 tests/test_swin.py create mode 100644 vision_toolbox/backbones/swin.py diff --git a/tests/test_swin.py b/tests/test_swin.py new file mode 100644 index 0000000..da0d2e9 --- /dev/null +++ b/tests/test_swin.py @@ -0,0 +1,20 @@ +import timm +import torch + +from vision_toolbox.backbones import SwinTransformer + + +def test_forward(): + m = SwinTransformer.from_config("T", 224) + m(torch.randn(1, 3, 224, 224)) + + +def test_from_pretrained(): + m = SwinTransformer.from_config("T", 224, True).eval() + x = torch.randn(1, 3, 224, 224) + out = m(x) + + m_timm = timm.create_model("swin_tiny_patch4_window7_224.ms_in22k", pretrained=True, num_classes=0).eval() + out_timm = m_timm(x) + + torch.testing.assert_close(out, out_timm, rtol=2e-5, atol=2e-5) diff --git a/vision_toolbox/backbones/__init__.py b/vision_toolbox/backbones/__init__.py index c814727..55a1857 100644 --- a/vision_toolbox/backbones/__init__.py +++ b/vision_toolbox/backbones/__init__.py @@ -1,6 +1,7 @@ from .darknet import Darknet, DarknetYOLOv5 from .mlp_mixer import MLPMixer from .patchconvnet import PatchConvNet +from .swin import SwinTransformer from .torchvision_models import EfficientNetExtractor, MobileNetExtractor, RegNetExtractor, ResNetExtractor from .vit import ViT from .vovnet import VoVNet diff --git a/vision_toolbox/backbones/mlp_mixer.py b/vision_toolbox/backbones/mlp_mixer.py index 124f7ce..3e07665 100644 --- a/vision_toolbox/backbones/mlp_mixer.py +++ b/vision_toolbox/backbones/mlp_mixer.py @@ -12,14 +12,7 @@ from ..utils import torch_hub_download from .base import _act, _norm - - -class MLP(nn.Sequential): - def __init__(self, in_dim: int, hidden_dim: float, act: _act = nn.GELU) -> None: - super().__init__() - self.linear1 = nn.Linear(in_dim, hidden_dim) - self.act = act() - self.linear2 = nn.Linear(hidden_dim, in_dim) +from .vit import MLP class MixerBlock(nn.Module): @@ -28,15 +21,16 @@ def __init__( n_tokens: int, d_model: int, mlp_ratio: tuple[int, int] = (0.5, 4.0), + dropout: float = 0.0, norm: _norm = partial(nn.LayerNorm, eps=1e-6), act: _act = nn.GELU, ) -> None: tokens_mlp_dim, channels_mlp_dim = [int(d_model * ratio) for ratio in mlp_ratio] super().__init__() self.norm1 = norm(d_model) - self.token_mixing = MLP(n_tokens, tokens_mlp_dim, act) + self.token_mixing = MLP(n_tokens, tokens_mlp_dim, dropout, act) self.norm2 = norm(d_model) - self.channel_mixing = MLP(d_model, channels_mlp_dim, act) + self.channel_mixing = MLP(d_model, channels_mlp_dim, dropout, act) def forward(self, x: Tensor) -> Tensor: # x -> (B, n_tokens, d_model) @@ -53,6 +47,7 @@ def __init__( patch_size: int, img_size: int, mlp_ratio: tuple[float, float] = (0.5, 4.0), + dropout: float = 0.0, norm: _norm = partial(nn.LayerNorm, eps=1e-6), act: _act = nn.GELU, ) -> None: @@ -60,7 +55,9 @@ def __init__( super().__init__() self.patch_embed = nn.Conv2d(3, d_model, patch_size, patch_size) n_tokens = (img_size // patch_size) ** 2 - self.layers = nn.Sequential(*[MixerBlock(n_tokens, d_model, mlp_ratio, norm, act) for _ in range(n_layers)]) + self.layers = nn.Sequential( + *[MixerBlock(n_tokens, d_model, mlp_ratio, dropout, norm, act) for _ in range(n_layers)] + ) self.norm = norm(d_model) def forward(self, x: Tensor) -> Tensor: diff --git a/vision_toolbox/backbones/swin.py b/vision_toolbox/backbones/swin.py new file mode 100644 index 0000000..6de5beb --- /dev/null +++ b/vision_toolbox/backbones/swin.py @@ -0,0 +1,251 @@ +# https://arxiv.org/abs/2103.14030 +# https://github.com/microsoft/Swin-Transformer + +from __future__ import annotations + +import itertools + +import torch +from torch import Tensor, nn + +from .base import BaseBackbone, _act, _norm +from .vit import MHA, MLP + + +def window_partition(x: Tensor, window_size: int) -> tuple[Tensor, int, int]: + B, H, W, C = x.shape + nH, nW = H // window_size, W // window_size + x = x.view(B, nH, window_size, nW, window_size, C) + x = x.transpose(2, 3).reshape(B * nH * nW, window_size * window_size, C) + return x, nH, nW + + +def window_unpartition(x: Tensor, window_size: int, nH: int, nW: int) -> Tensor: + B = x.shape[0] // (nH * nW) + C = x.shape[2] + x = x.view(B, nH, nW, window_size, window_size, C) + x = x.transpose(2, 3).reshape(B, nH * window_size, nW * window_size, C) + return x + + +class WindowAttention(MHA): + def __init__( + self, + input_size: int, + d_model: int, + n_heads: int, + window_size: int = 7, + shift: bool = False, + bias: bool = True, + dropout: float = 0.0, + ) -> None: + super().__init__(d_model, n_heads, bias, dropout) + self.input_size = input_size + self.window_size = window_size + + if shift: + self.shift = window_size // 2 + + img_mask = torch.zeros(1, input_size, input_size, 1) + slices = (slice(0, -window_size), slice(-window_size, -self.shift), slice(-self.shift, None)) + for i, (h_slice, w_slice) in enumerate(itertools.product(slices, slices)): + img_mask[:, h_slice, w_slice, :] = i + + windows_mask, _, _ = window_partition(img_mask, window_size) # (nH * nW, win_size * win_size, 1) + attn_mask = windows_mask.transpose(1, 2) - windows_mask + self.register_buffer("attn_mask", (attn_mask != 0) * (-100.0), False) + self.attn_mask: Tensor + + else: + self.shift = 0 + self.attn_mask = None + + self.relative_pe_table = nn.Parameter(torch.empty(1, n_heads, (2 * window_size - 1) ** 2)) + nn.init.trunc_normal_(self.relative_pe_table, 0, 0.02) + + xy = torch.cartesian_prod(torch.arange(window_size), torch.arange(window_size)) # all possible (x,y) pairs + diff = xy.unsqueeze(1) - xy.unsqueeze(0) # difference between all (x,y) pairs + index = (diff[:, :, 0] + window_size - 1) * (2 * window_size - 1) + diff[:, :, 1] + window_size - 1 + self.register_buffer("relative_pe_index", index, False) + self.relative_pe_index: Tensor + + def forward(self, x: Tensor) -> Tensor: + assert x.shape[1] == self.input_size, (x.shape[1], self.input_size) + attn_bias = self.relative_pe_table[..., self.relative_pe_index] + if self.shift > 0: + x = x.roll((-self.shift, -self.shift), (1, 2)) + attn_bias = attn_bias + self.attn_mask.unsqueeze(1) # add n_heads dim + + x, nH, nW = window_partition(x, self.window_size) # (B * nH * nW, win_size * win_size, C) + x = super().forward(x, attn_bias) + x = window_unpartition(x, self.window_size, nH, nW) # (B, H, W, C) + + if self.shift > 0: + x = x.roll((self.shift, self.shift), (1, 2)) + return x + + +class SwinBlock(nn.Module): + def __init__( + self, + input_size: int, + d_model: int, + n_heads: int, + window_size: int = 7, + shift: bool = False, + mlp_ratio: float = 4.0, + bias: bool = True, + dropout: float = 0.0, + norm: _norm = nn.LayerNorm, + act: _act = nn.GELU, + ) -> None: + super().__init__() + self.norm1 = norm(d_model) + self.mha = WindowAttention(input_size, d_model, n_heads, window_size, shift, bias, dropout) + self.norm2 = norm(d_model) + self.mlp = MLP(d_model, int(d_model * mlp_ratio), dropout, act) + + def forward(self, x: Tensor) -> Tensor: + x = x + self.mha(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + + +class PatchMerging(nn.Module): + def __init__(self, d_model: int, norm: _norm = nn.LayerNorm) -> None: + super().__init__() + self.norm = norm(d_model * 4) + self.reduction = nn.Linear(d_model * 4, d_model * 2, False) + + def forward(self, x: Tensor) -> Tensor: + B, H, W, C = x.shape + x = x.view(B, H // 2, 2, W // 2, 2, C).transpose(2, 3).flatten(-3) + x = self.reduction(self.norm(x)) + x = x.view(B, H // 2, W // 2, C * 2) + return x + + +class SwinTransformer(BaseBackbone): + def __init__( + self, + img_size: int, + d_model: int, + n_heads: int, + depths: tuple[int, ...], + window_sizes: tuple[int, ...], + patch_size: int = 4, + mlp_ratio: float = 4.0, + bias: bool = True, + dropout: float = 0.0, + norm: _norm = nn.LayerNorm, + act: _act = nn.GELU, + ) -> None: + assert img_size % patch_size == 0 + assert d_model % n_heads == 0 + super().__init__() + self.patch_embed = nn.Conv2d(3, d_model, patch_size, patch_size) + self.norm = norm(d_model) + self.dropout = nn.Dropout(dropout) + + input_size = img_size // patch_size + self.stages = nn.Sequential() + for i, (depth, window_size) in enumerate(zip(depths, window_sizes)): + stage = nn.Sequential() + if i > 0: + downsample = PatchMerging(d_model, norm) + input_size //= 2 + d_model *= 2 + n_heads *= 2 + else: + downsample = nn.Identity() + stage.append(downsample) + + for i in range(depth): + shift = (i % 2) and input_size > window_size + block = SwinBlock(input_size, d_model, n_heads, window_size, shift, mlp_ratio, bias, dropout, norm, act) + stage.append(block) + + self.stages.append(stage) + + self.head_norm = norm(d_model) + + def get_feature_maps(self, x: Tensor) -> list[Tensor]: + out = [self.dropout(self.norm(self.patch_embed(x).permute(0, 2, 3, 1)))] + for stage in self.stages: + out.append(stage(out[-1])) + return out[1:] + + def forward(self, x: Tensor) -> Tensor: + return self.head_norm(self.get_feature_maps(x)[-1]).mean((1, 2)) + + def resize_pe(self, img_size: int) -> None: + raise NotImplementedError() + + @staticmethod + def from_config(variant: str, img_size: int, pretrained: bool = False) -> SwinTransformer: + d_model, n_heads, depths, window_sizes, ckpt = { + # Sub-section 3.3 in https://arxiv.org/pdf/2103.14030.pdf + "T": (96, 3, (2, 2, 6, 2), (7, 7, 7, 7), "v1.0.8/swin_tiny_patch4_window7_224_22k.pth"), + "S": (96, 3, (2, 2, 18, 2), (7, 7, 7, 7), "v1.0.8/swin_small_patch4_window7_224_22k.pth"), + "B": (128, 4, (2, 2, 18, 2), (7, 7, 7, 7), "v1.0.0/swin_base_patch4_window7_224_22k.pth"), + "L": (192, 6, (2, 2, 18, 2), (7, 7, 7, 7), "v1.0.0/swin_large_patch4_window7_224_22k.pth"), + # https://github.com/microsoft/Cream/blob/main/AutoFormerV2/configs + "S3-T": (96, 3, (2, 2, 6, 2), (7, 7, 14, 7), "supernet-tiny.pth"), + "S3-S": (96, 3, (2, 2, 18, 2), (14, 14, 14, 14), "supernet-small.pth"), + "S3-B": (96, 3, (2, 2, 30, 2), (7, 7, 14, 7), "supernet-base.pth"), + }[variant] + m = SwinTransformer(224 if pretrained else img_size, d_model, n_heads, depths, window_sizes) + + if pretrained: + if variant.startswith("S3"): + base_url = "https://github.com/silent-chen/AutoFormer-model-zoo/releases/download/v1.0/" + else: + base_url = "https://github.com/SwinTransformer/storage/releases/download/" + state_dict = torch.hub.load_state_dict_from_url(base_url + ckpt)["model"] + m.load_official_ckpt(state_dict) + if img_size != 224: + m.resize_pe(img_size) + + return m + + @torch.no_grad() + def load_official_ckpt(self, state_dict: dict[str, Tensor]) -> None: + def copy_(m: nn.Linear | nn.LayerNorm, prefix: str) -> None: + m.weight.copy_(state_dict.pop(prefix + ".weight")) + m.bias.copy_(state_dict.pop(prefix + ".bias")) + + copy_(self.patch_embed, "patch_embed.proj") + copy_(self.norm, "patch_embed.norm") + + for stage_idx, stage in enumerate(self.stages): + if stage_idx > 0: + prefix = f"layers.{stage_idx-1}.downsample." + + def rearrange(p): + p1, p2, p3, p4 = p.chunk(4, -1) + return torch.cat((p1, p3, p2, p4), -1) + + stage[0].norm.weight.copy_(rearrange(state_dict.pop(prefix + "norm.weight"))) + stage[0].norm.bias.copy_(rearrange(state_dict.pop(prefix + "norm.bias"))) + stage[0].reduction.weight.copy_(rearrange(state_dict.pop(prefix + "reduction.weight"))) + + for block_idx in range(1, len(stage)): + block: SwinBlock = stage[block_idx] + prefix = f"layers.{stage_idx}.blocks.{block_idx - 1}." + block_idx += 1 + + if block.mha.attn_mask is not None: + torch.testing.assert_close(block.mha.attn_mask, state_dict.pop(prefix + "attn_mask")) + torch.testing.assert_close( + block.mha.relative_pe_index, state_dict.pop(prefix + "attn.relative_position_index") + ) + copy_(block.norm1, prefix + "norm1") + copy_(block.mha.in_proj, prefix + "attn.qkv") + copy_(block.mha.out_proj, prefix + "attn.proj") + block.mha.relative_pe_table.copy_(state_dict.pop(prefix + "attn.relative_position_bias_table").T) + copy_(block.norm2, prefix + "norm2") + copy_(block.mlp.linear1, prefix + "mlp.fc1") + copy_(block.mlp.linear2, prefix + "mlp.fc2") + + copy_(self.head_norm, "norm") + assert len(state_dict) == 2 # head.weight and head.bias diff --git a/vision_toolbox/backbones/vit.py b/vision_toolbox/backbones/vit.py index becda7d..c1e4ed8 100644 --- a/vision_toolbox/backbones/vit.py +++ b/vision_toolbox/backbones/vit.py @@ -25,27 +25,38 @@ def __init__(self, d_model: int, n_heads: int, bias: bool = True, dropout: float self.dropout = dropout self.scale = (d_model // n_heads) ** (-0.5) - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: Tensor, attn_bias: Tensor | None = None) -> Tensor: qkv = self.in_proj(x) - q, k, v = qkv.unflatten(-1, (3, self.n_heads, -1)).transpose(-2, -4).unbind(-3) - + q, k, v = qkv.unflatten(-1, (3, self.n_heads, -1)).transpose(-2, -4).unbind(-3) # (B, n_heads, L, head_dim) if hasattr(F, "scaled_dot_product_attention"): - out = F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout) + out = F.scaled_dot_product_attention(q, k, v, attn_bias, self.dropout if self.training else 0.0) else: - attn = torch.softmax(q @ (k * self.scale).transpose(-1, -2), -1) - out = F.dropout(attn, self.dropout, self.training) @ v + attn = q @ (k * self.scale).transpose(-1, -2) + if attn_bias is not None: + attn = attn + attn_bias + out = F.dropout(torch.softmax(attn, -1), self.dropout, self.training) @ v out = out.transpose(-2, -3).flatten(-2) out = self.out_proj(out) return out +class MLP(nn.Sequential): + def __init__(self, in_dim: int, hidden_dim: float, dropout: float = 0.0, act: _act = nn.GELU) -> None: + super().__init__() + self.linear1 = nn.Linear(in_dim, hidden_dim) + self.act = act() + self.linear2 = nn.Linear(hidden_dim, in_dim) + self.dropout = nn.Dropout(dropout) + + class ViTBlock(nn.Module): def __init__( self, d_model: int, n_heads: int, bias: bool = True, + mlp_ratio: float = 4.0, dropout: float = 0.0, norm: _norm = partial(nn.LayerNorm, eps=1e-6), act: _act = nn.GELU, @@ -54,12 +65,7 @@ def __init__( self.norm1 = norm(d_model) self.mha = MHA(d_model, n_heads, bias, dropout) self.norm2 = norm(d_model) - self.mlp = nn.Sequential( - nn.Linear(d_model, d_model * 4, bias), - act(), - nn.Linear(d_model * 4, d_model, bias), - nn.Dropout(dropout), - ) + self.mlp = MLP(d_model, int(d_model * mlp_ratio), dropout, act) def forward(self, x: Tensor) -> Tensor: x = x + self.mha(self.norm1(x)) @@ -77,6 +83,7 @@ def __init__( img_size: int, cls_token: bool = True, bias: bool = True, + mlp_ratio: float = 4.0, dropout: float = 0.0, norm: _norm = partial(nn.LayerNorm, eps=1e-6), act: _act = nn.GELU, @@ -92,7 +99,9 @@ def __init__( self.pe = nn.Parameter(torch.empty(1, pe_size, d_model)) nn.init.normal_(self.pe, 0, 0.02) - self.layers = nn.Sequential(*[ViTBlock(d_model, n_heads, bias, dropout, norm, act) for _ in range(n_layers)]) + self.layers = nn.Sequential( + *[ViTBlock(d_model, n_heads, bias, mlp_ratio, dropout, norm, act) for _ in range(n_layers)] + ) self.norm = norm(d_model) def forward(self, imgs: Tensor) -> Tensor: @@ -175,10 +184,10 @@ def get_w(key: str) -> Tensor: layer.norm2.weight.copy_(get_w(prefix + "LayerNorm_2/scale")) layer.norm2.bias.copy_(get_w(prefix + "LayerNorm_2/bias")) - layer.mlp[0].weight.copy_(get_w(prefix + "MlpBlock_3/Dense_0/kernel").T) - layer.mlp[0].bias.copy_(get_w(prefix + "MlpBlock_3/Dense_0/bias")) - layer.mlp[2].weight.copy_(get_w(prefix + "MlpBlock_3/Dense_1/kernel").T) - layer.mlp[2].bias.copy_(get_w(prefix + "MlpBlock_3/Dense_1/bias")) + layer.mlp.linear1.weight.copy_(get_w(prefix + "MlpBlock_3/Dense_0/kernel").T) + layer.mlp.linear1.bias.copy_(get_w(prefix + "MlpBlock_3/Dense_0/bias")) + layer.mlp.linear2.weight.copy_(get_w(prefix + "MlpBlock_3/Dense_1/kernel").T) + layer.mlp.linear2.bias.copy_(get_w(prefix + "MlpBlock_3/Dense_1/bias")) self.norm.weight.copy_(get_w("Transformer/encoder_norm/scale")) self.norm.bias.copy_(get_w("Transformer/encoder_norm/bias"))