diff --git a/vision_toolbox/backbones/darknet.py b/vision_toolbox/backbones/darknet.py index ccc93ef..d7a85f1 100644 --- a/vision_toolbox/backbones/darknet.py +++ b/vision_toolbox/backbones/darknet.py @@ -64,12 +64,12 @@ class Darknet(BaseBackbone): def __init__( self, stem_channels: int, - stage_configs: list[DarknetStageConfig], + stage_configs: list[DarknetStageConfig | tuple[int, int]], stage_cls: Callable[..., nn.Module] = DarknetStage, ): assert len(stage_configs) > 0 super().__init__() - self.out_channels_list = tuple(cfg.out_channels for cfg in stage_configs) + self.out_channels_list = tuple(cfg[1] for cfg in stage_configs) self.stride = 32 self.stem = ConvNormAct(3, stem_channels) @@ -93,7 +93,7 @@ def from_config(variant: str, pretrained: bool = False) -> Darknet: darknet53=((1, 2, 8, 8, 4), DarknetStage, "darknet53-94427f5b.pth"), # YOLOv3 cspdarknet53=((1, 2, 8, 8, 4), CSPDarknetStage, "cspdarknet53-3bfa0423.pth"), # CSPNet/YOLOv4 )[variant] - stage_configs = list(map(DarknetStageConfig, n_blocks_list, (64, 128, 256, 512, 1024))) + stage_configs = list(zip(n_blocks_list, (64, 128, 256, 512, 1024))) m = Darknet(32, stage_configs, stage_cls) if pretrained: m._load_state_dict_from_url(_BASE_URL + ckpt) @@ -101,9 +101,9 @@ def from_config(variant: str, pretrained: bool = False) -> Darknet: class DarknetYOLOv5(BaseBackbone): - def __init__(self, stem_channels: int, stage_configs: list[DarknetStageConfig]) -> None: + def __init__(self, stem_channels: int, stage_configs: list[DarknetStageConfig | tuple[int, int]]) -> None: super().__init__() - self.out_channels_list = (stem_channels,) + tuple(cfg.out_channels for cfg in stage_configs) + self.out_channels_list = (stem_channels,) + tuple(cfg[1] for cfg in stage_configs) self.stride = 32 self.stem = ConvNormAct(3, stem_channels, 6, 2) @@ -129,8 +129,7 @@ def from_config(variant: str, pretrained: bool = False) -> DarknetYOLOv5: x=(4 / 3, 5 / 4, "darknet_yolov5x-0ed0c035.pth"), )[variant] stage_configs = [ - DarknetStageConfig(int(d * depth_scale), int(w * width_scale)) - for d, w in zip((3, 6, 9, 3), (128, 256, 512, 1024)) + (int(d * depth_scale), int(w * width_scale)) for d, w in zip((3, 6, 9, 3), (128, 256, 512, 1024)) ] m = DarknetYOLOv5(int(64 * width_scale), stage_configs) if pretrained: