Skip to content

Commit

Permalink
Add ConvNeXt (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Aug 19, 2023
1 parent 788bfe8 commit 8c1d2c4
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 0 deletions.
20 changes: 20 additions & 0 deletions tests/test_convnext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import timm
import torch

from vision_toolbox.backbones import ConvNeXt


def test_forward():
m = ConvNeXt.from_config("T")
m(torch.randn(1, 3, 224, 224))


def test_from_pretrained():
m = ConvNeXt.from_config("T", True).eval()
x = torch.randn(1, 3, 224, 224)
out = m(x)

m_timm = timm.create_model("convnext_tiny.fb_in22k", pretrained=True, num_classes=0).eval()
out_timm = m_timm(x)

torch.testing.assert_close(out, out_timm)
1 change: 1 addition & 0 deletions vision_toolbox/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .convnext import ConvNeXt
from .darknet import Darknet, DarknetYOLOv5
from .mlp_mixer import MLPMixer
from .patchconvnet import PatchConvNet
Expand Down
146 changes: 146 additions & 0 deletions vision_toolbox/backbones/convnext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# https://arxiv.org/abs/2201.03545
# https://github.com/facebookresearch/ConvNeXt

from __future__ import annotations

from functools import partial

import torch
from torch import Tensor, nn

from ..components import Permute, StochasticDepth
from .base import BaseBackbone, _act, _norm


class ConvNeXtBlock(nn.Module):
def __init__(
self,
d_model: int,
expansion_ratio: float = 4.0,
bias: bool = True,
layer_scale_init: float = 1e-6,
stochastic_depth: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
) -> None:
super().__init__()
hidden_dim = int(d_model * expansion_ratio)
self.layers = nn.Sequential(
Permute(0, 3, 1, 2),
nn.Conv2d(d_model, d_model, 7, padding=3, groups=d_model, bias=bias),
Permute(0, 2, 3, 1),
norm(d_model),
nn.Linear(d_model, hidden_dim, bias=bias),
act(),
nn.Linear(hidden_dim, d_model, bias=bias),
)
self.layer_scale = nn.Parameter(torch.full((d_model,), layer_scale_init)) if layer_scale_init > 0 else None
self.drop = StochasticDepth(stochastic_depth)

def forward(self, x: Tensor) -> Tensor:
return x + self.drop(self.layers(x) * self.layer_scale)


class ConvNeXt(BaseBackbone):
def __init__(
self,
d_model: int,
depths: tuple[int, ...],
expansion_ratio: float = 4.0,
bias: bool = True,
layer_scale_init: float = 1e-6,
stochastic_depth: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
) -> None:
super().__init__()
self.stem = nn.Sequential(nn.Conv2d(3, d_model, 4, 4), Permute(0, 2, 3, 1), norm(d_model))

stochastic_depth_rates = torch.linspace(0, stochastic_depth, sum(depths))
self.stages = nn.Sequential()

for stage_idx, depth in enumerate(depths):
stage = nn.Sequential()
if stage_idx > 0:
# equivalent to PatchMerging in SwinTransformer
downsample = nn.Sequential(
norm(d_model),
Permute(0, 3, 1, 2),
nn.Conv2d(d_model, d_model * 2, 2, 2),
Permute(0, 2, 3, 1),
)
d_model *= 2
else:
downsample = nn.Identity()
stage.append(downsample)

for block_idx in range(depth):
rate = stochastic_depth_rates[sum(depths[:stage_idx]) + block_idx]
block = ConvNeXtBlock(d_model, expansion_ratio, bias, layer_scale_init, rate, norm, act)
stage.append(block)

self.stages.append(stage)

self.head_norm = norm(d_model)

def get_feature_maps(self, x: Tensor) -> list[Tensor]:
out = [self.stem(x)]
for stage in self.stages:
out.append(stage(out[-1]))
return out[-1:]

def forward(self, x: Tensor) -> Tensor:
return self.head_norm(self.get_feature_maps(x)[-1].mean((1, 2)))

@staticmethod
def from_config(variant: str, pretrained: bool = False) -> ConvNeXt:
d_model, depths = dict(
T=(96, (3, 3, 9, 3)),
S=(96, (3, 3, 27, 3)),
B=(128, (3, 3, 27, 3)),
L=(192, (3, 3, 27, 3)),
XL=(256, (3, 3, 27, 3)),
)[variant]
m = ConvNeXt(d_model, depths)

if pretrained:
# TODO: also add torchvision checkpoints?
ckpt = dict(
T="convnext_tiny_22k_224.pth",
S="convnext_small_22k_224.pth",
B="convnext_base_22k_224.pth",
L="convnext_large_22k_224.pth",
XL="convnext_xlarge_22k_224.pth",
)[variant]
base_url = "https://dl.fbaipublicfiles.com/convnext/"
state_dict = torch.hub.load_state_dict_from_url(base_url + ckpt)["model"]
m.load_official_ckpt(state_dict)

return m

@torch.no_grad()
def load_official_ckpt(self, state_dict: dict[str, Tensor]) -> None:
def copy_(m: nn.Conv2d | nn.Linear | nn.LayerNorm, prefix: str):
m.weight.copy_(state_dict.pop(prefix + ".weight"))
m.bias.copy_(state_dict.pop(prefix + ".bias"))

copy_(self.stem[0], "downsample_layers.0.0") # Conv2d
copy_(self.stem[2], "downsample_layers.0.1") # LayerNorm

for stage_idx, stage in enumerate(self.stages):
if stage_idx > 0:
copy_(stage[0][0], f"downsample_layers.{stage_idx}.0") # LayerNorm
copy_(stage[0][2], f"downsample_layers.{stage_idx}.1") # Conv2d

for block_idx in range(1, len(stage)):
block: ConvNeXtBlock = stage[block_idx]
prefix = f"stages.{stage_idx}.{block_idx - 1}."

copy_(block.layers[1], prefix + "dwconv")
copy_(block.layers[3], prefix + "norm")
copy_(block.layers[4], prefix + "pwconv1")
copy_(block.layers[6], prefix + "pwconv2")
block.layer_scale.copy_(state_dict.pop(prefix + "gamma"))

copy_(self.head_norm, "norm")
assert len(state_dict) == 2
28 changes: 28 additions & 0 deletions vision_toolbox/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,31 @@ def forward(self, x: Tensor) -> Tensor:
x = self.pool(x)
outputs.append(x)
return torch.cat(outputs, dim=1)


class Permute(nn.Module):
def __init__(self, *dims: int) -> None:
super().__init__()
self.dims = dims

def forward(self, x: Tensor) -> Tensor:
return x.permute(self.dims)


# https://arxiv.org/pdf/1603.09382.pdf
class StochasticDepth(nn.Module):
def __init__(self, p: float) -> None:
assert 0.0 <= p <= 1.0
super().__init__()
self.p = p

def forward(self, x: Tensor) -> Tensor:
if not self.training or self.p == 0.0:
return x

shape = [x.shape[0]] + [1] * (x.ndim - 1)
keep_p = 1.0 - self.p
return x * x.new_empty(shape).bernoulli_(keep_p).div_(keep_p)

def extra_repr(self) -> str:
return f"p={self.p}"

0 comments on commit 8c1d2c4

Please sign in to comment.