From b964ffef51f34d0bb1bd8030eb26353dbc7fd3bc Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 19 Aug 2023 11:14:34 +0800 Subject: [PATCH] use faster weight transpose --- vision_toolbox/backbones/swin.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/vision_toolbox/backbones/swin.py b/vision_toolbox/backbones/swin.py index c1e204c..6de5beb 100644 --- a/vision_toolbox/backbones/swin.py +++ b/vision_toolbox/backbones/swin.py @@ -119,8 +119,7 @@ def __init__(self, d_model: int, norm: _norm = nn.LayerNorm) -> None: def forward(self, x: Tensor) -> Tensor: B, H, W, C = x.shape - # x = x.view(B, H // 2, 2, W // 2, 2, C).transpose(2, 3).flatten(-3) - x = x.view(B, H // 2, 2, W // 2, 2, C).permute(0, 1, 3, 4, 2, 5).flatten(-3) + x = x.view(B, H // 2, 2, W // 2, 2, C).transpose(2, 3).flatten(-3) x = self.reduction(self.norm(x)) x = x.view(B, H // 2, W // 2, C * 2) return x @@ -213,8 +212,7 @@ def from_config(variant: str, img_size: int, pretrained: bool = False) -> SwinTr def load_official_ckpt(self, state_dict: dict[str, Tensor]) -> None: def copy_(m: nn.Linear | nn.LayerNorm, prefix: str) -> None: m.weight.copy_(state_dict.pop(prefix + ".weight")) - if m.bias is not None: - m.bias.copy_(state_dict.pop(prefix + ".bias")) + m.bias.copy_(state_dict.pop(prefix + ".bias")) copy_(self.patch_embed, "patch_embed.proj") copy_(self.norm, "patch_embed.norm") @@ -222,8 +220,14 @@ def copy_(m: nn.Linear | nn.LayerNorm, prefix: str) -> None: for stage_idx, stage in enumerate(self.stages): if stage_idx > 0: prefix = f"layers.{stage_idx-1}.downsample." - copy_(stage[0].norm, prefix + "norm") - copy_(stage[0].reduction, prefix + "reduction") + + def rearrange(p): + p1, p2, p3, p4 = p.chunk(4, -1) + return torch.cat((p1, p3, p2, p4), -1) + + stage[0].norm.weight.copy_(rearrange(state_dict.pop(prefix + "norm.weight"))) + stage[0].norm.bias.copy_(rearrange(state_dict.pop(prefix + "norm.bias"))) + stage[0].reduction.weight.copy_(rearrange(state_dict.pop(prefix + "reduction.weight"))) for block_idx in range(1, len(stage)): block: SwinBlock = stage[block_idx]