Skip to content

Commit

Permalink
make swin run. fix attn_bias
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Aug 9, 2023
1 parent 9fcb837 commit bfc91e5
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 49 deletions.
27 changes: 27 additions & 0 deletions tests/test_swin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import torch

from vision_toolbox.backbones import SwinTransformer
from vision_toolbox.backbones.swin import window_partition, window_unpartition


def test_window_partition():
img = torch.randn(1, 224, 280, 3)
windows, nH, nW = window_partition(img, 7)
_img = window_unpartition(windows, 7, nH, nW)
torch.testing.assert_close(img, _img)


def test_forward():
m = SwinTransformer.from_config("T", 224)
m(torch.randn(1, 3, 224, 224))


# def test_from_pretrained():
# m = ViT.from_config("Ti", 16, 224, True).eval()
# x = torch.randn(1, 3, 224, 224)
# out = m(x)

# m_timm = timm.create_model("vit_tiny_patch16_224.augreg_in21k", pretrained=True, num_classes=0).eval()
# out_timm = m_timm(x)

# torch.testing.assert_close(out, out_timm, rtol=2e-5, atol=2e-5)
1 change: 1 addition & 0 deletions vision_toolbox/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from .torchvision_models import EfficientNetExtractor, MobileNetExtractor, RegNetExtractor, ResNetExtractor
from .vit import ViT
from .vovnet import VoVNet
from .swin import SwinTransformer
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
dropout: float = 0.0,
) -> None:
super().__init__(d_model, n_heads, bias, dropout)
self.input_size = input_size
self.window_size = window_size

if shift:
Expand All @@ -51,26 +52,28 @@ def __init__(
for i, (h_slice, w_slice) in enumerate(itertools.product(slices, slices)):
img_mask[0, h_slice, w_slice, 0] = i

windows_mask, _, _ = window_partition(img_mask)
attn_mask = windows_mask.unsqueeze(1) - windows_mask.unsqueeze(2)
windows_mask, _, _ = window_partition(img_mask, window_size) # (nH * nW, win_size * win_size, 1)
windows_mask = windows_mask.transpose(1, 2) # (nH * nW, 1, win_size * win_size)
attn_mask = windows_mask.unsqueeze(2) - windows_mask.unsqueeze(3)
self.register_buffer("attn_mask", (attn_mask != 0) * (-100), False)
self.attn_mask: Tensor

else:
self.shift = 0
self.attn_mask = None

self.relative_pe_table = nn.Parameter(torch.empty(n_heads, (2 * window_size - 1) ** 2))
self.relative_pe_table = nn.Parameter(torch.empty(1, n_heads, (2 * window_size - 1) ** 2))
nn.init.trunc_normal_(self.relative_pe_table, 0, 0.02)

xy = torch.cartesian_prod(torch.arange(window_size), torch.arange(window_size)) # all possible (x,y) pairs
diff = xy.unsqueeze(1) - xy.unsqueeze(0) # difference between all (x,y) pairs
index = (diff[:, :, 0] + window_size - 1) * (2 * window_size - 1) + diff[:, :, 1] + window_size - 1
self.register_buffer("relative_pe_index", index.flatten(), False)
self.register_buffer("relative_pe_index", index, False)
self.relative_pe_index: Tensor

def forward(self, x: Tensor) -> Tensor:
attn_bias = self.relative_pe_table[:, self.relative_pe_index].unsqueeze(0)
assert x.shape[1] == self.input_size, (x.shape[1], self.input_size)
attn_bias = self.relative_pe_table[..., self.relative_pe_index]
if self.shift > 0:
x = x.roll((self.shift, self.shift), (1, 2))
attn_bias = attn_bias + self.attn_mask
Expand Down Expand Up @@ -100,9 +103,9 @@ def __init__(
) -> None:
super().__init__()
self.norm1 = norm(d_model)
self.mha = WindowAttention(input_size, d_model, window_size, shift, n_heads, bias, dropout)
self.mha = WindowAttention(input_size, d_model, n_heads, window_size, shift, bias, dropout)
self.norm2 = norm(d_model)
self.mlp = MLP(d_model, int(d_model * mlp_ratio), act)
self.mlp = MLP(d_model, int(d_model * mlp_ratio), dropout, act)

def forward(self, x: Tensor) -> Tensor:
x = self.mha(self.norm1(x))
Expand All @@ -117,39 +120,20 @@ def __init__(self, d_model: int, norm: _norm = nn.LayerNorm) -> None:
self.reduction = nn.Linear(d_model * 4, d_model * 2, False)

def forward(self, x: Tensor) -> Tensor:
x, _, _ = window_partition(x, 2)
return self.reduction(self.norm(x))


class SwinStage(nn.Sequential):
def __init__(
self,
input_size: int,
d_model: int,
n_heads: int,
depth: int,
downsample: bool = False,
window_size: int = 7,
mlp_ratio: float = 4.0,
bias: bool = True,
dropout: float = 0.0,
norm: _norm = nn.LayerNorm,
act: _act = nn.GELU,
) -> None:
super().__init__()
for i in range(depth):
blk = SwinBlock(input_size, d_model, n_heads, window_size, i % 2 == 1, mlp_ratio, bias, dropout, norm, act)
self.append(blk)
self.downsample = PatchMerging(d_model, norm) if downsample else None
B, H, W, C = x.shape
x = x.view(B, H // 2, 2, W // 2, 2, C).transpose(2, 3).flatten(-3)
x = self.reduction(self.norm(x))
x = x.view(B, H // 2, W // 2, C * 2)
return x


class SwinTransformer(BaseBackbone):
class SwinTransformer(nn.Module):
def __init__(
self,
img_size: int,
d_model: int,
n_heads: int,
depths: tuple[int, int, int, int],
depths: tuple[int, ...],
patch_size: int = 4,
window_size: int = 7,
mlp_ratio: float = 4.0,
Expand All @@ -158,29 +142,46 @@ def __init__(
norm: _norm = nn.LayerNorm,
act: _act = nn.GELU,
) -> None:
assert img_size % patch_size == 0
assert img_size % window_size == 0
assert d_model % n_heads == 0
super().__init__()
self.patch_embed = nn.Conv2d(3, d_model, patch_size, patch_size)
self.norm = norm(d_model)
self.dropout = nn.Dropout(dropout)

input_size = img_size // patch_size
self.stages = nn.Sequential()
for i, depth in enumerate(depths):
stage = nn.Sequential()
for i in range(depth):
blk = SwinBlock(input_size, d_model, n_heads, window_size, i % 2, mlp_ratio, bias, dropout, norm, act)
stage.append(blk)

if i < len(depths) - 1:
stage.append(PatchMerging(d_model, norm))
input_size //= 2
d_model *= 2
n_heads *= 2

self.stages = nn.ModuleList()
for depth in depths:
stage = SwinStage(img_size, d_model, n_heads, depth, window_size, mlp_ratio, bias, dropout, norm, act)
self.stages.append(stage)
img_size //= 2
d_model *= 2
n_heads *= 2

def forward_features(self, x: Tensor) -> Tensor:
x = self.norm(self.patch_embed(x).permute(0, 2, 3, 1))
for stage in self.stages:
x = stage(x)
self.head_norm = norm(d_model)

def forward(self, x: Tensor) -> Tensor:
x = self.dropout(self.norm(self.patch_embed(x).permute(0, 2, 3, 1)))
x = self.stages(x)
x = self.head_norm(x).mean((1, 2))
return x

@staticmethod
def from_config(variant: str, pretrained: bool = False) -> SwinTransformer:
d_model, n_heads, n_layers = dict(
def from_config(variant: str, img_size: int, pretrained: bool = False) -> SwinTransformer:
d_model, n_heads, depths = dict(
T=(96, 3, (2, 2, 6, 2)),
S=(96, 3, (2, 2, 18, 2)),
B=(128, 4, (2, 2, 18, 2)),
L=(192, 6, (2, 2, 18, 2)),
)[variant]
m = SwinTransformer(d_model, n_heads, n_layers)
m = SwinTransformer(img_size, d_model, n_heads, depths)

return m
5 changes: 2 additions & 3 deletions vision_toolbox/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,13 @@ def __init__(self, d_model: int, n_heads: int, bias: bool = True, dropout: float
def forward(self, x: Tensor, attn_bias: Tensor | None = None) -> Tensor:
qkv = self.in_proj(x)
q, k, v = qkv.unflatten(-1, (3, self.n_heads, -1)).transpose(-2, -4).unbind(-3)

if hasattr(F, "scaled_dot_product_attention"):
out = F.scaled_dot_product_attention(q, k, v, attn_bias, self.dropout if self.training else 0.0)
else:
attn = torch.softmax(q @ (k * self.scale).transpose(-1, -2), -1)
attn = q @ (k * self.scale).transpose(-1, -2)
if attn_bias is not None:
attn = attn + attn_bias
out = F.dropout(attn, self.dropout, self.training) @ v
out = F.dropout(torch.softmax(attn, -1), self.dropout, self.training) @ v

out = out.transpose(-2, -3).flatten(-2)
out = self.out_proj(out)
Expand Down

0 comments on commit bfc91e5

Please sign in to comment.