From f6e48d8255f57ee3c21aef40ebb84842d3a7237c Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 19 Aug 2023 10:57:29 +0800 Subject: [PATCH] don't shift when input_size = window_size --- tests/test_swin.py | 23 ++++++---------- vision_toolbox/backbones/swin.py | 45 ++++++++++++++++++++------------ vision_toolbox/backbones/vit.py | 2 +- 3 files changed, 37 insertions(+), 33 deletions(-) diff --git a/tests/test_swin.py b/tests/test_swin.py index e9a0c6f..da0d2e9 100644 --- a/tests/test_swin.py +++ b/tests/test_swin.py @@ -1,14 +1,7 @@ +import timm import torch from vision_toolbox.backbones import SwinTransformer -from vision_toolbox.backbones.swin import window_partition, window_unpartition - - -def test_window_partition(): - img = torch.randn(1, 224, 280, 3) - windows, nH, nW = window_partition(img, 7) - _img = window_unpartition(windows, 7, nH, nW) - torch.testing.assert_close(img, _img) def test_forward(): @@ -16,12 +9,12 @@ def test_forward(): m(torch.randn(1, 3, 224, 224)) -# def test_from_pretrained(): -# m = ViT.from_config("Ti", 16, 224, True).eval() -# x = torch.randn(1, 3, 224, 224) -# out = m(x) +def test_from_pretrained(): + m = SwinTransformer.from_config("T", 224, True).eval() + x = torch.randn(1, 3, 224, 224) + out = m(x) -# m_timm = timm.create_model("vit_tiny_patch16_224.augreg_in21k", pretrained=True, num_classes=0).eval() -# out_timm = m_timm(x) + m_timm = timm.create_model("swin_tiny_patch4_window7_224.ms_in22k", pretrained=True, num_classes=0).eval() + out_timm = m_timm(x) -# torch.testing.assert_close(out, out_timm, rtol=2e-5, atol=2e-5) + torch.testing.assert_close(out, out_timm, rtol=2e-5, atol=2e-5) diff --git a/vision_toolbox/backbones/swin.py b/vision_toolbox/backbones/swin.py index 4cfed40..c1e204c 100644 --- a/vision_toolbox/backbones/swin.py +++ b/vision_toolbox/backbones/swin.py @@ -49,11 +49,11 @@ def __init__( img_mask = torch.zeros(1, input_size, input_size, 1) slices = (slice(0, -window_size), slice(-window_size, -self.shift), slice(-self.shift, None)) for i, (h_slice, w_slice) in enumerate(itertools.product(slices, slices)): - img_mask[0, h_slice, w_slice, 0] = i + img_mask[:, h_slice, w_slice, :] = i windows_mask, _, _ = window_partition(img_mask, window_size) # (nH * nW, win_size * win_size, 1) - attn_mask = windows_mask - windows_mask.transpose(1, 2) - self.register_buffer("attn_mask", (attn_mask != 0).unsqueeze(1) * (-100), False) + attn_mask = windows_mask.transpose(1, 2) - windows_mask + self.register_buffer("attn_mask", (attn_mask != 0) * (-100.0), False) self.attn_mask: Tensor else: @@ -74,11 +74,11 @@ def forward(self, x: Tensor) -> Tensor: attn_bias = self.relative_pe_table[..., self.relative_pe_index] if self.shift > 0: x = x.roll((-self.shift, -self.shift), (1, 2)) - attn_bias = attn_bias + self.attn_mask + attn_bias = attn_bias + self.attn_mask.unsqueeze(1) # add n_heads dim x, nH, nW = window_partition(x, self.window_size) # (B * nH * nW, win_size * win_size, C) x = super().forward(x, attn_bias) - x = window_unpartition(x, self.window_size, nH, nW) + x = window_unpartition(x, self.window_size, nH, nW) # (B, H, W, C) if self.shift > 0: x = x.roll((self.shift, self.shift), (1, 2)) @@ -159,11 +159,13 @@ def __init__( n_heads *= 2 else: downsample = nn.Identity() - stage.add_module("downsample", downsample) + stage.append(downsample) for i in range(depth): - blk = SwinBlock(input_size, d_model, n_heads, window_size, i % 2, mlp_ratio, bias, dropout, norm, act) - stage.add_module(f"block_{i}", blk) + shift = (i % 2) and input_size > window_size + block = SwinBlock(input_size, d_model, n_heads, window_size, shift, mlp_ratio, bias, dropout, norm, act) + stage.append(block) + self.stages.append(stage) self.head_norm = norm(d_model) @@ -210,27 +212,36 @@ def from_config(variant: str, img_size: int, pretrained: bool = False) -> SwinTr @torch.no_grad() def load_official_ckpt(self, state_dict: dict[str, Tensor]) -> None: def copy_(m: nn.Linear | nn.LayerNorm, prefix: str) -> None: - m.weight.copy_(state_dict[prefix + ".weight"]) - m.bias.copy_(state_dict[prefix + ".bias"]) + m.weight.copy_(state_dict.pop(prefix + ".weight")) + if m.bias is not None: + m.bias.copy_(state_dict.pop(prefix + ".bias")) copy_(self.patch_embed, "patch_embed.proj") copy_(self.norm, "patch_embed.norm") - for stage_i, stage in enumerate(self.stages): - if stage_i > 0: - stage.downsample.reduction.weight.copy_(state_dict[f"layers.{stage_i-1}.downsample.reduction.weight"]) + for stage_idx, stage in enumerate(self.stages): + if stage_idx > 0: + prefix = f"layers.{stage_idx-1}.downsample." + copy_(stage[0].norm, prefix + "norm") + copy_(stage[0].reduction, prefix + "reduction") - for block_idx in range(len(stage) - 1): - block: SwinBlock = getattr(stage, f"block_{block_idx}") - prefix = f"layers.{stage_i}.blocks.{block_idx}." + for block_idx in range(1, len(stage)): + block: SwinBlock = stage[block_idx] + prefix = f"layers.{stage_idx}.blocks.{block_idx - 1}." block_idx += 1 + if block.mha.attn_mask is not None: + torch.testing.assert_close(block.mha.attn_mask, state_dict.pop(prefix + "attn_mask")) + torch.testing.assert_close( + block.mha.relative_pe_index, state_dict.pop(prefix + "attn.relative_position_index") + ) copy_(block.norm1, prefix + "norm1") copy_(block.mha.in_proj, prefix + "attn.qkv") copy_(block.mha.out_proj, prefix + "attn.proj") - block.mha.relative_pe_table.copy_(state_dict[prefix + "attn.relative_position_bias_table"].T) + block.mha.relative_pe_table.copy_(state_dict.pop(prefix + "attn.relative_position_bias_table").T) copy_(block.norm2, prefix + "norm2") copy_(block.mlp.linear1, prefix + "mlp.fc1") copy_(block.mlp.linear2, prefix + "mlp.fc2") copy_(self.head_norm, "norm") + assert len(state_dict) == 2 # head.weight and head.bias diff --git a/vision_toolbox/backbones/vit.py b/vision_toolbox/backbones/vit.py index cdb3a29..c1e4ed8 100644 --- a/vision_toolbox/backbones/vit.py +++ b/vision_toolbox/backbones/vit.py @@ -27,7 +27,7 @@ def __init__(self, d_model: int, n_heads: int, bias: bool = True, dropout: float def forward(self, x: Tensor, attn_bias: Tensor | None = None) -> Tensor: qkv = self.in_proj(x) - q, k, v = qkv.unflatten(-1, (3, self.n_heads, -1)).transpose(-2, -4).unbind(-3) + q, k, v = qkv.unflatten(-1, (3, self.n_heads, -1)).transpose(-2, -4).unbind(-3) # (B, n_heads, L, head_dim) if hasattr(F, "scaled_dot_product_attention"): out = F.scaled_dot_product_attention(q, k, v, attn_bias, self.dropout if self.training else 0.0) else: