From bfc91e5696a5069ebbb039302873bcd1df2e3542 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 9 Aug 2023 22:00:56 +0800 Subject: [PATCH] make swin run. fix attn_bias --- tests/test_swin.py | 27 ++++++ vision_toolbox/backbones/__init__.py | 1 + .../{swin_transformer.py => swin.py} | 93 ++++++++++--------- vision_toolbox/backbones/vit.py | 5 +- 4 files changed, 77 insertions(+), 49 deletions(-) create mode 100644 tests/test_swin.py rename vision_toolbox/backbones/{swin_transformer.py => swin.py} (65%) diff --git a/tests/test_swin.py b/tests/test_swin.py new file mode 100644 index 0000000..e9a0c6f --- /dev/null +++ b/tests/test_swin.py @@ -0,0 +1,27 @@ +import torch + +from vision_toolbox.backbones import SwinTransformer +from vision_toolbox.backbones.swin import window_partition, window_unpartition + + +def test_window_partition(): + img = torch.randn(1, 224, 280, 3) + windows, nH, nW = window_partition(img, 7) + _img = window_unpartition(windows, 7, nH, nW) + torch.testing.assert_close(img, _img) + + +def test_forward(): + m = SwinTransformer.from_config("T", 224) + m(torch.randn(1, 3, 224, 224)) + + +# def test_from_pretrained(): +# m = ViT.from_config("Ti", 16, 224, True).eval() +# x = torch.randn(1, 3, 224, 224) +# out = m(x) + +# m_timm = timm.create_model("vit_tiny_patch16_224.augreg_in21k", pretrained=True, num_classes=0).eval() +# out_timm = m_timm(x) + +# torch.testing.assert_close(out, out_timm, rtol=2e-5, atol=2e-5) diff --git a/vision_toolbox/backbones/__init__.py b/vision_toolbox/backbones/__init__.py index c814727..3ca3664 100644 --- a/vision_toolbox/backbones/__init__.py +++ b/vision_toolbox/backbones/__init__.py @@ -4,3 +4,4 @@ from .torchvision_models import EfficientNetExtractor, MobileNetExtractor, RegNetExtractor, ResNetExtractor from .vit import ViT from .vovnet import VoVNet +from .swin import SwinTransformer diff --git a/vision_toolbox/backbones/swin_transformer.py b/vision_toolbox/backbones/swin.py similarity index 65% rename from vision_toolbox/backbones/swin_transformer.py rename to vision_toolbox/backbones/swin.py index 36e30b6..4ca7e7d 100644 --- a/vision_toolbox/backbones/swin_transformer.py +++ b/vision_toolbox/backbones/swin.py @@ -41,6 +41,7 @@ def __init__( dropout: float = 0.0, ) -> None: super().__init__(d_model, n_heads, bias, dropout) + self.input_size = input_size self.window_size = window_size if shift: @@ -51,8 +52,9 @@ def __init__( for i, (h_slice, w_slice) in enumerate(itertools.product(slices, slices)): img_mask[0, h_slice, w_slice, 0] = i - windows_mask, _, _ = window_partition(img_mask) - attn_mask = windows_mask.unsqueeze(1) - windows_mask.unsqueeze(2) + windows_mask, _, _ = window_partition(img_mask, window_size) # (nH * nW, win_size * win_size, 1) + windows_mask = windows_mask.transpose(1, 2) # (nH * nW, 1, win_size * win_size) + attn_mask = windows_mask.unsqueeze(2) - windows_mask.unsqueeze(3) self.register_buffer("attn_mask", (attn_mask != 0) * (-100), False) self.attn_mask: Tensor @@ -60,17 +62,18 @@ def __init__( self.shift = 0 self.attn_mask = None - self.relative_pe_table = nn.Parameter(torch.empty(n_heads, (2 * window_size - 1) ** 2)) + self.relative_pe_table = nn.Parameter(torch.empty(1, n_heads, (2 * window_size - 1) ** 2)) nn.init.trunc_normal_(self.relative_pe_table, 0, 0.02) xy = torch.cartesian_prod(torch.arange(window_size), torch.arange(window_size)) # all possible (x,y) pairs diff = xy.unsqueeze(1) - xy.unsqueeze(0) # difference between all (x,y) pairs index = (diff[:, :, 0] + window_size - 1) * (2 * window_size - 1) + diff[:, :, 1] + window_size - 1 - self.register_buffer("relative_pe_index", index.flatten(), False) + self.register_buffer("relative_pe_index", index, False) self.relative_pe_index: Tensor def forward(self, x: Tensor) -> Tensor: - attn_bias = self.relative_pe_table[:, self.relative_pe_index].unsqueeze(0) + assert x.shape[1] == self.input_size, (x.shape[1], self.input_size) + attn_bias = self.relative_pe_table[..., self.relative_pe_index] if self.shift > 0: x = x.roll((self.shift, self.shift), (1, 2)) attn_bias = attn_bias + self.attn_mask @@ -100,9 +103,9 @@ def __init__( ) -> None: super().__init__() self.norm1 = norm(d_model) - self.mha = WindowAttention(input_size, d_model, window_size, shift, n_heads, bias, dropout) + self.mha = WindowAttention(input_size, d_model, n_heads, window_size, shift, bias, dropout) self.norm2 = norm(d_model) - self.mlp = MLP(d_model, int(d_model * mlp_ratio), act) + self.mlp = MLP(d_model, int(d_model * mlp_ratio), dropout, act) def forward(self, x: Tensor) -> Tensor: x = self.mha(self.norm1(x)) @@ -117,39 +120,20 @@ def __init__(self, d_model: int, norm: _norm = nn.LayerNorm) -> None: self.reduction = nn.Linear(d_model * 4, d_model * 2, False) def forward(self, x: Tensor) -> Tensor: - x, _, _ = window_partition(x, 2) - return self.reduction(self.norm(x)) - - -class SwinStage(nn.Sequential): - def __init__( - self, - input_size: int, - d_model: int, - n_heads: int, - depth: int, - downsample: bool = False, - window_size: int = 7, - mlp_ratio: float = 4.0, - bias: bool = True, - dropout: float = 0.0, - norm: _norm = nn.LayerNorm, - act: _act = nn.GELU, - ) -> None: - super().__init__() - for i in range(depth): - blk = SwinBlock(input_size, d_model, n_heads, window_size, i % 2 == 1, mlp_ratio, bias, dropout, norm, act) - self.append(blk) - self.downsample = PatchMerging(d_model, norm) if downsample else None + B, H, W, C = x.shape + x = x.view(B, H // 2, 2, W // 2, 2, C).transpose(2, 3).flatten(-3) + x = self.reduction(self.norm(x)) + x = x.view(B, H // 2, W // 2, C * 2) + return x -class SwinTransformer(BaseBackbone): +class SwinTransformer(nn.Module): def __init__( self, img_size: int, d_model: int, n_heads: int, - depths: tuple[int, int, int, int], + depths: tuple[int, ...], patch_size: int = 4, window_size: int = 7, mlp_ratio: float = 4.0, @@ -158,29 +142,46 @@ def __init__( norm: _norm = nn.LayerNorm, act: _act = nn.GELU, ) -> None: + assert img_size % patch_size == 0 + assert img_size % window_size == 0 + assert d_model % n_heads == 0 super().__init__() self.patch_embed = nn.Conv2d(3, d_model, patch_size, patch_size) self.norm = norm(d_model) + self.dropout = nn.Dropout(dropout) + + input_size = img_size // patch_size + self.stages = nn.Sequential() + for i, depth in enumerate(depths): + stage = nn.Sequential() + for i in range(depth): + blk = SwinBlock(input_size, d_model, n_heads, window_size, i % 2, mlp_ratio, bias, dropout, norm, act) + stage.append(blk) + + if i < len(depths) - 1: + stage.append(PatchMerging(d_model, norm)) + input_size //= 2 + d_model *= 2 + n_heads *= 2 - self.stages = nn.ModuleList() - for depth in depths: - stage = SwinStage(img_size, d_model, n_heads, depth, window_size, mlp_ratio, bias, dropout, norm, act) self.stages.append(stage) - img_size //= 2 - d_model *= 2 - n_heads *= 2 - def forward_features(self, x: Tensor) -> Tensor: - x = self.norm(self.patch_embed(x).permute(0, 2, 3, 1)) - for stage in self.stages: - x = stage(x) + self.head_norm = norm(d_model) + + def forward(self, x: Tensor) -> Tensor: + x = self.dropout(self.norm(self.patch_embed(x).permute(0, 2, 3, 1))) + x = self.stages(x) + x = self.head_norm(x).mean((1, 2)) + return x @staticmethod - def from_config(variant: str, pretrained: bool = False) -> SwinTransformer: - d_model, n_heads, n_layers = dict( + def from_config(variant: str, img_size: int, pretrained: bool = False) -> SwinTransformer: + d_model, n_heads, depths = dict( T=(96, 3, (2, 2, 6, 2)), S=(96, 3, (2, 2, 18, 2)), B=(128, 4, (2, 2, 18, 2)), L=(192, 6, (2, 2, 18, 2)), )[variant] - m = SwinTransformer(d_model, n_heads, n_layers) + m = SwinTransformer(img_size, d_model, n_heads, depths) + + return m diff --git a/vision_toolbox/backbones/vit.py b/vision_toolbox/backbones/vit.py index 81cb1ff..cdb3a29 100644 --- a/vision_toolbox/backbones/vit.py +++ b/vision_toolbox/backbones/vit.py @@ -28,14 +28,13 @@ def __init__(self, d_model: int, n_heads: int, bias: bool = True, dropout: float def forward(self, x: Tensor, attn_bias: Tensor | None = None) -> Tensor: qkv = self.in_proj(x) q, k, v = qkv.unflatten(-1, (3, self.n_heads, -1)).transpose(-2, -4).unbind(-3) - if hasattr(F, "scaled_dot_product_attention"): out = F.scaled_dot_product_attention(q, k, v, attn_bias, self.dropout if self.training else 0.0) else: - attn = torch.softmax(q @ (k * self.scale).transpose(-1, -2), -1) + attn = q @ (k * self.scale).transpose(-1, -2) if attn_bias is not None: attn = attn + attn_bias - out = F.dropout(attn, self.dropout, self.training) @ v + out = F.dropout(torch.softmax(attn, -1), self.dropout, self.training) @ v out = out.transpose(-2, -3).flatten(-2) out = self.out_proj(out)