Skip to content

Commit

Permalink
add ConvNeXt-V2
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Aug 19, 2023
1 parent 8c1d2c4 commit a513569
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 22 deletions.
14 changes: 9 additions & 5 deletions tests/test_convnext.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
import pytest
import timm
import torch

from vision_toolbox.backbones import ConvNeXt


def test_forward():
m = ConvNeXt.from_config("T")
@pytest.mark.parametrize("v2", [False, True])
def test_forward(v2):
m = ConvNeXt.from_config("T", v2)
m(torch.randn(1, 3, 224, 224))


def test_from_pretrained():
m = ConvNeXt.from_config("T", True).eval()
@pytest.mark.parametrize("v2", [False, True])
def test_from_pretrained(v2):
m = ConvNeXt.from_config("T", v2, True).eval()
x = torch.randn(1, 3, 224, 224)
out = m(x)

m_timm = timm.create_model("convnext_tiny.fb_in22k", pretrained=True, num_classes=0).eval()
model_name = "convnextv2_tiny.fcmae" if v2 else "convnext_tiny.fb_in22k"
m_timm = timm.create_model(model_name, pretrained=True, num_classes=0).eval()
out_timm = m_timm(x)

torch.testing.assert_close(out, out_timm)
84 changes: 67 additions & 17 deletions vision_toolbox/backbones/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,20 @@
from .base import BaseBackbone, _act, _norm


class GlobalResponseNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6) -> None:
super().__init__()
self.gamma = nn.Parameter(torch.zeros(dim))
self.beta = nn.Parameter(torch.zeros(dim))
self.eps = eps

def forward(self, x: Tensor) -> Tensor:
# x: shape (B, H, W, C)
gx = torch.linalg.vector_norm(x, dim=(1, 2), keepdim=True) # (B, 1, 1, C)
nx = gx / gx.mean(-1, keepdim=True).add(self.eps)
return x + x * nx * self.gamma + self.beta


class ConvNeXtBlock(nn.Module):
def __init__(
self,
Expand All @@ -22,6 +36,7 @@ def __init__(
stochastic_depth: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
v2: bool = False,
) -> None:
super().__init__()
hidden_dim = int(d_model * expansion_ratio)
Expand All @@ -32,13 +47,19 @@ def __init__(
norm(d_model),
nn.Linear(d_model, hidden_dim, bias=bias),
act(),
GlobalResponseNorm(hidden_dim) if v2 else nn.Identity(),
nn.Linear(hidden_dim, d_model, bias=bias),
)
self.layer_scale = nn.Parameter(torch.full((d_model,), layer_scale_init)) if layer_scale_init > 0 else None
self.layer_scale = (
nn.Parameter(torch.full((d_model,), layer_scale_init)) if layer_scale_init > 0 and not v2 else None
)
self.drop = StochasticDepth(stochastic_depth)

def forward(self, x: Tensor) -> Tensor:
return x + self.drop(self.layers(x) * self.layer_scale)
out = self.layers(x)
if self.layer_scale is not None:
out = out * self.layer_scale
return x + self.drop(out)


class ConvNeXt(BaseBackbone):
Expand All @@ -52,6 +73,7 @@ def __init__(
stochastic_depth: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
v2: bool = False,
) -> None:
super().__init__()
self.stem = nn.Sequential(nn.Conv2d(3, d_model, 4, 4), Permute(0, 2, 3, 1), norm(d_model))
Expand All @@ -76,7 +98,7 @@ def __init__(

for block_idx in range(depth):
rate = stochastic_depth_rates[sum(depths[:stage_idx]) + block_idx]
block = ConvNeXtBlock(d_model, expansion_ratio, bias, layer_scale_init, rate, norm, act)
block = ConvNeXtBlock(d_model, expansion_ratio, bias, layer_scale_init, rate, norm, act, v2)
stage.append(block)

self.stages.append(stage)
Expand All @@ -93,26 +115,44 @@ def forward(self, x: Tensor) -> Tensor:
return self.head_norm(self.get_feature_maps(x)[-1].mean((1, 2)))

@staticmethod
def from_config(variant: str, pretrained: bool = False) -> ConvNeXt:
def from_config(variant: str, v2: bool = False, pretrained: bool = False) -> ConvNeXt:
d_model, depths = dict(
A=(40, (2, 2, 6, 2)),
F=(48, (2, 2, 6, 2)),
P=(64, (2, 2, 6, 2)),
N=(80, (2, 2, 8, 2)),
T=(96, (3, 3, 9, 3)),
S=(96, (3, 3, 27, 3)),
B=(128, (3, 3, 27, 3)),
L=(192, (3, 3, 27, 3)),
XL=(256, (3, 3, 27, 3)),
H=(352, (3, 3, 27, 3)),
)[variant]
m = ConvNeXt(d_model, depths)
m = ConvNeXt(d_model, depths, v2=v2)

if pretrained:
# TODO: also add torchvision checkpoints?
ckpt = dict(
T="convnext_tiny_22k_224.pth",
S="convnext_small_22k_224.pth",
B="convnext_base_22k_224.pth",
L="convnext_large_22k_224.pth",
XL="convnext_xlarge_22k_224.pth",
)[variant]
base_url = "https://dl.fbaipublicfiles.com/convnext/"
if not v2:
ckpt = dict(
T="convnext_tiny_22k_224.pth",
S="convnext_small_22k_224.pth",
B="convnext_base_22k_224.pth",
L="convnext_large_22k_224.pth",
XL="convnext_xlarge_22k_224.pth",
)[variant]
base_url = "https://dl.fbaipublicfiles.com/convnext/"
else:
ckpt = dict(
A="convnextv2_atto_1k_224_fcmae.pt",
F="convnextv2_femto_1k_224_fcmae.pt",
P="convnextv2_pico_1k_224_fcmae.pt",
N="convnextv2_nano_1k_224_fcmae.pt",
T="convnextv2_tiny_1k_224_fcmae.pt",
B="convnextv2_base_1k_224_fcmae.pt",
L="convnextv2_large_1k_224_fcmae.pt",
H="convnextv2_huge_1k_224_fcmae.pt",
)[variant]
base_url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/"
state_dict = torch.hub.load_state_dict_from_url(base_url + ckpt)["model"]
m.load_official_ckpt(state_dict)

Expand All @@ -139,8 +179,18 @@ def copy_(m: nn.Conv2d | nn.Linear | nn.LayerNorm, prefix: str):
copy_(block.layers[1], prefix + "dwconv")
copy_(block.layers[3], prefix + "norm")
copy_(block.layers[4], prefix + "pwconv1")
copy_(block.layers[6], prefix + "pwconv2")
block.layer_scale.copy_(state_dict.pop(prefix + "gamma"))

copy_(self.head_norm, "norm")
assert len(state_dict) == 2
if isinstance(block.layers[6], GlobalResponseNorm): # v2
block.layers[6].gamma.copy_(state_dict.pop(prefix + "grn.gamma").squeeze())
block.layers[6].beta.copy_(state_dict.pop(prefix + "grn.beta").squeeze())

copy_(block.layers[7], prefix + "pwconv2")
if block.layer_scale is not None:
block.layer_scale.copy_(state_dict.pop(prefix + "gamma"))

# FCMAE checkpoints don't contain head norm
if "norm.weight" in state_dict:
copy_(self.head_norm, "norm")
assert len(state_dict) == 2
else:
assert len(state_dict) == 0

0 comments on commit a513569

Please sign in to comment.