Skip to content

Commit

Permalink
add WIP weight loading
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Aug 14, 2023
1 parent fd9a270 commit 0e95db4
Showing 1 changed file with 48 additions and 10 deletions.
58 changes: 48 additions & 10 deletions vision_toolbox/backbones/swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch
from torch import Tensor, nn

from ..utils import torch_hub_download
from .base import BaseBackbone, _act, _norm
from .vit import MHA, MLP

Expand Down Expand Up @@ -127,7 +126,7 @@ def forward(self, x: Tensor) -> Tensor:
return x


class SwinTransformer(nn.Module):
class SwinTransformer(BaseBackbone):
def __init__(
self,
img_size: int,
Expand Down Expand Up @@ -175,19 +174,58 @@ def get_feature_maps(self, x: Tensor) -> list[Tensor]:
def forward(self, x: Tensor) -> Tensor:
return self.head_norm(self.get_feature_maps(x)[-1]).mean((1, 2))

def resize_pe(self, img_size: int) -> None:
pass

@staticmethod
def from_config(variant: str, img_size: int, pretrained: bool = False) -> SwinTransformer:
d_model, n_heads, depths, window_sizes = {
d_model, n_heads, depths, window_sizes, ckpt = {
# Sub-section 3.3 in https://arxiv.org/pdf/2103.14030.pdf
"T": (96, 3, (2, 2, 6, 2), (7, 7, 7, 7)),
"S": (96, 3, (2, 2, 18, 2), (7, 7, 7, 7)),
"B": (128, 4, (2, 2, 18, 2), (7, 7, 7, 7)),
"L": (192, 6, (2, 2, 18, 2), (7, 7, 7, 7)),
"T": (96, 3, (2, 2, 6, 2), (7, 7, 7, 7), "v1.0.8/swin_tiny_patch4_window7_224_22k.pth"),
"S": (96, 3, (2, 2, 18, 2), (7, 7, 7, 7), "v1.0.8/swin_small_patch4_window7_224_22k.pth"),
"B": (128, 4, (2, 2, 18, 2), (7, 7, 7, 7), "v1.0.0/swin_base_patch4_window7_224_22k.pth"),
"L": (192, 6, (2, 2, 18, 2), (7, 7, 7, 7), "v1.0.0/swin_large_patch4_window7_224_22k.pth"),
# https://github.com/microsoft/Cream/blob/main/AutoFormerV2/configs
"S3-T": (96, 3, (2, 2, 6, 2), (7, 7, 14, 7)),
"S3-S": (96, 3, (2, 2, 18, 2), (14, 14, 14, 14)),
"S3-B": (96, 3, (2, 2, 30, 2), (7, 7, 14, 7)),
"S3-T": (96, 3, (2, 2, 6, 2), (7, 7, 14, 7), "supernet-tiny.pth"),
"S3-S": (96, 3, (2, 2, 18, 2), (14, 14, 14, 14), "supernet-small.pth"),
"S3-B": (96, 3, (2, 2, 30, 2), (7, 7, 14, 7), "supernet-base.pth"),
}[variant]
m = SwinTransformer(img_size, d_model, n_heads, depths, window_sizes)

if pretrained:
base_url = (
"https://github.com/silent-chen/AutoFormer-model-zoo/releases/download/v1.0/"
if variant.startswith("S3")
else "https://github.com/SwinTransformer/storage/releases/download/v1.0.8/"
)
state_dict = torch.hub.load_state_dict_from_url(base_url + ckpt)["model"]
m.load_official_ckpt(state_dict)

return m

@torch.no_grad()
def load_official_ckpt(self, state_dict: dict[str, Tensor]) -> None:
def copy_(m: nn.Linear | nn.LayerNorm, prefix: str) -> None:
m.weight.copy_(state_dict[prefix + ".weight"])
m.bias.copy_(state_dict[prefix + ".bias"])

copy_(self.patch_embed, "patch_embed.proj")
copy_(self.norm, "patch_embed.norm")

for stage_i, stage in enumerate(self.stages):
if stage_i > 0:
downsample: PatchMerging = stage[0]
downsample.reduction.weight.copy_(state_dict[f"layers.{stage_i-1}.downsample.reduction.weight"])

for block_idx, block in enumerate(stage):
block: SwinBlock
prefix = f"layers.{stage_i}.blocks.{block_idx}."
copy_(block.norm1, prefix + "norm1")
copy_(block.mha.in_proj, prefix + "attn.qkv")
copy_(block.mha.out_proj, prefix + "attn.proj")
block.mha.relative_pe_table.copy_(state_dict[prefix + "attn.relative_position_bias_table"])
copy_(block.norm2, prefix + "norm2")
copy_(block.mlp.linear1, prefix + "mlp.fc1")
copy_(block.mlp.linear2, prefix + "mlp.fc2")

copy_(self.head_norm, "norm")

0 comments on commit 0e95db4

Please sign in to comment.