Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Swin Transformer #16

Merged
merged 13 commits into from
Aug 19, 2023
20 changes: 20 additions & 0 deletions tests/test_swin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import timm
import torch

from vision_toolbox.backbones import SwinTransformer


def test_forward():
m = SwinTransformer.from_config("T", 224)
m(torch.randn(1, 3, 224, 224))


def test_from_pretrained():
m = SwinTransformer.from_config("T", 224, True).eval()
x = torch.randn(1, 3, 224, 224)
out = m(x)

m_timm = timm.create_model("swin_tiny_patch4_window7_224.ms_in22k", pretrained=True, num_classes=0).eval()
out_timm = m_timm(x)

torch.testing.assert_close(out, out_timm, rtol=2e-5, atol=2e-5)
1 change: 1 addition & 0 deletions vision_toolbox/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .darknet import Darknet, DarknetYOLOv5
from .mlp_mixer import MLPMixer
from .patchconvnet import PatchConvNet
from .swin import SwinTransformer
from .torchvision_models import EfficientNetExtractor, MobileNetExtractor, RegNetExtractor, ResNetExtractor
from .vit import ViT
from .vovnet import VoVNet
19 changes: 8 additions & 11 deletions vision_toolbox/backbones/mlp_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,7 @@

from ..utils import torch_hub_download
from .base import _act, _norm


class MLP(nn.Sequential):
def __init__(self, in_dim: int, hidden_dim: float, act: _act = nn.GELU) -> None:
super().__init__()
self.linear1 = nn.Linear(in_dim, hidden_dim)
self.act = act()
self.linear2 = nn.Linear(hidden_dim, in_dim)
from .vit import MLP


class MixerBlock(nn.Module):
Expand All @@ -28,15 +21,16 @@ def __init__(
n_tokens: int,
d_model: int,
mlp_ratio: tuple[int, int] = (0.5, 4.0),
dropout: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
) -> None:
tokens_mlp_dim, channels_mlp_dim = [int(d_model * ratio) for ratio in mlp_ratio]
super().__init__()
self.norm1 = norm(d_model)
self.token_mixing = MLP(n_tokens, tokens_mlp_dim, act)
self.token_mixing = MLP(n_tokens, tokens_mlp_dim, dropout, act)
self.norm2 = norm(d_model)
self.channel_mixing = MLP(d_model, channels_mlp_dim, act)
self.channel_mixing = MLP(d_model, channels_mlp_dim, dropout, act)

def forward(self, x: Tensor) -> Tensor:
# x -> (B, n_tokens, d_model)
Expand All @@ -53,14 +47,17 @@ def __init__(
patch_size: int,
img_size: int,
mlp_ratio: tuple[float, float] = (0.5, 4.0),
dropout: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
) -> None:
assert img_size % patch_size == 0
super().__init__()
self.patch_embed = nn.Conv2d(3, d_model, patch_size, patch_size)
n_tokens = (img_size // patch_size) ** 2
self.layers = nn.Sequential(*[MixerBlock(n_tokens, d_model, mlp_ratio, norm, act) for _ in range(n_layers)])
self.layers = nn.Sequential(
*[MixerBlock(n_tokens, d_model, mlp_ratio, dropout, norm, act) for _ in range(n_layers)]
)
self.norm = norm(d_model)

def forward(self, x: Tensor) -> Tensor:
Expand Down
251 changes: 251 additions & 0 deletions vision_toolbox/backbones/swin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
# https://arxiv.org/abs/2103.14030
# https://github.com/microsoft/Swin-Transformer

from __future__ import annotations

import itertools

import torch
from torch import Tensor, nn

from .base import BaseBackbone, _act, _norm
from .vit import MHA, MLP


def window_partition(x: Tensor, window_size: int) -> tuple[Tensor, int, int]:
B, H, W, C = x.shape
nH, nW = H // window_size, W // window_size
x = x.view(B, nH, window_size, nW, window_size, C)
x = x.transpose(2, 3).reshape(B * nH * nW, window_size * window_size, C)
return x, nH, nW


def window_unpartition(x: Tensor, window_size: int, nH: int, nW: int) -> Tensor:
B = x.shape[0] // (nH * nW)
C = x.shape[2]
x = x.view(B, nH, nW, window_size, window_size, C)
x = x.transpose(2, 3).reshape(B, nH * window_size, nW * window_size, C)
return x


class WindowAttention(MHA):
def __init__(
self,
input_size: int,
d_model: int,
n_heads: int,
window_size: int = 7,
shift: bool = False,
bias: bool = True,
dropout: float = 0.0,
) -> None:
super().__init__(d_model, n_heads, bias, dropout)
self.input_size = input_size
self.window_size = window_size

if shift:
self.shift = window_size // 2

img_mask = torch.zeros(1, input_size, input_size, 1)
slices = (slice(0, -window_size), slice(-window_size, -self.shift), slice(-self.shift, None))
for i, (h_slice, w_slice) in enumerate(itertools.product(slices, slices)):
img_mask[:, h_slice, w_slice, :] = i

windows_mask, _, _ = window_partition(img_mask, window_size) # (nH * nW, win_size * win_size, 1)
attn_mask = windows_mask.transpose(1, 2) - windows_mask
self.register_buffer("attn_mask", (attn_mask != 0) * (-100.0), False)
self.attn_mask: Tensor

else:
self.shift = 0
self.attn_mask = None

self.relative_pe_table = nn.Parameter(torch.empty(1, n_heads, (2 * window_size - 1) ** 2))
nn.init.trunc_normal_(self.relative_pe_table, 0, 0.02)

xy = torch.cartesian_prod(torch.arange(window_size), torch.arange(window_size)) # all possible (x,y) pairs
diff = xy.unsqueeze(1) - xy.unsqueeze(0) # difference between all (x,y) pairs
index = (diff[:, :, 0] + window_size - 1) * (2 * window_size - 1) + diff[:, :, 1] + window_size - 1
self.register_buffer("relative_pe_index", index, False)
self.relative_pe_index: Tensor

def forward(self, x: Tensor) -> Tensor:
assert x.shape[1] == self.input_size, (x.shape[1], self.input_size)
attn_bias = self.relative_pe_table[..., self.relative_pe_index]
if self.shift > 0:
x = x.roll((-self.shift, -self.shift), (1, 2))
attn_bias = attn_bias + self.attn_mask.unsqueeze(1) # add n_heads dim

x, nH, nW = window_partition(x, self.window_size) # (B * nH * nW, win_size * win_size, C)
x = super().forward(x, attn_bias)
x = window_unpartition(x, self.window_size, nH, nW) # (B, H, W, C)

if self.shift > 0:
x = x.roll((self.shift, self.shift), (1, 2))
return x


class SwinBlock(nn.Module):
def __init__(
self,
input_size: int,
d_model: int,
n_heads: int,
window_size: int = 7,
shift: bool = False,
mlp_ratio: float = 4.0,
bias: bool = True,
dropout: float = 0.0,
norm: _norm = nn.LayerNorm,
act: _act = nn.GELU,
) -> None:
super().__init__()
self.norm1 = norm(d_model)
self.mha = WindowAttention(input_size, d_model, n_heads, window_size, shift, bias, dropout)
self.norm2 = norm(d_model)
self.mlp = MLP(d_model, int(d_model * mlp_ratio), dropout, act)

def forward(self, x: Tensor) -> Tensor:
x = x + self.mha(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x


class PatchMerging(nn.Module):
def __init__(self, d_model: int, norm: _norm = nn.LayerNorm) -> None:
super().__init__()
self.norm = norm(d_model * 4)
self.reduction = nn.Linear(d_model * 4, d_model * 2, False)

def forward(self, x: Tensor) -> Tensor:
B, H, W, C = x.shape
x = x.view(B, H // 2, 2, W // 2, 2, C).transpose(2, 3).flatten(-3)
x = self.reduction(self.norm(x))
x = x.view(B, H // 2, W // 2, C * 2)
return x


class SwinTransformer(BaseBackbone):
def __init__(
self,
img_size: int,
d_model: int,
n_heads: int,
depths: tuple[int, ...],
window_sizes: tuple[int, ...],
patch_size: int = 4,
mlp_ratio: float = 4.0,
bias: bool = True,
dropout: float = 0.0,
norm: _norm = nn.LayerNorm,
act: _act = nn.GELU,
) -> None:
assert img_size % patch_size == 0
assert d_model % n_heads == 0
super().__init__()
self.patch_embed = nn.Conv2d(3, d_model, patch_size, patch_size)
self.norm = norm(d_model)
self.dropout = nn.Dropout(dropout)

input_size = img_size // patch_size
self.stages = nn.Sequential()
for i, (depth, window_size) in enumerate(zip(depths, window_sizes)):
stage = nn.Sequential()
if i > 0:
downsample = PatchMerging(d_model, norm)
input_size //= 2
d_model *= 2
n_heads *= 2
else:
downsample = nn.Identity()
stage.append(downsample)

for i in range(depth):
shift = (i % 2) and input_size > window_size
block = SwinBlock(input_size, d_model, n_heads, window_size, shift, mlp_ratio, bias, dropout, norm, act)
stage.append(block)

self.stages.append(stage)

self.head_norm = norm(d_model)

def get_feature_maps(self, x: Tensor) -> list[Tensor]:
out = [self.dropout(self.norm(self.patch_embed(x).permute(0, 2, 3, 1)))]
for stage in self.stages:
out.append(stage(out[-1]))
return out[1:]

def forward(self, x: Tensor) -> Tensor:
return self.head_norm(self.get_feature_maps(x)[-1]).mean((1, 2))

def resize_pe(self, img_size: int) -> None:
raise NotImplementedError()

@staticmethod
def from_config(variant: str, img_size: int, pretrained: bool = False) -> SwinTransformer:
d_model, n_heads, depths, window_sizes, ckpt = {
# Sub-section 3.3 in https://arxiv.org/pdf/2103.14030.pdf
"T": (96, 3, (2, 2, 6, 2), (7, 7, 7, 7), "v1.0.8/swin_tiny_patch4_window7_224_22k.pth"),
"S": (96, 3, (2, 2, 18, 2), (7, 7, 7, 7), "v1.0.8/swin_small_patch4_window7_224_22k.pth"),
"B": (128, 4, (2, 2, 18, 2), (7, 7, 7, 7), "v1.0.0/swin_base_patch4_window7_224_22k.pth"),
"L": (192, 6, (2, 2, 18, 2), (7, 7, 7, 7), "v1.0.0/swin_large_patch4_window7_224_22k.pth"),
# https://github.com/microsoft/Cream/blob/main/AutoFormerV2/configs
"S3-T": (96, 3, (2, 2, 6, 2), (7, 7, 14, 7), "supernet-tiny.pth"),
"S3-S": (96, 3, (2, 2, 18, 2), (14, 14, 14, 14), "supernet-small.pth"),
"S3-B": (96, 3, (2, 2, 30, 2), (7, 7, 14, 7), "supernet-base.pth"),
}[variant]
m = SwinTransformer(224 if pretrained else img_size, d_model, n_heads, depths, window_sizes)

if pretrained:
if variant.startswith("S3"):
base_url = "https://github.com/silent-chen/AutoFormer-model-zoo/releases/download/v1.0/"
else:
base_url = "https://github.com/SwinTransformer/storage/releases/download/"
state_dict = torch.hub.load_state_dict_from_url(base_url + ckpt)["model"]
m.load_official_ckpt(state_dict)
if img_size != 224:
m.resize_pe(img_size)

return m

@torch.no_grad()
def load_official_ckpt(self, state_dict: dict[str, Tensor]) -> None:
def copy_(m: nn.Linear | nn.LayerNorm, prefix: str) -> None:
m.weight.copy_(state_dict.pop(prefix + ".weight"))
m.bias.copy_(state_dict.pop(prefix + ".bias"))

copy_(self.patch_embed, "patch_embed.proj")
copy_(self.norm, "patch_embed.norm")

for stage_idx, stage in enumerate(self.stages):
if stage_idx > 0:
prefix = f"layers.{stage_idx-1}.downsample."

def rearrange(p):
p1, p2, p3, p4 = p.chunk(4, -1)
return torch.cat((p1, p3, p2, p4), -1)

stage[0].norm.weight.copy_(rearrange(state_dict.pop(prefix + "norm.weight")))
stage[0].norm.bias.copy_(rearrange(state_dict.pop(prefix + "norm.bias")))
stage[0].reduction.weight.copy_(rearrange(state_dict.pop(prefix + "reduction.weight")))

for block_idx in range(1, len(stage)):
block: SwinBlock = stage[block_idx]
prefix = f"layers.{stage_idx}.blocks.{block_idx - 1}."
block_idx += 1

if block.mha.attn_mask is not None:
torch.testing.assert_close(block.mha.attn_mask, state_dict.pop(prefix + "attn_mask"))
torch.testing.assert_close(
block.mha.relative_pe_index, state_dict.pop(prefix + "attn.relative_position_index")
)
copy_(block.norm1, prefix + "norm1")
copy_(block.mha.in_proj, prefix + "attn.qkv")
copy_(block.mha.out_proj, prefix + "attn.proj")
block.mha.relative_pe_table.copy_(state_dict.pop(prefix + "attn.relative_position_bias_table").T)
copy_(block.norm2, prefix + "norm2")
copy_(block.mlp.linear1, prefix + "mlp.fc1")
copy_(block.mlp.linear2, prefix + "mlp.fc2")

copy_(self.head_norm, "norm")
assert len(state_dict) == 2 # head.weight and head.bias
Loading