Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ConvNeXt #17

Merged
merged 4 commits into from
Aug 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions tests/test_convnext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import timm
import torch

from vision_toolbox.backbones import ConvNeXt


def test_forward():
m = ConvNeXt.from_config("T")
m(torch.randn(1, 3, 224, 224))


def test_from_pretrained():
m = ConvNeXt.from_config("T", True).eval()
x = torch.randn(1, 3, 224, 224)
out = m(x)

m_timm = timm.create_model("convnext_tiny.fb_in22k", pretrained=True, num_classes=0).eval()
out_timm = m_timm(x)

torch.testing.assert_close(out, out_timm)
1 change: 1 addition & 0 deletions vision_toolbox/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .convnext import ConvNeXt
from .darknet import Darknet, DarknetYOLOv5
from .mlp_mixer import MLPMixer
from .patchconvnet import PatchConvNet
Expand Down
146 changes: 146 additions & 0 deletions vision_toolbox/backbones/convnext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# https://arxiv.org/abs/2201.03545
# https://github.com/facebookresearch/ConvNeXt

from __future__ import annotations

from functools import partial

import torch
from torch import Tensor, nn

from ..components import Permute, StochasticDepth
from .base import BaseBackbone, _act, _norm


class ConvNeXtBlock(nn.Module):
def __init__(
self,
d_model: int,
expansion_ratio: float = 4.0,
bias: bool = True,
layer_scale_init: float = 1e-6,
stochastic_depth: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
) -> None:
super().__init__()
hidden_dim = int(d_model * expansion_ratio)
self.layers = nn.Sequential(
Permute(0, 3, 1, 2),
nn.Conv2d(d_model, d_model, 7, padding=3, groups=d_model, bias=bias),
Permute(0, 2, 3, 1),
norm(d_model),
nn.Linear(d_model, hidden_dim, bias=bias),
act(),
nn.Linear(hidden_dim, d_model, bias=bias),
)
self.layer_scale = nn.Parameter(torch.full((d_model,), layer_scale_init)) if layer_scale_init > 0 else None
self.drop = StochasticDepth(stochastic_depth)

def forward(self, x: Tensor) -> Tensor:
return x + self.drop(self.layers(x) * self.layer_scale)


class ConvNeXt(BaseBackbone):
def __init__(
self,
d_model: int,
depths: tuple[int, ...],
expansion_ratio: float = 4.0,
bias: bool = True,
layer_scale_init: float = 1e-6,
stochastic_depth: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
) -> None:
super().__init__()
self.stem = nn.Sequential(nn.Conv2d(3, d_model, 4, 4), Permute(0, 2, 3, 1), norm(d_model))

stochastic_depth_rates = torch.linspace(0, stochastic_depth, sum(depths))
self.stages = nn.Sequential()

for stage_idx, depth in enumerate(depths):
stage = nn.Sequential()
if stage_idx > 0:
# equivalent to PatchMerging in SwinTransformer
downsample = nn.Sequential(
norm(d_model),
Permute(0, 3, 1, 2),
nn.Conv2d(d_model, d_model * 2, 2, 2),
Permute(0, 2, 3, 1),
)
d_model *= 2
else:
downsample = nn.Identity()
stage.append(downsample)

for block_idx in range(depth):
rate = stochastic_depth_rates[sum(depths[:stage_idx]) + block_idx]
block = ConvNeXtBlock(d_model, expansion_ratio, bias, layer_scale_init, rate, norm, act)
stage.append(block)

self.stages.append(stage)

self.head_norm = norm(d_model)

def get_feature_maps(self, x: Tensor) -> list[Tensor]:
out = [self.stem(x)]
for stage in self.stages:
out.append(stage(out[-1]))
return out[-1:]

def forward(self, x: Tensor) -> Tensor:
return self.head_norm(self.get_feature_maps(x)[-1].mean((1, 2)))

@staticmethod
def from_config(variant: str, pretrained: bool = False) -> ConvNeXt:
d_model, depths = dict(
T=(96, (3, 3, 9, 3)),
S=(96, (3, 3, 27, 3)),
B=(128, (3, 3, 27, 3)),
L=(192, (3, 3, 27, 3)),
XL=(256, (3, 3, 27, 3)),
)[variant]
m = ConvNeXt(d_model, depths)

if pretrained:
# TODO: also add torchvision checkpoints?
ckpt = dict(
T="convnext_tiny_22k_224.pth",
S="convnext_small_22k_224.pth",
B="convnext_base_22k_224.pth",
L="convnext_large_22k_224.pth",
XL="convnext_xlarge_22k_224.pth",
)[variant]
base_url = "https://dl.fbaipublicfiles.com/convnext/"
state_dict = torch.hub.load_state_dict_from_url(base_url + ckpt)["model"]
m.load_official_ckpt(state_dict)

return m

@torch.no_grad()
def load_official_ckpt(self, state_dict: dict[str, Tensor]) -> None:
def copy_(m: nn.Conv2d | nn.Linear | nn.LayerNorm, prefix: str):
m.weight.copy_(state_dict.pop(prefix + ".weight"))
m.bias.copy_(state_dict.pop(prefix + ".bias"))

copy_(self.stem[0], "downsample_layers.0.0") # Conv2d
copy_(self.stem[2], "downsample_layers.0.1") # LayerNorm

for stage_idx, stage in enumerate(self.stages):
if stage_idx > 0:
copy_(stage[0][0], f"downsample_layers.{stage_idx}.0") # LayerNorm
copy_(stage[0][2], f"downsample_layers.{stage_idx}.1") # Conv2d

for block_idx in range(1, len(stage)):
block: ConvNeXtBlock = stage[block_idx]
prefix = f"stages.{stage_idx}.{block_idx - 1}."

copy_(block.layers[1], prefix + "dwconv")
copy_(block.layers[3], prefix + "norm")
copy_(block.layers[4], prefix + "pwconv1")
copy_(block.layers[6], prefix + "pwconv2")
block.layer_scale.copy_(state_dict.pop(prefix + "gamma"))

copy_(self.head_norm, "norm")
assert len(state_dict) == 2
28 changes: 28 additions & 0 deletions vision_toolbox/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,31 @@ def forward(self, x: Tensor) -> Tensor:
x = self.pool(x)
outputs.append(x)
return torch.cat(outputs, dim=1)


class Permute(nn.Module):
def __init__(self, *dims: int) -> None:
super().__init__()
self.dims = dims

def forward(self, x: Tensor) -> Tensor:
return x.permute(self.dims)


# https://arxiv.org/pdf/1603.09382.pdf
class StochasticDepth(nn.Module):
def __init__(self, p: float) -> None:
assert 0.0 <= p <= 1.0
super().__init__()
self.p = p

def forward(self, x: Tensor) -> Tensor:
if not self.training or self.p == 0.0:
return x

shape = [x.shape[0]] + [1] * (x.ndim - 1)
keep_p = 1.0 - self.p
return x * x.new_empty(shape).bernoulli_(keep_p).div_(keep_p)

def extra_repr(self) -> str:
return f"p={self.p}"