Skip to content

Commit

Permalink
update VoVNet
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Jul 24, 2023
1 parent b224ff0 commit 6ca4bbf
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 136 deletions.
9 changes: 5 additions & 4 deletions tests/test_backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from torch import Tensor, nn

from vision_toolbox.backbones import Darknet, DarknetYOLOv5
from vision_toolbox.backbones import Darknet, DarknetYOLOv5, VoVNet


@pytest.fixture
Expand All @@ -22,12 +22,13 @@ def inputs():


def partial_list(fn, args_list):
return [partial(fn, x) for x in args_list]
return [partial(fn, *args) for args in args_list]


factory_list = [
*partial_list(Darknet.from_config, ("darknet19", "cspdarknet53")),
*partial_list(DarknetYOLOv5.from_config, ("n", "l")),
*partial_list(Darknet.from_config, (("darknet19",), ("cspdarknet53",))),
*partial_list(DarknetYOLOv5.from_config, (("n",), ("l",))),
*partial_list(VoVNet.from_config, ((27, True), (39,), (19, True, True), (57, False, True))),
]


Expand Down
2 changes: 1 addition & 1 deletion vision_toolbox/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from .patchconvnet import *
from .torchvision_models import *
from .vit import *
from .vovnet import *
from .vovnet import VoVNet
4 changes: 2 additions & 2 deletions vision_toolbox/backbones/darknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class DarknetYOLOv5(BaseBackbone):
def __init__(self, stem_channels: int, stage_configs: list[DarknetStageConfig | tuple[int, int]]) -> None:
super().__init__()
self.out_channels_list = (stem_channels,) + tuple(cfg[1] for cfg in stage_configs)
self.stride = 32
self.stride = 2 ** len(self.out_channels_list)

self.stem = ConvNormAct(3, stem_channels, 6, 2)
self.stages = nn.ModuleList()
Expand All @@ -125,7 +125,7 @@ def from_config(variant: str, pretrained: bool = False) -> DarknetYOLOv5:
n=(1 / 3, 1 / 4, "darknet_yolov5n-68f182f1.pth"),
s=(1 / 3, 1 / 2, "darknet_yolov5s-175f7462.pth"),
m=(2 / 3, 3 / 4, "darknet_yolov5m-9866aa40.pth"),
l=(1, 1, "darknet_yolov5l-8e25d388.pth"),
l=(1 / 1, 1 / 1, "darknet_yolov5l-8e25d388.pth"),
x=(4 / 3, 5 / 4, "darknet_yolov5x-0ed0c035.pth"),
)[variant]
stage_configs = [
Expand Down
186 changes: 57 additions & 129 deletions vision_toolbox/backbones/vovnet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# Papers:
# VoVNetV1: https://arxiv.org/abs/1904.09730
# VoVNetV2: https://arxiv.org/abs/1911.06667 (CenterMask)
# https://github.com/youngwanLEE/vovnet-detectron2/blob/master/vovnet/vovnet.py

from __future__ import annotations

from typing import NamedTuple

import torch
from torch import Tensor, nn
Expand All @@ -9,80 +14,7 @@
from .base import BaseBackbone


__all__ = [
"VoVNet",
"vovnet27_slim",
"vovnet39",
"vovnet57",
"vovnet19_slim_ese",
"vovnet19_ese",
"vovnet39_ese",
"vovnet57_ese",
"vovnet99_ese",
]

# https://github.com/youngwanLEE/vovnet-detectron2/blob/master/vovnet/vovnet.py
_base = dict(
stem_channels=128,
stage_channels_list=(128, 160, 192, 224),
out_channels_list=(256, 512, 768, 1024),
num_layers_list=(5, 5, 5, 5),
)
_slim = dict(
stage_channels_list=(64, 80, 96, 112),
out_channels_list=(128, 256, 384, 512),
)
configs = {
# VoVNetV1
"vovnet-27-slim": {
**_base,
**_slim,
"num_blocks_list": (1, 1, 1, 1),
"ese": False,
"weights": "https://github.com/gau-nernst/vision-toolbox/releases/download/v0.0.1/vovnet27_slim-dd43306a.pth",
},
"vovnet-39": dict(
**_base,
num_blocks_list=(1, 1, 2, 2),
ese=False,
weights="https://github.com/gau-nernst/vision-toolbox/releases/download/v0.0.1/vovnet39-4c79d629.pth",
),
"vovnet-57": dict(
**_base,
num_blocks_list=(1, 1, 4, 3),
ese=False,
weights="https://github.com/gau-nernst/vision-toolbox/releases/download/v0.0.1/vovnet57-ecb9cc34.pth",
),
# VoVNetV2
"vovnet-19-slim-ese": {
**_base,
**_slim,
"num_layers_list": (3, 3, 3, 3),
"num_blocks_list": (1, 1, 1, 1),
"weights": "https://github.com/gau-nernst/vision-toolbox/releases/download/v0.0.1/vovnet19_slim_ese-f8075640.pth",
},
"vovnet-19-ese": {
**_base,
"num_layers_list": (3, 3, 3, 3),
"num_blocks_list": (1, 1, 1, 1),
"weights": "https://github.com/gau-nernst/vision-toolbox/releases/download/v0.0.1/vovnet19_ese-a077657e.pth",
},
"vovnet-39-ese": dict(
**_base,
num_blocks_list=(1, 1, 2, 2),
weights="https://github.com/gau-nernst/vision-toolbox/releases/download/v0.0.1/vovnet39_ese-9ce81b0d.pth",
),
"vovnet-57-ese": dict(
**_base,
num_blocks_list=(1, 1, 4, 3),
weights="https://github.com/gau-nernst/vision-toolbox/releases/download/v0.0.1/vovnet57_ese-ae1a7f89.pth",
),
"vovnet-99-ese": dict(
**_base,
num_blocks_list=(1, 3, 9, 3),
weights="https://github.com/gau-nernst/vision-toolbox/releases/download/v0.0.1/vovnet99_ese-713f3062.pth",
),
}
_BASE_URL = "https://github.com/gau-nernst/vision-toolbox/releases/download/v0.0.1/"


class ESEBlock(nn.Module):
Expand All @@ -103,7 +35,6 @@ def __init__(
mid_channels: int,
num_layers: int,
out_channels: int,
residual: bool = True,
ese: bool = True,
) -> None:
super().__init__()
Expand All @@ -114,15 +45,12 @@ def __init__(
self.out_conv = ConvNormAct(concat_channels, out_channels, 1)

self.ese = ESEBlock(out_channels) if ese else None
self.residual = residual and (in_channels == out_channels)
self.residual = in_channels == out_channels

def forward(self, x: Tensor) -> Tensor:
outputs = []
out = x
outputs.append(out)
for conv_layer in self.convs:
out = conv_layer(out)
outputs.append(out)
outputs = [x]
for conv in self.convs:
outputs.append(conv(outputs[-1]))

out = torch.cat(outputs, dim=1)
out = self.out_conv(out)
Expand All @@ -135,38 +63,38 @@ def forward(self, x: Tensor) -> Tensor:
return out


class VoVNetStageConfig(NamedTuple):
n_blocks: int
mid_channels: int
n_layers: int
out_channels: int


class VoVNet(BaseBackbone):
def __init__(
self,
stem_channels: int,
num_blocks_list: list[int],
stage_channels_list: list[int],
num_layers_list: list[int],
out_channels_list: list[int],
residual: bool = True,
stage_configs: list[VoVNetStageConfig | tuple[int, int, int, int]],
ese: bool = True,
) -> None:
super().__init__()
self.out_channels_list = (stem_channels,) + tuple(out_channels_list)
self.stride = 32
self.out_channels_list = (stem_channels,) + tuple(cfg[3] for cfg in stage_configs)
self.stride = 2 ** len(self.out_channels_list)

self.stem = nn.Sequential(
ConvNormAct(3, stem_channels // 2, stride=2),
ConvNormAct(3, stem_channels // 2, 3, 2),
ConvNormAct(stem_channels // 2, stem_channels // 2),
ConvNormAct(stem_channels // 2, stem_channels),
)

self.stages = nn.ModuleList()
in_c = stem_channels
for n, stage_c, n_l, out_c in zip(num_blocks_list, stage_channels_list, num_layers_list, out_channels_list):
in_ch = stem_channels
for n_blocks, mid_ch, n_layers, out_ch in stage_configs:
stage = nn.Sequential()
stage.add_module("max_pool", nn.MaxPool2d(3, 2, 1))
for i in range(n):
stage.add_module(
f"module_{i}",
OSABlock(in_c, stage_c, n_l, out_c, residual=residual, ese=ese),
)
in_c = out_c
for i in range(n_blocks):
stage.add_module(f"module_{i}", OSABlock(in_ch, mid_ch, n_layers, out_ch, ese))
in_ch = out_ch
self.stages.append(stage)

def get_feature_maps(self, x: Tensor) -> list[Tensor]:
Expand All @@ -175,34 +103,34 @@ def get_feature_maps(self, x: Tensor) -> list[Tensor]:
outputs.append(s(outputs[-1]))
return outputs


def vovnet27_slim(pretrained=False, **kwargs):
return VoVNet.from_config(configs["vovnet-27-slim"], pretrained=pretrained, **kwargs)


def vovnet39(pretrained=False, **kwargs):
return VoVNet.from_config(configs["vovnet-39"], pretrained=pretrained, **kwargs)


def vovnet57(pretrained=False, **kwargs):
return VoVNet.from_config(configs["vovnet-57"], pretrained=pretrained, **kwargs)


def vovnet19_slim_ese(pretrained=False, **kwargs):
return VoVNet.from_config(configs["vovnet-19-slim-ese"], pretrained=pretrained, **kwargs)


def vovnet19_ese(pretrained=False, **kwargs):
return VoVNet.from_config(configs["vovnet-19-ese"], pretrained=pretrained, **kwargs)


def vovnet39_ese(pretrained=False, **kwargs):
return VoVNet.from_config(configs["vovnet-39-ese"], pretrained=pretrained, **kwargs)


def vovnet57_ese(pretrained=False, **kwargs):
return VoVNet.from_config(configs["vovnet-57-ese"], pretrained=pretrained, **kwargs)


def vovnet99_ese(pretrained=False, **kwargs):
return VoVNet.from_config(configs["vovnet-99-ese"], pretrained=pretrained, **kwargs)
@staticmethod
def from_config(variant: int, slim: bool = False, ese: bool = False, pretrained: bool = False) -> VoVNet:
stem_channels = 128
mid_channels_list = (64, 80, 96, 112) if slim else (128, 160, 192, 224)
out_channels_list = (128, 256, 384, 512) if slim else (256, 512, 768, 1024)
n_blocks_list, n_layers_list = {
19: ((1, 1, 1, 1), (3, 3, 3, 3)),
27: ((1, 1, 1, 1), (5, 5, 5, 5)),
39: ((1, 1, 2, 2), (5, 5, 5, 5)),
57: ((1, 1, 4, 3), (5, 5, 5, 5)),
99: ((1, 3, 9, 3), (5, 5, 5, 5)),
}[variant]
stage_configs = list(zip(n_blocks_list, mid_channels_list, n_layers_list, out_channels_list))
m = VoVNet(stem_channels, stage_configs, ese)

if pretrained:
ckpt = {
# VoVNetV1
(27, True, False): "vovnet27_slim-dd43306a.pth",
(39, False, False): "vovnet39-4c79d629.pth",
(57, False, False): "vovnet57-ecb9cc34.pth",
# VoVNetV2
(19, True, True): "vovnet19_slim_ese-f8075640.pth",
(19, False, True): "vovnet19_ese-a077657e.pth",
(39, False, True): "vovnet39_ese-9ce81b0d.pth",
(57, False, True): "vovnet57_ese-ae1a7f89.pth",
(99, False, True): "vovnet99_ese-713f3062.pth",
}[(variant, slim, ese)]
m._load_state_dict_from_url(_BASE_URL + ckpt)

return m

0 comments on commit 6ca4bbf

Please sign in to comment.