Skip to content

Commit

Permalink
use faster weight transpose
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Aug 19, 2023
1 parent f6e48d8 commit b964ffe
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions vision_toolbox/backbones/swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,7 @@ def __init__(self, d_model: int, norm: _norm = nn.LayerNorm) -> None:

def forward(self, x: Tensor) -> Tensor:
B, H, W, C = x.shape
# x = x.view(B, H // 2, 2, W // 2, 2, C).transpose(2, 3).flatten(-3)
x = x.view(B, H // 2, 2, W // 2, 2, C).permute(0, 1, 3, 4, 2, 5).flatten(-3)
x = x.view(B, H // 2, 2, W // 2, 2, C).transpose(2, 3).flatten(-3)
x = self.reduction(self.norm(x))
x = x.view(B, H // 2, W // 2, C * 2)
return x
Expand Down Expand Up @@ -213,17 +212,22 @@ def from_config(variant: str, img_size: int, pretrained: bool = False) -> SwinTr
def load_official_ckpt(self, state_dict: dict[str, Tensor]) -> None:
def copy_(m: nn.Linear | nn.LayerNorm, prefix: str) -> None:
m.weight.copy_(state_dict.pop(prefix + ".weight"))
if m.bias is not None:
m.bias.copy_(state_dict.pop(prefix + ".bias"))
m.bias.copy_(state_dict.pop(prefix + ".bias"))

copy_(self.patch_embed, "patch_embed.proj")
copy_(self.norm, "patch_embed.norm")

for stage_idx, stage in enumerate(self.stages):
if stage_idx > 0:
prefix = f"layers.{stage_idx-1}.downsample."
copy_(stage[0].norm, prefix + "norm")
copy_(stage[0].reduction, prefix + "reduction")

def rearrange(p):
p1, p2, p3, p4 = p.chunk(4, -1)
return torch.cat((p1, p3, p2, p4), -1)

stage[0].norm.weight.copy_(rearrange(state_dict.pop(prefix + "norm.weight")))
stage[0].norm.bias.copy_(rearrange(state_dict.pop(prefix + "norm.bias")))
stage[0].reduction.weight.copy_(rearrange(state_dict.pop(prefix + "reduction.weight")))

for block_idx in range(1, len(stage)):
block: SwinBlock = stage[block_idx]
Expand Down

0 comments on commit b964ffe

Please sign in to comment.