Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ConvNeXt-V2 #18

Merged
merged 1 commit into from
Aug 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions tests/test_convnext.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
import pytest
import timm
import torch

from vision_toolbox.backbones import ConvNeXt


def test_forward():
m = ConvNeXt.from_config("T")
@pytest.mark.parametrize("v2", [False, True])
def test_forward(v2):
m = ConvNeXt.from_config("T", v2)
m(torch.randn(1, 3, 224, 224))


def test_from_pretrained():
m = ConvNeXt.from_config("T", True).eval()
@pytest.mark.parametrize("v2", [False, True])
def test_from_pretrained(v2):
m = ConvNeXt.from_config("T", v2, True).eval()
x = torch.randn(1, 3, 224, 224)
out = m(x)

m_timm = timm.create_model("convnext_tiny.fb_in22k", pretrained=True, num_classes=0).eval()
model_name = "convnextv2_tiny.fcmae" if v2 else "convnext_tiny.fb_in22k"
m_timm = timm.create_model(model_name, pretrained=True, num_classes=0).eval()
out_timm = m_timm(x)

torch.testing.assert_close(out, out_timm)
84 changes: 67 additions & 17 deletions vision_toolbox/backbones/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,20 @@
from .base import BaseBackbone, _act, _norm


class GlobalResponseNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6) -> None:
super().__init__()
self.gamma = nn.Parameter(torch.zeros(dim))
self.beta = nn.Parameter(torch.zeros(dim))
self.eps = eps

def forward(self, x: Tensor) -> Tensor:
# x: shape (B, H, W, C)
gx = torch.linalg.vector_norm(x, dim=(1, 2), keepdim=True) # (B, 1, 1, C)
nx = gx / gx.mean(-1, keepdim=True).add(self.eps)
return x + x * nx * self.gamma + self.beta


class ConvNeXtBlock(nn.Module):
def __init__(
self,
Expand All @@ -22,6 +36,7 @@ def __init__(
stochastic_depth: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
v2: bool = False,
) -> None:
super().__init__()
hidden_dim = int(d_model * expansion_ratio)
Expand All @@ -32,13 +47,19 @@ def __init__(
norm(d_model),
nn.Linear(d_model, hidden_dim, bias=bias),
act(),
GlobalResponseNorm(hidden_dim) if v2 else nn.Identity(),
nn.Linear(hidden_dim, d_model, bias=bias),
)
self.layer_scale = nn.Parameter(torch.full((d_model,), layer_scale_init)) if layer_scale_init > 0 else None
self.layer_scale = (
nn.Parameter(torch.full((d_model,), layer_scale_init)) if layer_scale_init > 0 and not v2 else None
)
self.drop = StochasticDepth(stochastic_depth)

def forward(self, x: Tensor) -> Tensor:
return x + self.drop(self.layers(x) * self.layer_scale)
out = self.layers(x)
if self.layer_scale is not None:
out = out * self.layer_scale
return x + self.drop(out)


class ConvNeXt(BaseBackbone):
Expand All @@ -52,6 +73,7 @@ def __init__(
stochastic_depth: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
v2: bool = False,
) -> None:
super().__init__()
self.stem = nn.Sequential(nn.Conv2d(3, d_model, 4, 4), Permute(0, 2, 3, 1), norm(d_model))
Expand All @@ -76,7 +98,7 @@ def __init__(

for block_idx in range(depth):
rate = stochastic_depth_rates[sum(depths[:stage_idx]) + block_idx]
block = ConvNeXtBlock(d_model, expansion_ratio, bias, layer_scale_init, rate, norm, act)
block = ConvNeXtBlock(d_model, expansion_ratio, bias, layer_scale_init, rate, norm, act, v2)
stage.append(block)

self.stages.append(stage)
Expand All @@ -93,26 +115,44 @@ def forward(self, x: Tensor) -> Tensor:
return self.head_norm(self.get_feature_maps(x)[-1].mean((1, 2)))

@staticmethod
def from_config(variant: str, pretrained: bool = False) -> ConvNeXt:
def from_config(variant: str, v2: bool = False, pretrained: bool = False) -> ConvNeXt:
d_model, depths = dict(
A=(40, (2, 2, 6, 2)),
F=(48, (2, 2, 6, 2)),
P=(64, (2, 2, 6, 2)),
N=(80, (2, 2, 8, 2)),
T=(96, (3, 3, 9, 3)),
S=(96, (3, 3, 27, 3)),
B=(128, (3, 3, 27, 3)),
L=(192, (3, 3, 27, 3)),
XL=(256, (3, 3, 27, 3)),
H=(352, (3, 3, 27, 3)),
)[variant]
m = ConvNeXt(d_model, depths)
m = ConvNeXt(d_model, depths, v2=v2)

if pretrained:
# TODO: also add torchvision checkpoints?
ckpt = dict(
T="convnext_tiny_22k_224.pth",
S="convnext_small_22k_224.pth",
B="convnext_base_22k_224.pth",
L="convnext_large_22k_224.pth",
XL="convnext_xlarge_22k_224.pth",
)[variant]
base_url = "https://dl.fbaipublicfiles.com/convnext/"
if not v2:
ckpt = dict(
T="convnext_tiny_22k_224.pth",
S="convnext_small_22k_224.pth",
B="convnext_base_22k_224.pth",
L="convnext_large_22k_224.pth",
XL="convnext_xlarge_22k_224.pth",
)[variant]
base_url = "https://dl.fbaipublicfiles.com/convnext/"
else:
ckpt = dict(
A="convnextv2_atto_1k_224_fcmae.pt",
F="convnextv2_femto_1k_224_fcmae.pt",
P="convnextv2_pico_1k_224_fcmae.pt",
N="convnextv2_nano_1k_224_fcmae.pt",
T="convnextv2_tiny_1k_224_fcmae.pt",
B="convnextv2_base_1k_224_fcmae.pt",
L="convnextv2_large_1k_224_fcmae.pt",
H="convnextv2_huge_1k_224_fcmae.pt",
)[variant]
base_url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/"
state_dict = torch.hub.load_state_dict_from_url(base_url + ckpt)["model"]
m.load_official_ckpt(state_dict)

Expand All @@ -139,8 +179,18 @@ def copy_(m: nn.Conv2d | nn.Linear | nn.LayerNorm, prefix: str):
copy_(block.layers[1], prefix + "dwconv")
copy_(block.layers[3], prefix + "norm")
copy_(block.layers[4], prefix + "pwconv1")
copy_(block.layers[6], prefix + "pwconv2")
block.layer_scale.copy_(state_dict.pop(prefix + "gamma"))

copy_(self.head_norm, "norm")
assert len(state_dict) == 2
if isinstance(block.layers[6], GlobalResponseNorm): # v2
block.layers[6].gamma.copy_(state_dict.pop(prefix + "grn.gamma").squeeze())
block.layers[6].beta.copy_(state_dict.pop(prefix + "grn.beta").squeeze())

copy_(block.layers[7], prefix + "pwconv2")
if block.layer_scale is not None:
block.layer_scale.copy_(state_dict.pop(prefix + "gamma"))

# FCMAE checkpoints don't contain head norm
if "norm.weight" in state_dict:
copy_(self.head_norm, "norm")
assert len(state_dict) == 2
else:
assert len(state_dict) == 0