From 8c1d2c4b42feb44df320efecfe3969d6e1d35d1c Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 19 Aug 2023 13:35:44 +0800 Subject: [PATCH] Add ConvNeXt (#17) --- tests/test_convnext.py | 20 ++++ vision_toolbox/backbones/__init__.py | 1 + vision_toolbox/backbones/convnext.py | 146 +++++++++++++++++++++++++++ vision_toolbox/components.py | 28 +++++ 4 files changed, 195 insertions(+) create mode 100644 tests/test_convnext.py create mode 100644 vision_toolbox/backbones/convnext.py diff --git a/tests/test_convnext.py b/tests/test_convnext.py new file mode 100644 index 0000000..c516420 --- /dev/null +++ b/tests/test_convnext.py @@ -0,0 +1,20 @@ +import timm +import torch + +from vision_toolbox.backbones import ConvNeXt + + +def test_forward(): + m = ConvNeXt.from_config("T") + m(torch.randn(1, 3, 224, 224)) + + +def test_from_pretrained(): + m = ConvNeXt.from_config("T", True).eval() + x = torch.randn(1, 3, 224, 224) + out = m(x) + + m_timm = timm.create_model("convnext_tiny.fb_in22k", pretrained=True, num_classes=0).eval() + out_timm = m_timm(x) + + torch.testing.assert_close(out, out_timm) diff --git a/vision_toolbox/backbones/__init__.py b/vision_toolbox/backbones/__init__.py index 55a1857..a724f7c 100644 --- a/vision_toolbox/backbones/__init__.py +++ b/vision_toolbox/backbones/__init__.py @@ -1,3 +1,4 @@ +from .convnext import ConvNeXt from .darknet import Darknet, DarknetYOLOv5 from .mlp_mixer import MLPMixer from .patchconvnet import PatchConvNet diff --git a/vision_toolbox/backbones/convnext.py b/vision_toolbox/backbones/convnext.py new file mode 100644 index 0000000..86e346e --- /dev/null +++ b/vision_toolbox/backbones/convnext.py @@ -0,0 +1,146 @@ +# https://arxiv.org/abs/2201.03545 +# https://github.com/facebookresearch/ConvNeXt + +from __future__ import annotations + +from functools import partial + +import torch +from torch import Tensor, nn + +from ..components import Permute, StochasticDepth +from .base import BaseBackbone, _act, _norm + + +class ConvNeXtBlock(nn.Module): + def __init__( + self, + d_model: int, + expansion_ratio: float = 4.0, + bias: bool = True, + layer_scale_init: float = 1e-6, + stochastic_depth: float = 0.0, + norm: _norm = partial(nn.LayerNorm, eps=1e-6), + act: _act = nn.GELU, + ) -> None: + super().__init__() + hidden_dim = int(d_model * expansion_ratio) + self.layers = nn.Sequential( + Permute(0, 3, 1, 2), + nn.Conv2d(d_model, d_model, 7, padding=3, groups=d_model, bias=bias), + Permute(0, 2, 3, 1), + norm(d_model), + nn.Linear(d_model, hidden_dim, bias=bias), + act(), + nn.Linear(hidden_dim, d_model, bias=bias), + ) + self.layer_scale = nn.Parameter(torch.full((d_model,), layer_scale_init)) if layer_scale_init > 0 else None + self.drop = StochasticDepth(stochastic_depth) + + def forward(self, x: Tensor) -> Tensor: + return x + self.drop(self.layers(x) * self.layer_scale) + + +class ConvNeXt(BaseBackbone): + def __init__( + self, + d_model: int, + depths: tuple[int, ...], + expansion_ratio: float = 4.0, + bias: bool = True, + layer_scale_init: float = 1e-6, + stochastic_depth: float = 0.0, + norm: _norm = partial(nn.LayerNorm, eps=1e-6), + act: _act = nn.GELU, + ) -> None: + super().__init__() + self.stem = nn.Sequential(nn.Conv2d(3, d_model, 4, 4), Permute(0, 2, 3, 1), norm(d_model)) + + stochastic_depth_rates = torch.linspace(0, stochastic_depth, sum(depths)) + self.stages = nn.Sequential() + + for stage_idx, depth in enumerate(depths): + stage = nn.Sequential() + if stage_idx > 0: + # equivalent to PatchMerging in SwinTransformer + downsample = nn.Sequential( + norm(d_model), + Permute(0, 3, 1, 2), + nn.Conv2d(d_model, d_model * 2, 2, 2), + Permute(0, 2, 3, 1), + ) + d_model *= 2 + else: + downsample = nn.Identity() + stage.append(downsample) + + for block_idx in range(depth): + rate = stochastic_depth_rates[sum(depths[:stage_idx]) + block_idx] + block = ConvNeXtBlock(d_model, expansion_ratio, bias, layer_scale_init, rate, norm, act) + stage.append(block) + + self.stages.append(stage) + + self.head_norm = norm(d_model) + + def get_feature_maps(self, x: Tensor) -> list[Tensor]: + out = [self.stem(x)] + for stage in self.stages: + out.append(stage(out[-1])) + return out[-1:] + + def forward(self, x: Tensor) -> Tensor: + return self.head_norm(self.get_feature_maps(x)[-1].mean((1, 2))) + + @staticmethod + def from_config(variant: str, pretrained: bool = False) -> ConvNeXt: + d_model, depths = dict( + T=(96, (3, 3, 9, 3)), + S=(96, (3, 3, 27, 3)), + B=(128, (3, 3, 27, 3)), + L=(192, (3, 3, 27, 3)), + XL=(256, (3, 3, 27, 3)), + )[variant] + m = ConvNeXt(d_model, depths) + + if pretrained: + # TODO: also add torchvision checkpoints? + ckpt = dict( + T="convnext_tiny_22k_224.pth", + S="convnext_small_22k_224.pth", + B="convnext_base_22k_224.pth", + L="convnext_large_22k_224.pth", + XL="convnext_xlarge_22k_224.pth", + )[variant] + base_url = "https://dl.fbaipublicfiles.com/convnext/" + state_dict = torch.hub.load_state_dict_from_url(base_url + ckpt)["model"] + m.load_official_ckpt(state_dict) + + return m + + @torch.no_grad() + def load_official_ckpt(self, state_dict: dict[str, Tensor]) -> None: + def copy_(m: nn.Conv2d | nn.Linear | nn.LayerNorm, prefix: str): + m.weight.copy_(state_dict.pop(prefix + ".weight")) + m.bias.copy_(state_dict.pop(prefix + ".bias")) + + copy_(self.stem[0], "downsample_layers.0.0") # Conv2d + copy_(self.stem[2], "downsample_layers.0.1") # LayerNorm + + for stage_idx, stage in enumerate(self.stages): + if stage_idx > 0: + copy_(stage[0][0], f"downsample_layers.{stage_idx}.0") # LayerNorm + copy_(stage[0][2], f"downsample_layers.{stage_idx}.1") # Conv2d + + for block_idx in range(1, len(stage)): + block: ConvNeXtBlock = stage[block_idx] + prefix = f"stages.{stage_idx}.{block_idx - 1}." + + copy_(block.layers[1], prefix + "dwconv") + copy_(block.layers[3], prefix + "norm") + copy_(block.layers[4], prefix + "pwconv1") + copy_(block.layers[6], prefix + "pwconv2") + block.layer_scale.copy_(state_dict.pop(prefix + "gamma")) + + copy_(self.head_norm, "norm") + assert len(state_dict) == 2 diff --git a/vision_toolbox/components.py b/vision_toolbox/components.py index f5c0cee..d813166 100644 --- a/vision_toolbox/components.py +++ b/vision_toolbox/components.py @@ -150,3 +150,31 @@ def forward(self, x: Tensor) -> Tensor: x = self.pool(x) outputs.append(x) return torch.cat(outputs, dim=1) + + +class Permute(nn.Module): + def __init__(self, *dims: int) -> None: + super().__init__() + self.dims = dims + + def forward(self, x: Tensor) -> Tensor: + return x.permute(self.dims) + + +# https://arxiv.org/pdf/1603.09382.pdf +class StochasticDepth(nn.Module): + def __init__(self, p: float) -> None: + assert 0.0 <= p <= 1.0 + super().__init__() + self.p = p + + def forward(self, x: Tensor) -> Tensor: + if not self.training or self.p == 0.0: + return x + + shape = [x.shape[0]] + [1] * (x.ndim - 1) + keep_p = 1.0 - self.p + return x * x.new_empty(shape).bernoulli_(keep_p).div_(keep_p) + + def extra_repr(self) -> str: + return f"p={self.p}"