From 0e95db41db8816374e506d944e433192e2d2d819 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 14 Aug 2023 22:51:17 +0800 Subject: [PATCH] add WIP weight loading --- vision_toolbox/backbones/swin.py | 58 ++++++++++++++++++++++++++------ 1 file changed, 48 insertions(+), 10 deletions(-) diff --git a/vision_toolbox/backbones/swin.py b/vision_toolbox/backbones/swin.py index 124d179..e54515e 100644 --- a/vision_toolbox/backbones/swin.py +++ b/vision_toolbox/backbones/swin.py @@ -8,7 +8,6 @@ import torch from torch import Tensor, nn -from ..utils import torch_hub_download from .base import BaseBackbone, _act, _norm from .vit import MHA, MLP @@ -127,7 +126,7 @@ def forward(self, x: Tensor) -> Tensor: return x -class SwinTransformer(nn.Module): +class SwinTransformer(BaseBackbone): def __init__( self, img_size: int, @@ -175,19 +174,58 @@ def get_feature_maps(self, x: Tensor) -> list[Tensor]: def forward(self, x: Tensor) -> Tensor: return self.head_norm(self.get_feature_maps(x)[-1]).mean((1, 2)) + def resize_pe(self, img_size: int) -> None: + pass + @staticmethod def from_config(variant: str, img_size: int, pretrained: bool = False) -> SwinTransformer: - d_model, n_heads, depths, window_sizes = { + d_model, n_heads, depths, window_sizes, ckpt = { # Sub-section 3.3 in https://arxiv.org/pdf/2103.14030.pdf - "T": (96, 3, (2, 2, 6, 2), (7, 7, 7, 7)), - "S": (96, 3, (2, 2, 18, 2), (7, 7, 7, 7)), - "B": (128, 4, (2, 2, 18, 2), (7, 7, 7, 7)), - "L": (192, 6, (2, 2, 18, 2), (7, 7, 7, 7)), + "T": (96, 3, (2, 2, 6, 2), (7, 7, 7, 7), "v1.0.8/swin_tiny_patch4_window7_224_22k.pth"), + "S": (96, 3, (2, 2, 18, 2), (7, 7, 7, 7), "v1.0.8/swin_small_patch4_window7_224_22k.pth"), + "B": (128, 4, (2, 2, 18, 2), (7, 7, 7, 7), "v1.0.0/swin_base_patch4_window7_224_22k.pth"), + "L": (192, 6, (2, 2, 18, 2), (7, 7, 7, 7), "v1.0.0/swin_large_patch4_window7_224_22k.pth"), # https://github.com/microsoft/Cream/blob/main/AutoFormerV2/configs - "S3-T": (96, 3, (2, 2, 6, 2), (7, 7, 14, 7)), - "S3-S": (96, 3, (2, 2, 18, 2), (14, 14, 14, 14)), - "S3-B": (96, 3, (2, 2, 30, 2), (7, 7, 14, 7)), + "S3-T": (96, 3, (2, 2, 6, 2), (7, 7, 14, 7), "supernet-tiny.pth"), + "S3-S": (96, 3, (2, 2, 18, 2), (14, 14, 14, 14), "supernet-small.pth"), + "S3-B": (96, 3, (2, 2, 30, 2), (7, 7, 14, 7), "supernet-base.pth"), }[variant] m = SwinTransformer(img_size, d_model, n_heads, depths, window_sizes) + if pretrained: + base_url = ( + "https://github.com/silent-chen/AutoFormer-model-zoo/releases/download/v1.0/" + if variant.startswith("S3") + else "https://github.com/SwinTransformer/storage/releases/download/v1.0.8/" + ) + state_dict = torch.hub.load_state_dict_from_url(base_url + ckpt)["model"] + m.load_official_ckpt(state_dict) + return m + + @torch.no_grad() + def load_official_ckpt(self, state_dict: dict[str, Tensor]) -> None: + def copy_(m: nn.Linear | nn.LayerNorm, prefix: str) -> None: + m.weight.copy_(state_dict[prefix + ".weight"]) + m.bias.copy_(state_dict[prefix + ".bias"]) + + copy_(self.patch_embed, "patch_embed.proj") + copy_(self.norm, "patch_embed.norm") + + for stage_i, stage in enumerate(self.stages): + if stage_i > 0: + downsample: PatchMerging = stage[0] + downsample.reduction.weight.copy_(state_dict[f"layers.{stage_i-1}.downsample.reduction.weight"]) + + for block_idx, block in enumerate(stage): + block: SwinBlock + prefix = f"layers.{stage_i}.blocks.{block_idx}." + copy_(block.norm1, prefix + "norm1") + copy_(block.mha.in_proj, prefix + "attn.qkv") + copy_(block.mha.out_proj, prefix + "attn.proj") + block.mha.relative_pe_table.copy_(state_dict[prefix + "attn.relative_position_bias_table"]) + copy_(block.norm2, prefix + "norm2") + copy_(block.mlp.linear1, prefix + "mlp.fc1") + copy_(block.mlp.linear2, prefix + "mlp.fc2") + + copy_(self.head_norm, "norm")