From 6205cc30791c32f3eca3226e0a63bff1968b446f Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 19 Aug 2023 13:10:11 +0800 Subject: [PATCH] add initial convnext --- tests/test_convnext.py | 20 +++++ vision_toolbox/backbones/__init__.py | 1 + vision_toolbox/backbones/convnext.py | 105 +++++++++++++++++++++++++++ vision_toolbox/components.py | 28 +++++++ 4 files changed, 154 insertions(+) create mode 100644 tests/test_convnext.py create mode 100644 vision_toolbox/backbones/convnext.py diff --git a/tests/test_convnext.py b/tests/test_convnext.py new file mode 100644 index 0000000..97000d9 --- /dev/null +++ b/tests/test_convnext.py @@ -0,0 +1,20 @@ +import timm +import torch + +from vision_toolbox.backbones import ConvNeXt + + +def test_forward(): + m = ConvNeXt.from_config("T") + m(torch.randn(1, 3, 224, 224)) + + +# def test_from_pretrained(): +# m = ViT.from_config("Ti", 16, 224, True).eval() +# x = torch.randn(1, 3, 224, 224) +# out = m(x) + +# m_timm = timm.create_model("vit_tiny_patch16_224.augreg_in21k", pretrained=True, num_classes=0).eval() +# out_timm = m_timm(x) + +# torch.testing.assert_close(out, out_timm, rtol=2e-5, atol=2e-5) diff --git a/vision_toolbox/backbones/__init__.py b/vision_toolbox/backbones/__init__.py index 55a1857..a724f7c 100644 --- a/vision_toolbox/backbones/__init__.py +++ b/vision_toolbox/backbones/__init__.py @@ -1,3 +1,4 @@ +from .convnext import ConvNeXt from .darknet import Darknet, DarknetYOLOv5 from .mlp_mixer import MLPMixer from .patchconvnet import PatchConvNet diff --git a/vision_toolbox/backbones/convnext.py b/vision_toolbox/backbones/convnext.py new file mode 100644 index 0000000..2ed5b11 --- /dev/null +++ b/vision_toolbox/backbones/convnext.py @@ -0,0 +1,105 @@ +# https://arxiv.org/abs/2201.03545 + +from __future__ import annotations + +from functools import partial + +import torch +from torch import Tensor, nn + +from ..components import Permute, StochasticDepth +from .base import BaseBackbone, _act, _norm + + +class ConvNeXtBlock(nn.Module): + def __init__( + self, + d_model: int, + expansion_ratio: float = 4.0, + bias: bool = True, + layer_scale_init: float = 1e-6, + stochastic_depth: float = 0.0, + norm: _norm = partial(nn.LayerNorm, eps=1e-6), + act: _act = nn.GELU, + ) -> None: + super().__init__() + hidden_dim = int(d_model * expansion_ratio) + self.layers = nn.Sequential( + Permute(0, 3, 1, 2), + nn.Conv2d(d_model, d_model, 7, padding=3, groups=d_model, bias=bias), + Permute(0, 2, 3, 1), + norm(d_model), + nn.Linear(d_model, hidden_dim, bias=bias), + act(), + nn.Linear(hidden_dim, d_model, bias=bias), + ) + self.layer_scale = nn.Parameter(torch.full((d_model,), layer_scale_init)) if layer_scale_init > 0 else None + self.drop = StochasticDepth(stochastic_depth) + + def forward(self, x: Tensor) -> Tensor: + return x + self.drop(self.layers(x) * self.layer_scale) + + +class ConvNeXt(BaseBackbone): + def __init__( + self, + d_model: int, + depths: tuple[int, ...], + expansion_ratio: float = 4.0, + bias: bool = True, + layer_scale_init: float = 1e-6, + stochastic_depth: float = 0.0, + norm: _norm = partial(nn.LayerNorm, eps=1e-6), + act: _act = nn.GELU, + ) -> None: + super().__init__() + self.stem = nn.Sequential(nn.Conv2d(3, d_model, 4, 4), Permute(0, 2, 3, 1), norm(d_model)) + + self.stages = nn.Sequential() + for stage_idx, depth in enumerate(depths): + stage = nn.Sequential() + if stage_idx > 0: + # equivalent to PatchMerging in SwinTransformer + downsample = nn.Sequential( + norm(d_model), + Permute(0, 3, 1, 2), + nn.Conv2d(d_model, d_model * 2, 2, 2), + Permute(0, 2, 3, 1), + ) + d_model *= 2 + else: + downsample = nn.Identity() + stage.append(downsample) + + for _ in range(depth): + block = ConvNeXtBlock(d_model, expansion_ratio, bias, layer_scale_init, stochastic_depth, norm, act) + stage.append(block) + + self.stages.append(stage) + + self.head_norm = norm(d_model) + + def get_feature_maps(self, x: Tensor) -> list[Tensor]: + out = [self.stem(x)] + for stage in self.stages: + out.append(stage(out[-1])) + return out[-1:] + + def forward(self, x: Tensor) -> Tensor: + return self.head_norm(self.get_feature_maps(x)[-1].mean((1, 2))) + + @staticmethod + def from_config(variant: str, pretrained: bool = False) -> ConvNeXt: + d_model, depths = dict( + T=(96, (3, 3, 9, 3)), + S=(96, (3, 3, 27, 3)), + B=(128, (3, 3, 27, 3)), + L=(192, (3, 3, 27, 3)), + XL=(256, (3, 3, 27, 3)), + )[variant] + m = ConvNeXt(d_model, depths) + + if pretrained: + pass + + return m diff --git a/vision_toolbox/components.py b/vision_toolbox/components.py index f5c0cee..d813166 100644 --- a/vision_toolbox/components.py +++ b/vision_toolbox/components.py @@ -150,3 +150,31 @@ def forward(self, x: Tensor) -> Tensor: x = self.pool(x) outputs.append(x) return torch.cat(outputs, dim=1) + + +class Permute(nn.Module): + def __init__(self, *dims: int) -> None: + super().__init__() + self.dims = dims + + def forward(self, x: Tensor) -> Tensor: + return x.permute(self.dims) + + +# https://arxiv.org/pdf/1603.09382.pdf +class StochasticDepth(nn.Module): + def __init__(self, p: float) -> None: + assert 0.0 <= p <= 1.0 + super().__init__() + self.p = p + + def forward(self, x: Tensor) -> Tensor: + if not self.training or self.p == 0.0: + return x + + shape = [x.shape[0]] + [1] * (x.ndim - 1) + keep_p = 1.0 - self.p + return x * x.new_empty(shape).bernoulli_(keep_p).div_(keep_p) + + def extra_repr(self) -> str: + return f"p={self.p}"