diff --git a/vision_toolbox/backbones/swin.py b/vision_toolbox/backbones/swin.py index 4ca7e7d..04a1d65 100644 --- a/vision_toolbox/backbones/swin.py +++ b/vision_toolbox/backbones/swin.py @@ -134,8 +134,8 @@ def __init__( d_model: int, n_heads: int, depths: tuple[int, ...], + window_sizes: tuple[int, ...], patch_size: int = 4, - window_size: int = 7, mlp_ratio: float = 4.0, bias: bool = True, dropout: float = 0.0, @@ -143,7 +143,6 @@ def __init__( act: _act = nn.GELU, ) -> None: assert img_size % patch_size == 0 - assert img_size % window_size == 0 assert d_model % n_heads == 0 super().__init__() self.patch_embed = nn.Conv2d(3, d_model, patch_size, patch_size) @@ -152,7 +151,7 @@ def __init__( input_size = img_size // patch_size self.stages = nn.Sequential() - for i, depth in enumerate(depths): + for i, (depth, window_size) in enumerate(zip(depths, window_sizes)): stage = nn.Sequential() for i in range(depth): blk = SwinBlock(input_size, d_model, n_heads, window_size, i % 2, mlp_ratio, bias, dropout, norm, act) @@ -176,12 +175,17 @@ def forward(self, x: Tensor) -> Tensor: @staticmethod def from_config(variant: str, img_size: int, pretrained: bool = False) -> SwinTransformer: - d_model, n_heads, depths = dict( - T=(96, 3, (2, 2, 6, 2)), - S=(96, 3, (2, 2, 18, 2)), - B=(128, 4, (2, 2, 18, 2)), - L=(192, 6, (2, 2, 18, 2)), - )[variant] - m = SwinTransformer(img_size, d_model, n_heads, depths) + d_model, n_heads, depths, window_sizes = { + # Sub-section 3.3 in https://arxiv.org/pdf/2103.14030.pdf + "T": (96, 3, (2, 2, 6, 2), (7, 7, 7, 7)), + "S": (96, 3, (2, 2, 18, 2), (7, 7, 7, 7)), + "B": (128, 4, (2, 2, 18, 2), (7, 7, 7, 7)), + "L": (192, 6, (2, 2, 18, 2), (7, 7, 7, 7)), + # https://github.com/microsoft/Cream/blob/main/AutoFormerV2/configs + "S3-T": (96, 3, (2, 2, 6, 2), (7, 7, 14, 7)), + "S3-S": (96, 3, (2, 2, 18, 2), (14, 14, 14, 14)), + "S3-B": (96, 3, (2, 2, 30, 2), (7, 7, 14, 7)), + }[variant] + m = SwinTransformer(img_size, d_model, n_heads, depths, window_sizes) return m