diff --git a/vision_toolbox/backbones/swin_transformer.py b/vision_toolbox/backbones/swin_transformer.py index e09f8da..36e30b6 100644 --- a/vision_toolbox/backbones/swin_transformer.py +++ b/vision_toolbox/backbones/swin_transformer.py @@ -3,16 +3,13 @@ from __future__ import annotations -from functools import partial -from typing import Mapping +import itertools -import numpy as np import torch -import torch.nn.functional as F from torch import Tensor, nn from ..utils import torch_hub_download -from .base import _act, _norm +from .base import BaseBackbone, _act, _norm from .vit import MHA, MLP @@ -34,11 +31,34 @@ def window_unpartition(x: Tensor, window_size: int, nH: int, nW: int) -> Tensor: class WindowAttention(MHA): def __init__( - self, d_model: int, window_size: int, shift: int, n_heads: int, bias: bool = True, dropout: float = 0.0 + self, + input_size: int, + d_model: int, + n_heads: int, + window_size: int = 7, + shift: bool = False, + bias: bool = True, + dropout: float = 0.0, ) -> None: super().__init__(d_model, n_heads, bias, dropout) self.window_size = window_size - self.shift = shift + + if shift: + self.shift = window_size // 2 + + img_mask = torch.zeros(1, input_size, input_size, 1) + slices = (slice(0, -window_size), slice(-window_size, -self.shift), slice(-self.shift, None)) + for i, (h_slice, w_slice) in enumerate(itertools.product(slices, slices)): + img_mask[0, h_slice, w_slice, 0] = i + + windows_mask, _, _ = window_partition(img_mask) + attn_mask = windows_mask.unsqueeze(1) - windows_mask.unsqueeze(2) + self.register_buffer("attn_mask", (attn_mask != 0) * (-100), False) + self.attn_mask: Tensor + + else: + self.shift = 0 + self.attn_mask = None self.relative_pe_table = nn.Parameter(torch.empty(n_heads, (2 * window_size - 1) ** 2)) nn.init.trunc_normal_(self.relative_pe_table, 0, 0.02) @@ -50,24 +70,28 @@ def __init__( self.relative_pe_index: Tensor def forward(self, x: Tensor) -> Tensor: + attn_bias = self.relative_pe_table[:, self.relative_pe_index].unsqueeze(0) if self.shift > 0: x = x.roll((self.shift, self.shift), (1, 2)) - x, nH, nW = window_partition(x, self.window_size) # (B * nH * nW, win_size * win_size, C) + attn_bias = attn_bias + self.attn_mask - attn_bias = self.relative_pe_table[:, self.relative_pe_index].unsqueeze(0) + x, nH, nW = window_partition(x, self.window_size) # (B * nH * nW, win_size * win_size, C) x = super().forward(x, attn_bias) - x = window_unpartition(x, self.window_size, nH, nW) + + if self.shift > 0: + x = x.roll((-self.shift, -self.shift), (1, 2)) return x class SwinBlock(nn.Module): def __init__( self, + input_size: int, d_model: int, n_heads: int, - window_size: int, - shift: int, + window_size: int = 7, + shift: bool = False, mlp_ratio: float = 4.0, bias: bool = True, dropout: float = 0.0, @@ -76,7 +100,7 @@ def __init__( ) -> None: super().__init__() self.norm1 = norm(d_model) - self.mha = WindowAttention(d_model, window_size, shift, n_heads, bias, dropout) + self.mha = WindowAttention(input_size, d_model, window_size, shift, n_heads, bias, dropout) self.norm2 = norm(d_model) self.mlp = MLP(d_model, int(d_model * mlp_ratio), act) @@ -86,16 +110,77 @@ def forward(self, x: Tensor) -> Tensor: return x -class SwinTransformer(nn.Module): - def __init__(self, d_model: int, n_layers: tuple[int, int, int, int]) -> None: +class PatchMerging(nn.Module): + def __init__(self, d_model: int, norm: _norm = nn.LayerNorm) -> None: + super().__init__() + self.norm = norm(d_model * 4) + self.reduction = nn.Linear(d_model * 4, d_model * 2, False) + + def forward(self, x: Tensor) -> Tensor: + x, _, _ = window_partition(x, 2) + return self.reduction(self.norm(x)) + + +class SwinStage(nn.Sequential): + def __init__( + self, + input_size: int, + d_model: int, + n_heads: int, + depth: int, + downsample: bool = False, + window_size: int = 7, + mlp_ratio: float = 4.0, + bias: bool = True, + dropout: float = 0.0, + norm: _norm = nn.LayerNorm, + act: _act = nn.GELU, + ) -> None: + super().__init__() + for i in range(depth): + blk = SwinBlock(input_size, d_model, n_heads, window_size, i % 2 == 1, mlp_ratio, bias, dropout, norm, act) + self.append(blk) + self.downsample = PatchMerging(d_model, norm) if downsample else None + + +class SwinTransformer(BaseBackbone): + def __init__( + self, + img_size: int, + d_model: int, + n_heads: int, + depths: tuple[int, int, int, int], + patch_size: int = 4, + window_size: int = 7, + mlp_ratio: float = 4.0, + bias: bool = True, + dropout: float = 0.0, + norm: _norm = nn.LayerNorm, + act: _act = nn.GELU, + ) -> None: super().__init__() + self.patch_embed = nn.Conv2d(3, d_model, patch_size, patch_size) + self.norm = norm(d_model) + + self.stages = nn.ModuleList() + for depth in depths: + stage = SwinStage(img_size, d_model, n_heads, depth, window_size, mlp_ratio, bias, dropout, norm, act) + self.stages.append(stage) + img_size //= 2 + d_model *= 2 + n_heads *= 2 + + def forward_features(self, x: Tensor) -> Tensor: + x = self.norm(self.patch_embed(x).permute(0, 2, 3, 1)) + for stage in self.stages: + x = stage(x) @staticmethod def from_config(variant: str, pretrained: bool = False) -> SwinTransformer: - d_model, n_layers = dict( - T=(96, (2, 2, 6, 2)), - S=(96, (2, 2, 18, 2)), - B=(128, (2, 2, 18, 2)), - L=(192, (2, 2, 18, 2)), + d_model, n_heads, n_layers = dict( + T=(96, 3, (2, 2, 6, 2)), + S=(96, 3, (2, 2, 18, 2)), + B=(128, 4, (2, 2, 18, 2)), + L=(192, 6, (2, 2, 18, 2)), )[variant] - m = SwinTransformer(d_model, n_layers) + m = SwinTransformer(d_model, n_heads, n_layers)