diff --git a/tests/test_convnext.py b/tests/test_convnext.py index c516420..af31085 100644 --- a/tests/test_convnext.py +++ b/tests/test_convnext.py @@ -1,20 +1,24 @@ +import pytest import timm import torch from vision_toolbox.backbones import ConvNeXt -def test_forward(): - m = ConvNeXt.from_config("T") +@pytest.mark.parametrize("v2", [False, True]) +def test_forward(v2): + m = ConvNeXt.from_config("T", v2) m(torch.randn(1, 3, 224, 224)) -def test_from_pretrained(): - m = ConvNeXt.from_config("T", True).eval() +@pytest.mark.parametrize("v2", [False, True]) +def test_from_pretrained(v2): + m = ConvNeXt.from_config("T", v2, True).eval() x = torch.randn(1, 3, 224, 224) out = m(x) - m_timm = timm.create_model("convnext_tiny.fb_in22k", pretrained=True, num_classes=0).eval() + model_name = "convnextv2_tiny.fcmae" if v2 else "convnext_tiny.fb_in22k" + m_timm = timm.create_model(model_name, pretrained=True, num_classes=0).eval() out_timm = m_timm(x) torch.testing.assert_close(out, out_timm) diff --git a/vision_toolbox/backbones/convnext.py b/vision_toolbox/backbones/convnext.py index 86e346e..de5b81d 100644 --- a/vision_toolbox/backbones/convnext.py +++ b/vision_toolbox/backbones/convnext.py @@ -12,6 +12,20 @@ from .base import BaseBackbone, _act, _norm +class GlobalResponseNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6) -> None: + super().__init__() + self.gamma = nn.Parameter(torch.zeros(dim)) + self.beta = nn.Parameter(torch.zeros(dim)) + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + # x: shape (B, H, W, C) + gx = torch.linalg.vector_norm(x, dim=(1, 2), keepdim=True) # (B, 1, 1, C) + nx = gx / gx.mean(-1, keepdim=True).add(self.eps) + return x + x * nx * self.gamma + self.beta + + class ConvNeXtBlock(nn.Module): def __init__( self, @@ -22,6 +36,7 @@ def __init__( stochastic_depth: float = 0.0, norm: _norm = partial(nn.LayerNorm, eps=1e-6), act: _act = nn.GELU, + v2: bool = False, ) -> None: super().__init__() hidden_dim = int(d_model * expansion_ratio) @@ -32,13 +47,19 @@ def __init__( norm(d_model), nn.Linear(d_model, hidden_dim, bias=bias), act(), + GlobalResponseNorm(hidden_dim) if v2 else nn.Identity(), nn.Linear(hidden_dim, d_model, bias=bias), ) - self.layer_scale = nn.Parameter(torch.full((d_model,), layer_scale_init)) if layer_scale_init > 0 else None + self.layer_scale = ( + nn.Parameter(torch.full((d_model,), layer_scale_init)) if layer_scale_init > 0 and not v2 else None + ) self.drop = StochasticDepth(stochastic_depth) def forward(self, x: Tensor) -> Tensor: - return x + self.drop(self.layers(x) * self.layer_scale) + out = self.layers(x) + if self.layer_scale is not None: + out = out * self.layer_scale + return x + self.drop(out) class ConvNeXt(BaseBackbone): @@ -52,6 +73,7 @@ def __init__( stochastic_depth: float = 0.0, norm: _norm = partial(nn.LayerNorm, eps=1e-6), act: _act = nn.GELU, + v2: bool = False, ) -> None: super().__init__() self.stem = nn.Sequential(nn.Conv2d(3, d_model, 4, 4), Permute(0, 2, 3, 1), norm(d_model)) @@ -76,7 +98,7 @@ def __init__( for block_idx in range(depth): rate = stochastic_depth_rates[sum(depths[:stage_idx]) + block_idx] - block = ConvNeXtBlock(d_model, expansion_ratio, bias, layer_scale_init, rate, norm, act) + block = ConvNeXtBlock(d_model, expansion_ratio, bias, layer_scale_init, rate, norm, act, v2) stage.append(block) self.stages.append(stage) @@ -93,26 +115,44 @@ def forward(self, x: Tensor) -> Tensor: return self.head_norm(self.get_feature_maps(x)[-1].mean((1, 2))) @staticmethod - def from_config(variant: str, pretrained: bool = False) -> ConvNeXt: + def from_config(variant: str, v2: bool = False, pretrained: bool = False) -> ConvNeXt: d_model, depths = dict( + A=(40, (2, 2, 6, 2)), + F=(48, (2, 2, 6, 2)), + P=(64, (2, 2, 6, 2)), + N=(80, (2, 2, 8, 2)), T=(96, (3, 3, 9, 3)), S=(96, (3, 3, 27, 3)), B=(128, (3, 3, 27, 3)), L=(192, (3, 3, 27, 3)), XL=(256, (3, 3, 27, 3)), + H=(352, (3, 3, 27, 3)), )[variant] - m = ConvNeXt(d_model, depths) + m = ConvNeXt(d_model, depths, v2=v2) if pretrained: # TODO: also add torchvision checkpoints? - ckpt = dict( - T="convnext_tiny_22k_224.pth", - S="convnext_small_22k_224.pth", - B="convnext_base_22k_224.pth", - L="convnext_large_22k_224.pth", - XL="convnext_xlarge_22k_224.pth", - )[variant] - base_url = "https://dl.fbaipublicfiles.com/convnext/" + if not v2: + ckpt = dict( + T="convnext_tiny_22k_224.pth", + S="convnext_small_22k_224.pth", + B="convnext_base_22k_224.pth", + L="convnext_large_22k_224.pth", + XL="convnext_xlarge_22k_224.pth", + )[variant] + base_url = "https://dl.fbaipublicfiles.com/convnext/" + else: + ckpt = dict( + A="convnextv2_atto_1k_224_fcmae.pt", + F="convnextv2_femto_1k_224_fcmae.pt", + P="convnextv2_pico_1k_224_fcmae.pt", + N="convnextv2_nano_1k_224_fcmae.pt", + T="convnextv2_tiny_1k_224_fcmae.pt", + B="convnextv2_base_1k_224_fcmae.pt", + L="convnextv2_large_1k_224_fcmae.pt", + H="convnextv2_huge_1k_224_fcmae.pt", + )[variant] + base_url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/" state_dict = torch.hub.load_state_dict_from_url(base_url + ckpt)["model"] m.load_official_ckpt(state_dict) @@ -139,8 +179,18 @@ def copy_(m: nn.Conv2d | nn.Linear | nn.LayerNorm, prefix: str): copy_(block.layers[1], prefix + "dwconv") copy_(block.layers[3], prefix + "norm") copy_(block.layers[4], prefix + "pwconv1") - copy_(block.layers[6], prefix + "pwconv2") - block.layer_scale.copy_(state_dict.pop(prefix + "gamma")) - copy_(self.head_norm, "norm") - assert len(state_dict) == 2 + if isinstance(block.layers[6], GlobalResponseNorm): # v2 + block.layers[6].gamma.copy_(state_dict.pop(prefix + "grn.gamma").squeeze()) + block.layers[6].beta.copy_(state_dict.pop(prefix + "grn.beta").squeeze()) + + copy_(block.layers[7], prefix + "pwconv2") + if block.layer_scale is not None: + block.layer_scale.copy_(state_dict.pop(prefix + "gamma")) + + # FCMAE checkpoints don't contain head norm + if "norm.weight" in state_dict: + copy_(self.head_norm, "norm") + assert len(state_dict) == 2 + else: + assert len(state_dict) == 0