From 7ce7b4ef4cf46c67453e25213087ed087fa8baae Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 25 Jul 2023 07:18:55 +0800 Subject: [PATCH] update patchconvnet --- tests/test_backbones.py | 2 + vision_toolbox/backbones/__init__.py | 2 +- vision_toolbox/backbones/patchconvnet.py | 84 +++++++++++++----------- 3 files changed, 47 insertions(+), 41 deletions(-) diff --git a/tests/test_backbones.py b/tests/test_backbones.py index 60a6daf..5a7f53a 100644 --- a/tests/test_backbones.py +++ b/tests/test_backbones.py @@ -9,6 +9,7 @@ DarknetYOLOv5, EfficientNetExtractor, MobileNetExtractor, + PatchConvNet, RegNetExtractor, ResNetExtractor, VoVNet, @@ -27,6 +28,7 @@ def inputs(): partial(VoVNet.from_config, x, y, z) for x, y, z in ((27, True, False), (39, False, False), (19, True, True), (57, False, True)) ], + # partial(PatchConvNet.from_config, "S", 60), partial(ResNetExtractor, "resnet18"), partial(RegNetExtractor, "regnet_x_400mf"), partial(MobileNetExtractor, "mobilenet_v2"), diff --git a/vision_toolbox/backbones/__init__.py b/vision_toolbox/backbones/__init__.py index 1425c61..3f18f45 100644 --- a/vision_toolbox/backbones/__init__.py +++ b/vision_toolbox/backbones/__init__.py @@ -1,5 +1,5 @@ from .darknet import Darknet, DarknetYOLOv5 -from .patchconvnet import * +from .patchconvnet import PatchConvNet from .torchvision_models import EfficientNetExtractor, MobileNetExtractor, RegNetExtractor, ResNetExtractor from .vit import ViT from .vovnet import VoVNet diff --git a/vision_toolbox/backbones/patchconvnet.py b/vision_toolbox/backbones/patchconvnet.py index ccc4ca7..ea8c363 100644 --- a/vision_toolbox/backbones/patchconvnet.py +++ b/vision_toolbox/backbones/patchconvnet.py @@ -22,41 +22,43 @@ def forward(self, x: Tensor) -> Tensor: return torch.permute(x, self.dims) -class PatchConvBlock(nn.Module): - def __init__( - self, embed_dim: int, drop_path: float = 0.3, layer_scale_init: float = 1e-6, norm_type: str = "bn" - ) -> None: - assert norm_type in ("bn", "ln") +class PatchConvBlockLN(nn.Module): + def __init__(self, embed_dim: int, drop_path: float = 0.3, layer_scale_init: float = 1e-6) -> None: super().__init__() - if norm_type == "ln": - # LayerNorm version. Primary format is (N, H, W, C) - # follow this approach https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py - self.layers = nn.Sequential( - nn.LayerNorm(embed_dim), - nn.Linear(embed_dim, embed_dim), - nn.GELU(), - Permute(0, 3, 1, 2), # (N, H, W, C) -> (N, C, H, W) - nn.Conv2d(embed_dim, embed_dim, 3, padding=1, groups=embed_dim), # dw-conv - nn.GELU(), - SqueezeExcitation(embed_dim, embed_dim // 4), - Permute(0, 2, 3, 1), # (N, C, H, W) -> (N, H, W, C) - nn.Linear(embed_dim, embed_dim), - ) - self.layer_scale = nn.Parameter(torch.full(embed_dim, layer_scale_init)) - - else: - # BatchNorm version. Primary format is (N, C, H, W) - self.layers = nn.Sequential( - nn.BatchNorm2d(embed_dim), - nn.Conv2d(embed_dim, embed_dim, 1), - nn.GELU(), - nn.Conv2d(embed_dim, embed_dim, 3, padding=1, groups=embed_dim), - nn.GELU(), - SqueezeExcitation(embed_dim, embed_dim // 4), - nn.Conv2d(embed_dim, embed_dim, 1), - ) - self.layer_scale = nn.Parameter(torch.full((embed_dim, 1, 1), layer_scale_init)) + # LayerNorm version. Primary format is (N, H, W, C) + # follow this approach https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py + self.layers = nn.Sequential( + nn.LayerNorm(embed_dim), + nn.Linear(embed_dim, embed_dim), + nn.GELU(), + Permute(0, 3, 1, 2), # (N, H, W, C) -> (N, C, H, W) + nn.Conv2d(embed_dim, embed_dim, 3, padding=1, groups=embed_dim), # dw-conv + nn.GELU(), + SqueezeExcitation(embed_dim, embed_dim // 4), + Permute(0, 2, 3, 1), # (N, C, H, W) -> (N, H, W, C) + nn.Linear(embed_dim, embed_dim), + ) + self.layer_scale = nn.Parameter(torch.full((embed_dim,), layer_scale_init)) + self.drop_path = StochasticDepth(drop_path, "row") if drop_path > 0 else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + return x + self.drop_path(self.layers(x) * self.layer_scale) + +class PatchConvBlockBN(nn.Module): + def __init__(self, embed_dim: int, drop_path: float = 0.3, layer_scale_init: float = 1e-6) -> None: + super().__init__() + # BatchNorm version. Primary format is (N, C, H, W) + self.layers = nn.Sequential( + nn.BatchNorm2d(embed_dim), + nn.Conv2d(embed_dim, embed_dim, 1), + nn.GELU(), + nn.Conv2d(embed_dim, embed_dim, 3, padding=1, groups=embed_dim), + nn.GELU(), + SqueezeExcitation(embed_dim, embed_dim // 4), + nn.Conv2d(embed_dim, embed_dim, 1), + ) + self.layer_scale = nn.Parameter(torch.full((embed_dim, 1, 1), layer_scale_init)) self.drop_path = StochasticDepth(drop_path, "row") if drop_path > 0 else nn.Identity() def forward(self, x: Tensor) -> Tensor: @@ -72,13 +74,13 @@ def __init__( self.norm_1 = nn.LayerNorm(embed_dim) self.attn = nn.MultiheadAttention(embed_dim, 1, batch_first=True) - self.layer_scale_1 = nn.Parameter(torch.ones(embed_dim) * layer_scale_init) + self.layer_scale_1 = nn.Parameter(torch.full((embed_dim,), layer_scale_init)) self.drop_path1 = StochasticDepth(drop_path, "row") if drop_path > 0 else nn.Identity() self.norm_2 = nn.LayerNorm(embed_dim) mlp_dim = int(embed_dim * mlp_ratio) self.mlp = nn.Sequential(nn.Linear(embed_dim, mlp_dim), nn.GELU(), nn.Linear(mlp_dim, embed_dim)) - self.layer_scale_2 = nn.Parameter(torch.ones(embed_dim) * layer_scale_init) + self.layer_scale_2 = nn.Parameter(torch.full((embed_dim,), layer_scale_init)) self.drop_path2 = StochasticDepth(drop_path, "row") if drop_path > 0 else nn.Identity() self.norm_3 = nn.LayerNorm(embed_dim) @@ -114,7 +116,8 @@ def __init__( assert norm_type in ("bn", "ln") super().__init__() self.norm_type = norm_type - self.out_channels = (embed_dim,) + self.out_channels_list = (embed_dim,) + self.stride = 16 # stem has no bias and no last activation layer # https://github.com/facebookresearch/deit/issues/151 @@ -129,9 +132,10 @@ def __init__( conv3x3_s2(embed_dim // 2, embed_dim), ) + blk = PatchConvBlockLN if norm_type == "ln" else PatchConvBlockBN self.trunk = nn.Sequential( Permute(0, 2, 3, 1) if norm_type == "ln" else nn.Identity(), - *[PatchConvBlock(embed_dim, drop_path, layer_scale_init, norm_type) for _ in range(depth)], + *[blk(embed_dim, drop_path, layer_scale_init) for _ in range(depth)], Permute(0, 2, 3, 1) if norm_type == "bn" else nn.Identity(), ) self.pool = AttentionPooling(embed_dim, mlp_ratio, drop_path, layer_scale_init) @@ -146,12 +150,12 @@ def __init__( if m.bias is not None: nn.init.zeros_(m.bias) - def forward(self, x: Tensor): + def get_feature_maps(self, x: Tensor) -> list[Tensor]: out = self.stem(x) - out = self.trunk(x) + out = self.trunk(out) out = out.flatten(1, 2) out = self.pool(out) - return out + return [out] @staticmethod def from_config(variant: str, depth: int, pretrained: bool = False) -> PatchConvNet: