From c90825fa990b449d86282e1e2ef0a39706959e68 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 17 Aug 2023 22:46:04 +0800 Subject: [PATCH] fix some bugs with Swin --- vision_toolbox/backbones/swin.py | 49 ++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/vision_toolbox/backbones/swin.py b/vision_toolbox/backbones/swin.py index e54515e..4cfed40 100644 --- a/vision_toolbox/backbones/swin.py +++ b/vision_toolbox/backbones/swin.py @@ -52,9 +52,8 @@ def __init__( img_mask[0, h_slice, w_slice, 0] = i windows_mask, _, _ = window_partition(img_mask, window_size) # (nH * nW, win_size * win_size, 1) - windows_mask = windows_mask.transpose(1, 2) # (nH * nW, 1, win_size * win_size) - attn_mask = windows_mask.unsqueeze(2) - windows_mask.unsqueeze(3) - self.register_buffer("attn_mask", (attn_mask != 0) * (-100), False) + attn_mask = windows_mask - windows_mask.transpose(1, 2) + self.register_buffer("attn_mask", (attn_mask != 0).unsqueeze(1) * (-100), False) self.attn_mask: Tensor else: @@ -74,7 +73,7 @@ def forward(self, x: Tensor) -> Tensor: assert x.shape[1] == self.input_size, (x.shape[1], self.input_size) attn_bias = self.relative_pe_table[..., self.relative_pe_index] if self.shift > 0: - x = x.roll((self.shift, self.shift), (1, 2)) + x = x.roll((-self.shift, -self.shift), (1, 2)) attn_bias = attn_bias + self.attn_mask x, nH, nW = window_partition(x, self.window_size) # (B * nH * nW, win_size * win_size, C) @@ -82,7 +81,7 @@ def forward(self, x: Tensor) -> Tensor: x = window_unpartition(x, self.window_size, nH, nW) if self.shift > 0: - x = x.roll((-self.shift, -self.shift), (1, 2)) + x = x.roll((self.shift, self.shift), (1, 2)) return x @@ -107,8 +106,8 @@ def __init__( self.mlp = MLP(d_model, int(d_model * mlp_ratio), dropout, act) def forward(self, x: Tensor) -> Tensor: - x = self.mha(self.norm1(x)) - x = self.mlp(self.norm2(x)) + x = x + self.mha(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) return x @@ -120,7 +119,8 @@ def __init__(self, d_model: int, norm: _norm = nn.LayerNorm) -> None: def forward(self, x: Tensor) -> Tensor: B, H, W, C = x.shape - x = x.view(B, H // 2, 2, W // 2, 2, C).transpose(2, 3).flatten(-3) + # x = x.view(B, H // 2, 2, W // 2, 2, C).transpose(2, 3).flatten(-3) + x = x.view(B, H // 2, 2, W // 2, 2, C).permute(0, 1, 3, 4, 2, 5).flatten(-3) x = self.reduction(self.norm(x)) x = x.view(B, H // 2, W // 2, C * 2) return x @@ -153,14 +153,17 @@ def __init__( for i, (depth, window_size) in enumerate(zip(depths, window_sizes)): stage = nn.Sequential() if i > 0: - stage.append(PatchMerging(d_model, norm)) + downsample = PatchMerging(d_model, norm) input_size //= 2 d_model *= 2 n_heads *= 2 + else: + downsample = nn.Identity() + stage.add_module("downsample", downsample) for i in range(depth): blk = SwinBlock(input_size, d_model, n_heads, window_size, i % 2, mlp_ratio, bias, dropout, norm, act) - stage.append(blk) + stage.add_module(f"block_{i}", blk) self.stages.append(stage) self.head_norm = norm(d_model) @@ -175,7 +178,7 @@ def forward(self, x: Tensor) -> Tensor: return self.head_norm(self.get_feature_maps(x)[-1]).mean((1, 2)) def resize_pe(self, img_size: int) -> None: - pass + raise NotImplementedError() @staticmethod def from_config(variant: str, img_size: int, pretrained: bool = False) -> SwinTransformer: @@ -190,16 +193,17 @@ def from_config(variant: str, img_size: int, pretrained: bool = False) -> SwinTr "S3-S": (96, 3, (2, 2, 18, 2), (14, 14, 14, 14), "supernet-small.pth"), "S3-B": (96, 3, (2, 2, 30, 2), (7, 7, 14, 7), "supernet-base.pth"), }[variant] - m = SwinTransformer(img_size, d_model, n_heads, depths, window_sizes) + m = SwinTransformer(224 if pretrained else img_size, d_model, n_heads, depths, window_sizes) if pretrained: - base_url = ( - "https://github.com/silent-chen/AutoFormer-model-zoo/releases/download/v1.0/" - if variant.startswith("S3") - else "https://github.com/SwinTransformer/storage/releases/download/v1.0.8/" - ) + if variant.startswith("S3"): + base_url = "https://github.com/silent-chen/AutoFormer-model-zoo/releases/download/v1.0/" + else: + base_url = "https://github.com/SwinTransformer/storage/releases/download/" state_dict = torch.hub.load_state_dict_from_url(base_url + ckpt)["model"] m.load_official_ckpt(state_dict) + if img_size != 224: + m.resize_pe(img_size) return m @@ -214,16 +218,17 @@ def copy_(m: nn.Linear | nn.LayerNorm, prefix: str) -> None: for stage_i, stage in enumerate(self.stages): if stage_i > 0: - downsample: PatchMerging = stage[0] - downsample.reduction.weight.copy_(state_dict[f"layers.{stage_i-1}.downsample.reduction.weight"]) + stage.downsample.reduction.weight.copy_(state_dict[f"layers.{stage_i-1}.downsample.reduction.weight"]) - for block_idx, block in enumerate(stage): - block: SwinBlock + for block_idx in range(len(stage) - 1): + block: SwinBlock = getattr(stage, f"block_{block_idx}") prefix = f"layers.{stage_i}.blocks.{block_idx}." + block_idx += 1 + copy_(block.norm1, prefix + "norm1") copy_(block.mha.in_proj, prefix + "attn.qkv") copy_(block.mha.out_proj, prefix + "attn.proj") - block.mha.relative_pe_table.copy_(state_dict[prefix + "attn.relative_position_bias_table"]) + block.mha.relative_pe_table.copy_(state_dict[prefix + "attn.relative_position_bias_table"].T) copy_(block.norm2, prefix + "norm2") copy_(block.mlp.linear1, prefix + "mlp.fc1") copy_(block.mlp.linear2, prefix + "mlp.fc2")