Skip to content

Commit

Permalink
duck typing
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Jul 24, 2023
1 parent 03a82fd commit b224ff0
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions vision_toolbox/backbones/darknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@ class Darknet(BaseBackbone):
def __init__(
self,
stem_channels: int,
stage_configs: list[DarknetStageConfig],
stage_configs: list[DarknetStageConfig | tuple[int, int]],
stage_cls: Callable[..., nn.Module] = DarknetStage,
):
assert len(stage_configs) > 0
super().__init__()
self.out_channels_list = tuple(cfg.out_channels for cfg in stage_configs)
self.out_channels_list = tuple(cfg[1] for cfg in stage_configs)
self.stride = 32

self.stem = ConvNormAct(3, stem_channels)
Expand All @@ -93,17 +93,17 @@ def from_config(variant: str, pretrained: bool = False) -> Darknet:
darknet53=((1, 2, 8, 8, 4), DarknetStage, "darknet53-94427f5b.pth"), # YOLOv3
cspdarknet53=((1, 2, 8, 8, 4), CSPDarknetStage, "cspdarknet53-3bfa0423.pth"), # CSPNet/YOLOv4
)[variant]
stage_configs = list(map(DarknetStageConfig, n_blocks_list, (64, 128, 256, 512, 1024)))
stage_configs = list(zip(n_blocks_list, (64, 128, 256, 512, 1024)))
m = Darknet(32, stage_configs, stage_cls)
if pretrained:
m._load_state_dict_from_url(_BASE_URL + ckpt)
return m


class DarknetYOLOv5(BaseBackbone):
def __init__(self, stem_channels: int, stage_configs: list[DarknetStageConfig]) -> None:
def __init__(self, stem_channels: int, stage_configs: list[DarknetStageConfig | tuple[int, int]]) -> None:
super().__init__()
self.out_channels_list = (stem_channels,) + tuple(cfg.out_channels for cfg in stage_configs)
self.out_channels_list = (stem_channels,) + tuple(cfg[1] for cfg in stage_configs)
self.stride = 32

self.stem = ConvNormAct(3, stem_channels, 6, 2)
Expand All @@ -129,8 +129,7 @@ def from_config(variant: str, pretrained: bool = False) -> DarknetYOLOv5:
x=(4 / 3, 5 / 4, "darknet_yolov5x-0ed0c035.pth"),
)[variant]
stage_configs = [
DarknetStageConfig(int(d * depth_scale), int(w * width_scale))
for d, w in zip((3, 6, 9, 3), (128, 256, 512, 1024))
(int(d * depth_scale), int(w * width_scale)) for d, w in zip((3, 6, 9, 3), (128, 256, 512, 1024))
]
m = DarknetYOLOv5(int(64 * width_scale), stage_configs)
if pretrained:
Expand Down

0 comments on commit b224ff0

Please sign in to comment.