Skip to content

Commit

Permalink
don't shift when input_size = window_size
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Aug 19, 2023
1 parent c90825f commit f6e48d8
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 33 deletions.
23 changes: 8 additions & 15 deletions tests/test_swin.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,20 @@
import timm
import torch

from vision_toolbox.backbones import SwinTransformer
from vision_toolbox.backbones.swin import window_partition, window_unpartition


def test_window_partition():
img = torch.randn(1, 224, 280, 3)
windows, nH, nW = window_partition(img, 7)
_img = window_unpartition(windows, 7, nH, nW)
torch.testing.assert_close(img, _img)


def test_forward():
m = SwinTransformer.from_config("T", 224)
m(torch.randn(1, 3, 224, 224))


# def test_from_pretrained():
# m = ViT.from_config("Ti", 16, 224, True).eval()
# x = torch.randn(1, 3, 224, 224)
# out = m(x)
def test_from_pretrained():
m = SwinTransformer.from_config("T", 224, True).eval()
x = torch.randn(1, 3, 224, 224)
out = m(x)

# m_timm = timm.create_model("vit_tiny_patch16_224.augreg_in21k", pretrained=True, num_classes=0).eval()
# out_timm = m_timm(x)
m_timm = timm.create_model("swin_tiny_patch4_window7_224.ms_in22k", pretrained=True, num_classes=0).eval()
out_timm = m_timm(x)

# torch.testing.assert_close(out, out_timm, rtol=2e-5, atol=2e-5)
torch.testing.assert_close(out, out_timm, rtol=2e-5, atol=2e-5)
45 changes: 28 additions & 17 deletions vision_toolbox/backbones/swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ def __init__(
img_mask = torch.zeros(1, input_size, input_size, 1)
slices = (slice(0, -window_size), slice(-window_size, -self.shift), slice(-self.shift, None))
for i, (h_slice, w_slice) in enumerate(itertools.product(slices, slices)):
img_mask[0, h_slice, w_slice, 0] = i
img_mask[:, h_slice, w_slice, :] = i

windows_mask, _, _ = window_partition(img_mask, window_size) # (nH * nW, win_size * win_size, 1)
attn_mask = windows_mask - windows_mask.transpose(1, 2)
self.register_buffer("attn_mask", (attn_mask != 0).unsqueeze(1) * (-100), False)
attn_mask = windows_mask.transpose(1, 2) - windows_mask
self.register_buffer("attn_mask", (attn_mask != 0) * (-100.0), False)
self.attn_mask: Tensor

else:
Expand All @@ -74,11 +74,11 @@ def forward(self, x: Tensor) -> Tensor:
attn_bias = self.relative_pe_table[..., self.relative_pe_index]
if self.shift > 0:
x = x.roll((-self.shift, -self.shift), (1, 2))
attn_bias = attn_bias + self.attn_mask
attn_bias = attn_bias + self.attn_mask.unsqueeze(1) # add n_heads dim

x, nH, nW = window_partition(x, self.window_size) # (B * nH * nW, win_size * win_size, C)
x = super().forward(x, attn_bias)
x = window_unpartition(x, self.window_size, nH, nW)
x = window_unpartition(x, self.window_size, nH, nW) # (B, H, W, C)

if self.shift > 0:
x = x.roll((self.shift, self.shift), (1, 2))
Expand Down Expand Up @@ -159,11 +159,13 @@ def __init__(
n_heads *= 2
else:
downsample = nn.Identity()
stage.add_module("downsample", downsample)
stage.append(downsample)

for i in range(depth):
blk = SwinBlock(input_size, d_model, n_heads, window_size, i % 2, mlp_ratio, bias, dropout, norm, act)
stage.add_module(f"block_{i}", blk)
shift = (i % 2) and input_size > window_size
block = SwinBlock(input_size, d_model, n_heads, window_size, shift, mlp_ratio, bias, dropout, norm, act)
stage.append(block)

self.stages.append(stage)

self.head_norm = norm(d_model)
Expand Down Expand Up @@ -210,27 +212,36 @@ def from_config(variant: str, img_size: int, pretrained: bool = False) -> SwinTr
@torch.no_grad()
def load_official_ckpt(self, state_dict: dict[str, Tensor]) -> None:
def copy_(m: nn.Linear | nn.LayerNorm, prefix: str) -> None:
m.weight.copy_(state_dict[prefix + ".weight"])
m.bias.copy_(state_dict[prefix + ".bias"])
m.weight.copy_(state_dict.pop(prefix + ".weight"))
if m.bias is not None:
m.bias.copy_(state_dict.pop(prefix + ".bias"))

copy_(self.patch_embed, "patch_embed.proj")
copy_(self.norm, "patch_embed.norm")

for stage_i, stage in enumerate(self.stages):
if stage_i > 0:
stage.downsample.reduction.weight.copy_(state_dict[f"layers.{stage_i-1}.downsample.reduction.weight"])
for stage_idx, stage in enumerate(self.stages):
if stage_idx > 0:
prefix = f"layers.{stage_idx-1}.downsample."
copy_(stage[0].norm, prefix + "norm")
copy_(stage[0].reduction, prefix + "reduction")

for block_idx in range(len(stage) - 1):
block: SwinBlock = getattr(stage, f"block_{block_idx}")
prefix = f"layers.{stage_i}.blocks.{block_idx}."
for block_idx in range(1, len(stage)):
block: SwinBlock = stage[block_idx]
prefix = f"layers.{stage_idx}.blocks.{block_idx - 1}."
block_idx += 1

if block.mha.attn_mask is not None:
torch.testing.assert_close(block.mha.attn_mask, state_dict.pop(prefix + "attn_mask"))
torch.testing.assert_close(
block.mha.relative_pe_index, state_dict.pop(prefix + "attn.relative_position_index")
)
copy_(block.norm1, prefix + "norm1")
copy_(block.mha.in_proj, prefix + "attn.qkv")
copy_(block.mha.out_proj, prefix + "attn.proj")
block.mha.relative_pe_table.copy_(state_dict[prefix + "attn.relative_position_bias_table"].T)
block.mha.relative_pe_table.copy_(state_dict.pop(prefix + "attn.relative_position_bias_table").T)
copy_(block.norm2, prefix + "norm2")
copy_(block.mlp.linear1, prefix + "mlp.fc1")
copy_(block.mlp.linear2, prefix + "mlp.fc2")

copy_(self.head_norm, "norm")
assert len(state_dict) == 2 # head.weight and head.bias
2 changes: 1 addition & 1 deletion vision_toolbox/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, d_model: int, n_heads: int, bias: bool = True, dropout: float

def forward(self, x: Tensor, attn_bias: Tensor | None = None) -> Tensor:
qkv = self.in_proj(x)
q, k, v = qkv.unflatten(-1, (3, self.n_heads, -1)).transpose(-2, -4).unbind(-3)
q, k, v = qkv.unflatten(-1, (3, self.n_heads, -1)).transpose(-2, -4).unbind(-3) # (B, n_heads, L, head_dim)
if hasattr(F, "scaled_dot_product_attention"):
out = F.scaled_dot_product_attention(q, k, v, attn_bias, self.dropout if self.training else 0.0)
else:
Expand Down

0 comments on commit f6e48d8

Please sign in to comment.