Skip to content

Commit

Permalink
update constructor structure
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Jul 24, 2023
1 parent bd65bd7 commit 03a82fd
Showing 1 changed file with 32 additions and 24 deletions.
56 changes: 32 additions & 24 deletions vision_toolbox/backbones/darknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from __future__ import annotations

from typing import Callable
from typing import Callable, NamedTuple

import torch
from torch import Tensor, nn
Expand All @@ -14,6 +14,9 @@
from .base import BaseBackbone


_BASE_URL = "https://github.com/gau-nernst/vision-toolbox/releases/download/v0.0.1/"


class DarknetBlock(nn.Module):
def __init__(self, in_channels: int, expansion: float = 0.5) -> None:
super().__init__()
Expand Down Expand Up @@ -52,24 +55,30 @@ def forward(self, x: Tensor) -> Tensor:
return out


class DarknetStageConfig(NamedTuple):
n_blocks: int
out_channels: int


class Darknet(BaseBackbone):
def __init__(
self,
stem_channels: int,
n_blocks_list: list[int],
out_channels_list: list[int],
stage_configs: list[DarknetStageConfig],
stage_cls: Callable[..., nn.Module] = DarknetStage,
):
assert len(stage_configs) > 0
super().__init__()
self.out_channels_list = tuple(out_channels_list)
self.out_channels_list = tuple(cfg.out_channels for cfg in stage_configs)
self.stride = 32

self.stem = ConvNormAct(3, stem_channels)
self.stages = nn.ModuleList()
in_c = stem_channels
for n, c in zip(n_blocks_list, out_channels_list):
self.stages.append(stage_cls(n, in_c, c) if n > 0 else ConvNormAct(in_c, c, stride=2))
in_c = c
in_ch = stem_channels
for n_blocks, out_ch in stage_configs:
stage = stage_cls(n_blocks, in_ch, out_ch) if n_blocks else ConvNormAct(in_ch, out_ch, 3, 2)
self.stages.append(stage)
in_ch = out_ch

def get_feature_maps(self, x: Tensor) -> list[Tensor]:
outputs = [self.stem(x)]
Expand All @@ -84,25 +93,25 @@ def from_config(variant: str, pretrained: bool = False) -> Darknet:
darknet53=((1, 2, 8, 8, 4), DarknetStage, "darknet53-94427f5b.pth"), # YOLOv3
cspdarknet53=((1, 2, 8, 8, 4), CSPDarknetStage, "cspdarknet53-3bfa0423.pth"), # CSPNet/YOLOv4
)[variant]
m = Darknet(32, n_blocks_list, (64, 128, 256, 512, 1024), stage_cls)
stage_configs = list(map(DarknetStageConfig, n_blocks_list, (64, 128, 256, 512, 1024)))
m = Darknet(32, stage_configs, stage_cls)
if pretrained:
base_url = "https://github.com/gau-nernst/vision-toolbox/releases/download/v0.0.1/"
m._load_state_dict_from_url(base_url + ckpt)
m._load_state_dict_from_url(_BASE_URL + ckpt)
return m


class DarknetYOLOv5(BaseBackbone):
def __init__(self, stem_channels: int, n_blocks_list: list[int], out_channels_list: list[int]) -> None:
def __init__(self, stem_channels: int, stage_configs: list[DarknetStageConfig]) -> None:
super().__init__()
self.out_channels_list = (stem_channels,) + tuple(out_channels_list)
self.out_channels_list = (stem_channels,) + tuple(cfg.out_channels for cfg in stage_configs)
self.stride = 32

self.stem = ConvNormAct(3, stem_channels, 6, 2)
self.stages = nn.ModuleList()
in_c = stem_channels
for n, c in zip(n_blocks_list, out_channels_list):
self.stages.append(CSPDarknetStage(n, in_c, c))
in_c = c
in_ch = stem_channels
for n_blocks, out_ch in stage_configs:
self.stages.append(CSPDarknetStage(n_blocks, in_ch, out_ch))
in_ch = out_ch

def get_feature_maps(self, x: Tensor) -> list[Tensor]:
outputs = [self.stem(x)]
Expand All @@ -119,12 +128,11 @@ def from_config(variant: str, pretrained: bool = False) -> DarknetYOLOv5:
l=(1, 1, "darknet_yolov5l-8e25d388.pth"),
x=(4 / 3, 5 / 4, "darknet_yolov5x-0ed0c035.pth"),
)[variant]
m = DarknetYOLOv5(
int(64 * width_scale),
[int(d * depth_scale) for d in (3, 6, 9, 3)],
[int(w * width_scale) for w in (128, 256, 512, 1024)],
)
stage_configs = [
DarknetStageConfig(int(d * depth_scale), int(w * width_scale))
for d, w in zip((3, 6, 9, 3), (128, 256, 512, 1024))
]
m = DarknetYOLOv5(int(64 * width_scale), stage_configs)
if pretrained:
base_url = "https://github.com/gau-nernst/vision-toolbox/releases/download/v0.0.1/"
m._load_state_dict_from_url(base_url + ckpt)
m._load_state_dict_from_url(_BASE_URL + ckpt)
return m

0 comments on commit 03a82fd

Please sign in to comment.