diff --git a/tests/test_backbones.py b/tests/test_backbones.py index 0074c81..60a6daf 100644 --- a/tests/test_backbones.py +++ b/tests/test_backbones.py @@ -2,9 +2,17 @@ import pytest import torch -from torch import Tensor, nn +from torch import Tensor -from vision_toolbox.backbones import Darknet, DarknetYOLOv5, VoVNet +from vision_toolbox.backbones import ( + Darknet, + DarknetYOLOv5, + EfficientNetExtractor, + MobileNetExtractor, + RegNetExtractor, + ResNetExtractor, + VoVNet, +) @pytest.fixture @@ -12,23 +20,17 @@ def inputs(): return torch.rand(1, 3, 224, 224) -vovnet_v1_models = [f"vovnet{x}" for x in ["27_slim", 39, 57]] -vovnet_v2_models = [f"vovnet{x}_ese" for x in ["19_slim", 19, 39, 57, 99]] -darknet_models = ["darknet19", "darknet53", "cspdarknet53"] -darknet_yolov5_models = [f"darknet_yolov5{x}" for x in ("n", "s", "m", "l", "x")] -torchvision_models = ["resnet18", "mobilenet_v2", "efficientnet_b0", "regnet_x_400mf"] - -all_models = vovnet_v1_models + vovnet_v2_models + darknet_models + darknet_yolov5_models + torchvision_models - - -def partial_list(fn, args_list): - return [partial(fn, *args) for args in args_list] - - factory_list = [ - *partial_list(Darknet.from_config, (("darknet19",), ("cspdarknet53",))), - *partial_list(DarknetYOLOv5.from_config, (("n",), ("l",))), - *partial_list(VoVNet.from_config, ((27, True), (39,), (19, True, True), (57, False, True))), + *[partial(Darknet.from_config, x) for x in ("darknet19", "cspdarknet53")], + *[partial(DarknetYOLOv5.from_config, x) for x in ("n", "l")], + *[ + partial(VoVNet.from_config, x, y, z) + for x, y, z in ((27, True, False), (39, False, False), (19, True, True), (57, False, True)) + ], + partial(ResNetExtractor, "resnet18"), + partial(RegNetExtractor, "regnet_x_400mf"), + partial(MobileNetExtractor, "mobilenet_v2"), + partial(EfficientNetExtractor, "efficientnet_b0"), ] diff --git a/vision_toolbox/backbones/__init__.py b/vision_toolbox/backbones/__init__.py index 3b75e98..1425c61 100644 --- a/vision_toolbox/backbones/__init__.py +++ b/vision_toolbox/backbones/__init__.py @@ -1,6 +1,5 @@ -from .base import * from .darknet import Darknet, DarknetYOLOv5 from .patchconvnet import * -from .torchvision_models import * -from .vit import * +from .torchvision_models import EfficientNetExtractor, MobileNetExtractor, RegNetExtractor, ResNetExtractor +from .vit import ViT from .vovnet import VoVNet diff --git a/vision_toolbox/backbones/torchvision_models.py b/vision_toolbox/backbones/torchvision_models.py index b35c914..208612b 100644 --- a/vision_toolbox/backbones/torchvision_models.py +++ b/vision_toolbox/backbones/torchvision_models.py @@ -1,62 +1,13 @@ -import warnings +from __future__ import annotations import torch from torch import Tensor, nn -from torchvision.models import mobilenet, resnet - - -try: - from torchvision.models import efficientnet, regnet - from torchvision.models.feature_extraction import create_feature_extractor -except ImportError: - warnings.warn("torchvision < 0.11.0. torchvision models won't be available") - regnet = efficientnet = create_feature_extractor = None +from torchvision.models import efficientnet, mobilenet, regnet, resnet +from torchvision.models.feature_extraction import create_feature_extractor from .base import BaseBackbone -__all__ = [ - "ResNetExtractor", - "RegNetExtractor", - "MobileNetExtractor", - "EfficientNetExtractor", - "resnet18", - "resnet34", - "resnet50", - "resnet101", - "resnet152", - "resnext50_32x4d", - "resnext101_32x8d", - "wide_resnet50_2", - "wide_resnet101_2", - "regnet_x_400mf", - "regnet_x_800mf", - "regnet_x_1_6gf", - "regnet_x_3_2gf", - "regnet_x_8gf", - "regnet_x_16gf", - "regnet_x_32gf", - "regnet_y_400mf", - "regnet_y_800mf", - "regnet_y_1_6gf", - "regnet_y_3_2gf", - "regnet_y_8gf", - "regnet_y_16gf", - "regnet_y_32gf", - "mobilenet_v2", - "mobilenet_v3_large", - "mobilenet_v3_small", - "efficientnet_b0", - "efficientnet_b1", - "efficientnet_b2", - "efficientnet_b3", - "efficientnet_b4", - "efficientnet_b5", - "efficientnet_b6", - "efficientnet_b7", -] - - class _ExtractorBackbone(BaseBackbone): def __init__(self, backbone: nn.Module, node_names: list[str]): super().__init__() @@ -102,139 +53,3 @@ def __init__(self, name: str, pretrained: bool = False): stage_indices = [2, 3, 4, 6] node_names = [f"features.{i}.0.block.0" for i in stage_indices] + [f"features.{len(backbone.features)-1}"] super().__init__(backbone, node_names) - - -def resnet18(pretrained=False, **kwargs): - return ResNetExtractor("resnet18", pretrained=pretrained, **kwargs) - - -def resnet34(pretrained=False, **kwargs): - return ResNetExtractor("resnet34", pretrained=pretrained, **kwargs) - - -def resnet50(pretrained=False, **kwargs): - return ResNetExtractor("resnet50", pretrained=pretrained, **kwargs) - - -def resnet101(pretrained=False, **kwargs): - return ResNetExtractor("resnet101", pretrained=pretrained, **kwargs) - - -def resnet152(pretrained=False, **kwargs): - return ResNetExtractor("resnet152", pretrained=pretrained, **kwargs) - - -def resnext50_32x4d(pretrained=False, **kwargs): - return ResNetExtractor("resnext50_32x4d", pretrained=pretrained, **kwargs) - - -def resnext101_32x8d(pretrained=False, **kwargs): - return ResNetExtractor("resnext101_32x8d", pretrained=pretrained, **kwargs) - - -def wide_resnet50_2(pretrained=False, **kwargs): - return ResNetExtractor("wide_resnet50_2", pretrained=pretrained, **kwargs) - - -def wide_resnet101_2(pretrained=False, **kwargs): - return ResNetExtractor("wide_resnet101_2", pretrained=pretrained, **kwargs) - - -def mobilenet_v2(pretrained=False, **kwargs): - return MobileNetExtractor("mobilenet_v2", pretrained=pretrained, **kwargs) - - -def mobilenet_v3_large(pretrained=False, **kwargs): - return MobileNetExtractor("mobilenet_v3_large", pretrained=pretrained, **kwargs) - - -def mobilenet_v3_small(pretrained=False, **kwargs): - return MobileNetExtractor("mobilenet_v3_small", pretrained=pretrained, **kwargs) - - -def efficientnet_b0(pretrained=False, **kwargs): - return EfficientNetExtractor("efficientnet_b0", pretrained=pretrained, **kwargs) - - -def efficientnet_b1(pretrained=False, **kwargs): - return EfficientNetExtractor("efficientnet_b1", pretrained=pretrained, **kwargs) - - -def efficientnet_b2(pretrained=False, **kwargs): - return EfficientNetExtractor("efficientnet_b2", pretrained=pretrained, **kwargs) - - -def efficientnet_b3(pretrained=False, **kwargs): - return EfficientNetExtractor("efficientnet_b3", pretrained=pretrained, **kwargs) - - -def efficientnet_b4(pretrained=False, **kwargs): - return EfficientNetExtractor("efficientnet_b4", pretrained=pretrained, **kwargs) - - -def efficientnet_b5(pretrained=False, **kwargs): - return EfficientNetExtractor("efficientnet_b5", pretrained=pretrained, **kwargs) - - -def efficientnet_b6(pretrained=False, **kwargs): - return EfficientNetExtractor("efficientnet_b6", pretrained=pretrained, **kwargs) - - -def efficientnet_b7(pretrained=False, **kwargs): - return EfficientNetExtractor("efficientnet_b7", pretrained=pretrained, **kwargs) - - -def regnet_x_400mf(pretrained=False, **kwargs): - return RegNetExtractor("regnet_x_400mf", pretrained=pretrained, **kwargs) - - -def regnet_x_800mf(pretrained=False, **kwargs): - return RegNetExtractor("regnet_x_800mf", pretrained=pretrained, **kwargs) - - -def regnet_x_1_6gf(pretrained=False, **kwargs): - return RegNetExtractor("regnet_x_1_6gf", pretrained=pretrained, **kwargs) - - -def regnet_x_3_2gf(pretrained=False, **kwargs): - return RegNetExtractor("regnet_x_3_2gf", pretrained=pretrained, **kwargs) - - -def regnet_x_8gf(pretrained=False, **kwargs): - return RegNetExtractor("regnet_x_8gf", pretrained=pretrained, **kwargs) - - -def regnet_x_16gf(pretrained=False, **kwargs): - return RegNetExtractor("regnet_x_16gf", pretrained=pretrained, **kwargs) - - -def regnet_x_32gf(pretrained=False, **kwargs): - return RegNetExtractor("regnet_x_32gf", pretrained=pretrained, **kwargs) - - -def regnet_y_400mf(pretrained=False, **kwargs): - return RegNetExtractor("regnet_y_400mf", pretrained=pretrained, **kwargs) - - -def regnet_y_800mf(pretrained=False, **kwargs): - return RegNetExtractor("regnet_y_800mf", pretrained=pretrained, **kwargs) - - -def regnet_y_1_6gf(pretrained=False, **kwargs): - return RegNetExtractor("regnet_y_1_6gf", pretrained=pretrained, **kwargs) - - -def regnet_y_3_2gf(pretrained=False, **kwargs): - return RegNetExtractor("regnet_y_3_2gf", pretrained=pretrained, **kwargs) - - -def regnet_y_8gf(pretrained=False, **kwargs): - return RegNetExtractor("regnet_y_8gf", pretrained=pretrained, **kwargs) - - -def regnet_y_16gf(pretrained=False, **kwargs): - return RegNetExtractor("regnet_y_16gf", pretrained=pretrained, **kwargs) - - -def regnet_y_32gf(pretrained=False, **kwargs): - return RegNetExtractor("regnet_y_32gf", pretrained=pretrained, **kwargs)