From ccbd0aa029bf516482c383462c392ec5f80b90a4 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 25 Jul 2023 06:48:22 +0800 Subject: [PATCH] simplify --- vision_toolbox/backbones/patchconvnet.py | 146 +++++++++-------------- 1 file changed, 54 insertions(+), 92 deletions(-) diff --git a/vision_toolbox/backbones/patchconvnet.py b/vision_toolbox/backbones/patchconvnet.py index deb51ac..ccc4ca7 100644 --- a/vision_toolbox/backbones/patchconvnet.py +++ b/vision_toolbox/backbones/patchconvnet.py @@ -1,58 +1,31 @@ # https://arxiv.org/abs/2112.13692 # https://github.com/facebookresearch/deit/blob/main/patchconvnet_models.py -import warnings -import torch -from torch import Tensor, nn +from __future__ import annotations +from functools import partial -try: - from torchvision.ops import StochasticDepth - from torchvision.ops.misc import SqueezeExcitation -except ImportError: - warnings.warn("torchvision.ops.misc.SqueezeExcitation is not available. Please update your torchvision") - SqueezeExcitation = None +import torch +from torch import Tensor, nn +from torchvision.ops import StochasticDepth +from torchvision.ops.misc import SqueezeExcitation from .base import BaseBackbone -__all__ = [ - "AttentionPooling", - "PatchConvNet", - "patchconvnet_s60", - "patchconvnet_s120", - "patchconvnet_b60", - "patchconvnet_b120", - "patchconvnet_l60", - "patchconvnet_l120", -] - - -_base = {"mlp_ratio": 3, "drop_path": 0.3, "layer_scale_init": 1e-6} -_S_embed_dim = 384 -_B_embed_dim = 768 -_L_embed_dim = 1024 -configs = { - "PatchConvNet-S60": dict(**_base, embed_dim=_S_embed_dim, depth=60), - "PatchConvNet-S120": dict(**_base, embed_dim=_S_embed_dim, depth=120), - "PatchConvNet-B60": dict(**_base, embed_dim=_B_embed_dim, depth=60), - "PatchConvNet-B120": dict(**_base, embed_dim=_B_embed_dim, depth=120), - "PatchConvNet-L60": dict(**_base, embed_dim=_L_embed_dim, depth=60), - "PatchConvNet-L120": dict(**_base, embed_dim=_L_embed_dim, depth=120), -} - - class Permute(nn.Module): def __init__(self, *dims): super().__init__() self.dims = dims - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: return torch.permute(x, self.dims) class PatchConvBlock(nn.Module): - def __init__(self, embed_dim, drop_path=0.3, layer_scale_init=1e-6, norm_type="bn"): + def __init__( + self, embed_dim: int, drop_path: float = 0.3, layer_scale_init: float = 1e-6, norm_type: str = "bn" + ) -> None: assert norm_type in ("bn", "ln") super().__init__() if norm_type == "ln": @@ -69,7 +42,7 @@ def __init__(self, embed_dim, drop_path=0.3, layer_scale_init=1e-6, norm_type="b Permute(0, 2, 3, 1), # (N, C, H, W) -> (N, H, W, C) nn.Linear(embed_dim, embed_dim), ) - self.layer_scale = nn.Parameter(torch.ones(embed_dim) * layer_scale_init) + self.layer_scale = nn.Parameter(torch.full(embed_dim, layer_scale_init)) else: # BatchNorm version. Primary format is (N, C, H, W) @@ -82,31 +55,35 @@ def __init__(self, embed_dim, drop_path=0.3, layer_scale_init=1e-6, norm_type="b SqueezeExcitation(embed_dim, embed_dim // 4), nn.Conv2d(embed_dim, embed_dim, 1), ) - self.layer_scale = nn.Parameter(torch.ones(embed_dim, 1, 1) * layer_scale_init) + self.layer_scale = nn.Parameter(torch.full((embed_dim, 1, 1), layer_scale_init)) self.drop_path = StochasticDepth(drop_path, "row") if drop_path > 0 else nn.Identity() - def forward(self, x: Tensor): + def forward(self, x: Tensor) -> Tensor: return x + self.drop_path(self.layers(x) * self.layer_scale) class AttentionPooling(nn.Module): - def __init__(self, embed_dim, mlp_ratio, drop_path=0.3, layer_scale_init=1e-6): + def __init__( + self, embed_dim: int, mlp_ratio: int = 3, drop_path: float = 0.3, layer_scale_init: float = 1e-6 + ) -> Tensor: super().__init__() - self.drop_path = StochasticDepth(drop_path, "row") if drop_path > 0 else nn.Identity() self.cls_token = nn.Parameter(torch.zeros(embed_dim)) - self.norm_1 = nn.LayerNorm(embed_dim) + self.norm_1 = nn.LayerNorm(embed_dim) self.attn = nn.MultiheadAttention(embed_dim, 1, batch_first=True) self.layer_scale_1 = nn.Parameter(torch.ones(embed_dim) * layer_scale_init) - self.norm_2 = nn.LayerNorm(embed_dim) + self.drop_path1 = StochasticDepth(drop_path, "row") if drop_path > 0 else nn.Identity() + self.norm_2 = nn.LayerNorm(embed_dim) mlp_dim = int(embed_dim * mlp_ratio) self.mlp = nn.Sequential(nn.Linear(embed_dim, mlp_dim), nn.GELU(), nn.Linear(mlp_dim, embed_dim)) self.layer_scale_2 = nn.Parameter(torch.ones(embed_dim) * layer_scale_init) + self.drop_path2 = StochasticDepth(drop_path, "row") if drop_path > 0 else nn.Identity() + self.norm_3 = nn.LayerNorm(embed_dim) - def forward(self, x: Tensor): + def forward(self, x: Tensor) -> Tensor: # (N, HW, C) cls_token = self.cls_token.expand(x.shape[0], 1, -1) out = torch.cat((cls_token, x), dim=1) @@ -114,19 +91,26 @@ def forward(self, x: Tensor): # attention pooling. q = cls_token. k = v = (cls_token, x) out = self.norm_1(out) out = self.attn(out[:, :1], out, out, need_weights=False)[0] - cls_token = cls_token + self.drop_path(out * self.layer_scale_1) # residual + layer scale + dropout + cls_token = cls_token + self.drop_path1(out * self.layer_scale_1) # residual + layer scale + dropout # mlp - out = self.norm_2(cls_token) - out = self.mlp(out) - cls_token = cls_token + self.drop_path(out * self.layer_scale_2) + out = self.mlp(self.norm_2(cls_token)) + cls_token = cls_token + self.drop_path2(out * self.layer_scale_2) out = self.norm_3(cls_token).squeeze(1) # (N, 1, C) -> (N, C) return out class PatchConvNet(BaseBackbone): - def __init__(self, embed_dim, depth, mlp_ratio, drop_path, layer_scale_init, norm_type="bn"): + def __init__( + self, + embed_dim: int, + depth: int, + mlp_ratio: int = 3, + drop_path: float = 0.3, + layer_scale_init: float = 1e-6, + norm_type: str = "bn", + ) -> None: assert norm_type in ("bn", "ln") super().__init__() self.norm_type = norm_type @@ -134,20 +118,23 @@ def __init__(self, embed_dim, depth, mlp_ratio, drop_path, layer_scale_init, nor # stem has no bias and no last activation layer # https://github.com/facebookresearch/deit/issues/151 - kwargs = dict(kernel_size=3, stride=2, padding=1, bias=False) + conv3x3_s2 = partial(nn.Conv2d, kernel_size=3, stride=2, padding=1, bias=False) self.stem = nn.Sequential( - nn.Conv2d(3, embed_dim // 8, **kwargs), + conv3x3_s2(3, embed_dim // 8), nn.GELU(), - nn.Conv2d(embed_dim // 8, embed_dim // 4, **kwargs), + conv3x3_s2(embed_dim // 8, embed_dim // 4), nn.GELU(), - nn.Conv2d(embed_dim // 4, embed_dim // 2, **kwargs), + conv3x3_s2(embed_dim // 4, embed_dim // 2), nn.GELU(), - nn.Conv2d(embed_dim // 2, embed_dim, **kwargs), + conv3x3_s2(embed_dim // 2, embed_dim), ) - kwargs = dict(drop_path=drop_path, layer_scale_init=layer_scale_init) - self.trunk = nn.Sequential(*[PatchConvBlock(embed_dim, norm_type=norm_type, **kwargs) for _ in range(depth)]) - self.pool = AttentionPooling(embed_dim, mlp_ratio, **kwargs) + self.trunk = nn.Sequential( + Permute(0, 2, 3, 1) if norm_type == "ln" else nn.Identity(), + *[PatchConvBlock(embed_dim, drop_path, layer_scale_init, norm_type) for _ in range(depth)], + Permute(0, 2, 3, 1) if norm_type == "bn" else nn.Identity(), + ) + self.pool = AttentionPooling(embed_dim, mlp_ratio, drop_path, layer_scale_init) # weight initialization nn.init.trunc_normal_(self.pool.cls_token, std=0.02) @@ -161,40 +148,15 @@ def __init__(self, embed_dim, depth, mlp_ratio, drop_path, layer_scale_init, nor def forward(self, x: Tensor): out = self.stem(x) - - if self.norm_type == "ln": - # layer norm - out = torch.permute(out, (0, 2, 3, 1)) # (N, C, H, W) -> (N, H, W, C) - out = self.trunk(out) - out = torch.flatten(out, 1, 2) # (N, H, W, C) -> (N, HW, C) - else: - # batch norm - out = self.trunk(out) - out = out.flatten(2).transpose(1, 2) # (N, C, H, W) -> (N, HW, C) - + out = self.trunk(x) + out = out.flatten(1, 2) out = self.pool(out) return out - -def patchconvnet_s60(pretrained=False, **kwargs): - return PatchConvNet.from_config(configs["PatchConvNet-S60"], pretrained=pretrained, **kwargs) - - -def patchconvnet_s120(pretrained=False, **kwargs): - return PatchConvNet.from_config(configs["PatchConvNet-S120"], pretrained=pretrained, **kwargs) - - -def patchconvnet_b60(pretrained=False, **kwargs): - return PatchConvNet.from_config(configs["PatchConvNet-B60"], pretrained=pretrained, **kwargs) - - -def patchconvnet_b120(pretrained=False, **kwargs): - return PatchConvNet.from_config(configs["PatchConvNet-B120"], pretrained=pretrained, **kwargs) - - -def patchconvnet_l60(pretrained=False, **kwargs): - return PatchConvNet.from_config(configs["PatchConvNet-L60"], pretrained=pretrained, **kwargs) - - -def patchconvnet_l120(pretrained=False, **kwargs): - return PatchConvNet.from_config(configs["PatchConvNet-L120"], pretrained=pretrained, **kwargs) + @staticmethod + def from_config(variant: str, depth: int, pretrained: bool = False) -> PatchConvNet: + embed_dim = dict(S=384, B=768, L=1024)[variant] + m = PatchConvNet(embed_dim, depth) + if pretrained: + raise ValueError + return m