Skip to content

Commit

Permalink
fix some bugs with Swin
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Aug 17, 2023
1 parent 0e95db4 commit c90825f
Showing 1 changed file with 27 additions and 22 deletions.
49 changes: 27 additions & 22 deletions vision_toolbox/backbones/swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,8 @@ def __init__(
img_mask[0, h_slice, w_slice, 0] = i

windows_mask, _, _ = window_partition(img_mask, window_size) # (nH * nW, win_size * win_size, 1)
windows_mask = windows_mask.transpose(1, 2) # (nH * nW, 1, win_size * win_size)
attn_mask = windows_mask.unsqueeze(2) - windows_mask.unsqueeze(3)
self.register_buffer("attn_mask", (attn_mask != 0) * (-100), False)
attn_mask = windows_mask - windows_mask.transpose(1, 2)
self.register_buffer("attn_mask", (attn_mask != 0).unsqueeze(1) * (-100), False)
self.attn_mask: Tensor

else:
Expand All @@ -74,15 +73,15 @@ def forward(self, x: Tensor) -> Tensor:
assert x.shape[1] == self.input_size, (x.shape[1], self.input_size)
attn_bias = self.relative_pe_table[..., self.relative_pe_index]
if self.shift > 0:
x = x.roll((self.shift, self.shift), (1, 2))
x = x.roll((-self.shift, -self.shift), (1, 2))
attn_bias = attn_bias + self.attn_mask

x, nH, nW = window_partition(x, self.window_size) # (B * nH * nW, win_size * win_size, C)
x = super().forward(x, attn_bias)
x = window_unpartition(x, self.window_size, nH, nW)

if self.shift > 0:
x = x.roll((-self.shift, -self.shift), (1, 2))
x = x.roll((self.shift, self.shift), (1, 2))
return x


Expand All @@ -107,8 +106,8 @@ def __init__(
self.mlp = MLP(d_model, int(d_model * mlp_ratio), dropout, act)

def forward(self, x: Tensor) -> Tensor:
x = self.mha(self.norm1(x))
x = self.mlp(self.norm2(x))
x = x + self.mha(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x


Expand All @@ -120,7 +119,8 @@ def __init__(self, d_model: int, norm: _norm = nn.LayerNorm) -> None:

def forward(self, x: Tensor) -> Tensor:
B, H, W, C = x.shape
x = x.view(B, H // 2, 2, W // 2, 2, C).transpose(2, 3).flatten(-3)
# x = x.view(B, H // 2, 2, W // 2, 2, C).transpose(2, 3).flatten(-3)
x = x.view(B, H // 2, 2, W // 2, 2, C).permute(0, 1, 3, 4, 2, 5).flatten(-3)
x = self.reduction(self.norm(x))
x = x.view(B, H // 2, W // 2, C * 2)
return x
Expand Down Expand Up @@ -153,14 +153,17 @@ def __init__(
for i, (depth, window_size) in enumerate(zip(depths, window_sizes)):
stage = nn.Sequential()
if i > 0:
stage.append(PatchMerging(d_model, norm))
downsample = PatchMerging(d_model, norm)
input_size //= 2
d_model *= 2
n_heads *= 2
else:
downsample = nn.Identity()
stage.add_module("downsample", downsample)

for i in range(depth):
blk = SwinBlock(input_size, d_model, n_heads, window_size, i % 2, mlp_ratio, bias, dropout, norm, act)
stage.append(blk)
stage.add_module(f"block_{i}", blk)
self.stages.append(stage)

self.head_norm = norm(d_model)
Expand All @@ -175,7 +178,7 @@ def forward(self, x: Tensor) -> Tensor:
return self.head_norm(self.get_feature_maps(x)[-1]).mean((1, 2))

def resize_pe(self, img_size: int) -> None:
pass
raise NotImplementedError()

@staticmethod
def from_config(variant: str, img_size: int, pretrained: bool = False) -> SwinTransformer:
Expand All @@ -190,16 +193,17 @@ def from_config(variant: str, img_size: int, pretrained: bool = False) -> SwinTr
"S3-S": (96, 3, (2, 2, 18, 2), (14, 14, 14, 14), "supernet-small.pth"),
"S3-B": (96, 3, (2, 2, 30, 2), (7, 7, 14, 7), "supernet-base.pth"),
}[variant]
m = SwinTransformer(img_size, d_model, n_heads, depths, window_sizes)
m = SwinTransformer(224 if pretrained else img_size, d_model, n_heads, depths, window_sizes)

if pretrained:
base_url = (
"https://github.com/silent-chen/AutoFormer-model-zoo/releases/download/v1.0/"
if variant.startswith("S3")
else "https://github.com/SwinTransformer/storage/releases/download/v1.0.8/"
)
if variant.startswith("S3"):
base_url = "https://github.com/silent-chen/AutoFormer-model-zoo/releases/download/v1.0/"
else:
base_url = "https://github.com/SwinTransformer/storage/releases/download/"
state_dict = torch.hub.load_state_dict_from_url(base_url + ckpt)["model"]
m.load_official_ckpt(state_dict)
if img_size != 224:
m.resize_pe(img_size)

return m

Expand All @@ -214,16 +218,17 @@ def copy_(m: nn.Linear | nn.LayerNorm, prefix: str) -> None:

for stage_i, stage in enumerate(self.stages):
if stage_i > 0:
downsample: PatchMerging = stage[0]
downsample.reduction.weight.copy_(state_dict[f"layers.{stage_i-1}.downsample.reduction.weight"])
stage.downsample.reduction.weight.copy_(state_dict[f"layers.{stage_i-1}.downsample.reduction.weight"])

for block_idx, block in enumerate(stage):
block: SwinBlock
for block_idx in range(len(stage) - 1):
block: SwinBlock = getattr(stage, f"block_{block_idx}")
prefix = f"layers.{stage_i}.blocks.{block_idx}."
block_idx += 1

copy_(block.norm1, prefix + "norm1")
copy_(block.mha.in_proj, prefix + "attn.qkv")
copy_(block.mha.out_proj, prefix + "attn.proj")
block.mha.relative_pe_table.copy_(state_dict[prefix + "attn.relative_position_bias_table"])
block.mha.relative_pe_table.copy_(state_dict[prefix + "attn.relative_position_bias_table"].T)
copy_(block.norm2, prefix + "norm2")
copy_(block.mlp.linear1, prefix + "mlp.fc1")
copy_(block.mlp.linear2, prefix + "mlp.fc2")
Expand Down

0 comments on commit c90825f

Please sign in to comment.