Skip to content

Commit

Permalink
add initial convnext
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Aug 19, 2023
1 parent 788bfe8 commit 6205cc3
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 0 deletions.
20 changes: 20 additions & 0 deletions tests/test_convnext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import timm
import torch

from vision_toolbox.backbones import ConvNeXt


def test_forward():
m = ConvNeXt.from_config("T")
m(torch.randn(1, 3, 224, 224))


# def test_from_pretrained():
# m = ViT.from_config("Ti", 16, 224, True).eval()
# x = torch.randn(1, 3, 224, 224)
# out = m(x)

# m_timm = timm.create_model("vit_tiny_patch16_224.augreg_in21k", pretrained=True, num_classes=0).eval()
# out_timm = m_timm(x)

# torch.testing.assert_close(out, out_timm, rtol=2e-5, atol=2e-5)
1 change: 1 addition & 0 deletions vision_toolbox/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .convnext import ConvNeXt
from .darknet import Darknet, DarknetYOLOv5
from .mlp_mixer import MLPMixer
from .patchconvnet import PatchConvNet
Expand Down
105 changes: 105 additions & 0 deletions vision_toolbox/backbones/convnext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# https://arxiv.org/abs/2201.03545

from __future__ import annotations

from functools import partial

import torch
from torch import Tensor, nn

from ..components import Permute, StochasticDepth
from .base import BaseBackbone, _act, _norm


class ConvNeXtBlock(nn.Module):
def __init__(
self,
d_model: int,
expansion_ratio: float = 4.0,
bias: bool = True,
layer_scale_init: float = 1e-6,
stochastic_depth: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
) -> None:
super().__init__()
hidden_dim = int(d_model * expansion_ratio)
self.layers = nn.Sequential(
Permute(0, 3, 1, 2),
nn.Conv2d(d_model, d_model, 7, padding=3, groups=d_model, bias=bias),
Permute(0, 2, 3, 1),
norm(d_model),
nn.Linear(d_model, hidden_dim, bias=bias),
act(),
nn.Linear(hidden_dim, d_model, bias=bias),
)
self.layer_scale = nn.Parameter(torch.full((d_model,), layer_scale_init)) if layer_scale_init > 0 else None
self.drop = StochasticDepth(stochastic_depth)

def forward(self, x: Tensor) -> Tensor:
return x + self.drop(self.layers(x) * self.layer_scale)


class ConvNeXt(BaseBackbone):
def __init__(
self,
d_model: int,
depths: tuple[int, ...],
expansion_ratio: float = 4.0,
bias: bool = True,
layer_scale_init: float = 1e-6,
stochastic_depth: float = 0.0,
norm: _norm = partial(nn.LayerNorm, eps=1e-6),
act: _act = nn.GELU,
) -> None:
super().__init__()
self.stem = nn.Sequential(nn.Conv2d(3, d_model, 4, 4), Permute(0, 2, 3, 1), norm(d_model))

self.stages = nn.Sequential()
for stage_idx, depth in enumerate(depths):
stage = nn.Sequential()
if stage_idx > 0:
# equivalent to PatchMerging in SwinTransformer
downsample = nn.Sequential(
norm(d_model),
Permute(0, 3, 1, 2),
nn.Conv2d(d_model, d_model * 2, 2, 2),
Permute(0, 2, 3, 1),
)
d_model *= 2
else:
downsample = nn.Identity()
stage.append(downsample)

for _ in range(depth):
block = ConvNeXtBlock(d_model, expansion_ratio, bias, layer_scale_init, stochastic_depth, norm, act)
stage.append(block)

self.stages.append(stage)

self.head_norm = norm(d_model)

def get_feature_maps(self, x: Tensor) -> list[Tensor]:
out = [self.stem(x)]
for stage in self.stages:
out.append(stage(out[-1]))
return out[-1:]

def forward(self, x: Tensor) -> Tensor:
return self.head_norm(self.get_feature_maps(x)[-1].mean((1, 2)))

@staticmethod
def from_config(variant: str, pretrained: bool = False) -> ConvNeXt:
d_model, depths = dict(
T=(96, (3, 3, 9, 3)),
S=(96, (3, 3, 27, 3)),
B=(128, (3, 3, 27, 3)),
L=(192, (3, 3, 27, 3)),
XL=(256, (3, 3, 27, 3)),
)[variant]
m = ConvNeXt(d_model, depths)

if pretrained:
pass

return m
28 changes: 28 additions & 0 deletions vision_toolbox/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,31 @@ def forward(self, x: Tensor) -> Tensor:
x = self.pool(x)
outputs.append(x)
return torch.cat(outputs, dim=1)


class Permute(nn.Module):
def __init__(self, *dims: int) -> None:
super().__init__()
self.dims = dims

def forward(self, x: Tensor) -> Tensor:
return x.permute(self.dims)


# https://arxiv.org/pdf/1603.09382.pdf
class StochasticDepth(nn.Module):
def __init__(self, p: float) -> None:
assert 0.0 <= p <= 1.0
super().__init__()
self.p = p

def forward(self, x: Tensor) -> Tensor:
if not self.training or self.p == 0.0:
return x

shape = [x.shape[0]] + [1] * (x.ndim - 1)
keep_p = 1.0 - self.p
return x * x.new_empty(shape).bernoulli_(keep_p).div_(keep_p)

def extra_repr(self) -> str:
return f"p={self.p}"

0 comments on commit 6205cc3

Please sign in to comment.