Skip to content

Commit

Permalink
fix torchvision
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Jul 24, 2023
1 parent d7755d7 commit 51b1b9a
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 209 deletions.
38 changes: 20 additions & 18 deletions tests/test_backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,35 @@

import pytest
import torch
from torch import Tensor, nn
from torch import Tensor

from vision_toolbox.backbones import Darknet, DarknetYOLOv5, VoVNet
from vision_toolbox.backbones import (
Darknet,
DarknetYOLOv5,
EfficientNetExtractor,
MobileNetExtractor,
RegNetExtractor,
ResNetExtractor,
VoVNet,
)


@pytest.fixture
def inputs():
return torch.rand(1, 3, 224, 224)


vovnet_v1_models = [f"vovnet{x}" for x in ["27_slim", 39, 57]]
vovnet_v2_models = [f"vovnet{x}_ese" for x in ["19_slim", 19, 39, 57, 99]]
darknet_models = ["darknet19", "darknet53", "cspdarknet53"]
darknet_yolov5_models = [f"darknet_yolov5{x}" for x in ("n", "s", "m", "l", "x")]
torchvision_models = ["resnet18", "mobilenet_v2", "efficientnet_b0", "regnet_x_400mf"]

all_models = vovnet_v1_models + vovnet_v2_models + darknet_models + darknet_yolov5_models + torchvision_models


def partial_list(fn, args_list):
return [partial(fn, *args) for args in args_list]


factory_list = [
*partial_list(Darknet.from_config, (("darknet19",), ("cspdarknet53",))),
*partial_list(DarknetYOLOv5.from_config, (("n",), ("l",))),
*partial_list(VoVNet.from_config, ((27, True), (39,), (19, True, True), (57, False, True))),
*[partial(Darknet.from_config, x) for x in ("darknet19", "cspdarknet53")],
*[partial(DarknetYOLOv5.from_config, x) for x in ("n", "l")],
*[
partial(VoVNet.from_config, x, y, z)
for x, y, z in ((27, True, False), (39, False, False), (19, True, True), (57, False, True))
],
partial(ResNetExtractor, "resnet18"),
partial(RegNetExtractor, "regnet_x_400mf"),
partial(MobileNetExtractor, "mobilenet_v2"),
partial(EfficientNetExtractor, "efficientnet_b0"),
]


Expand Down
5 changes: 2 additions & 3 deletions vision_toolbox/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .base import *
from .darknet import Darknet, DarknetYOLOv5
from .patchconvnet import *
from .torchvision_models import *
from .vit import *
from .torchvision_models import EfficientNetExtractor, MobileNetExtractor, RegNetExtractor, ResNetExtractor
from .vit import ViT
from .vovnet import VoVNet
191 changes: 3 additions & 188 deletions vision_toolbox/backbones/torchvision_models.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,13 @@
import warnings
from __future__ import annotations

import torch
from torch import Tensor, nn
from torchvision.models import mobilenet, resnet


try:
from torchvision.models import efficientnet, regnet
from torchvision.models.feature_extraction import create_feature_extractor
except ImportError:
warnings.warn("torchvision < 0.11.0. torchvision models won't be available")
regnet = efficientnet = create_feature_extractor = None
from torchvision.models import efficientnet, mobilenet, regnet, resnet
from torchvision.models.feature_extraction import create_feature_extractor

from .base import BaseBackbone


__all__ = [
"ResNetExtractor",
"RegNetExtractor",
"MobileNetExtractor",
"EfficientNetExtractor",
"resnet18",
"resnet34",
"resnet50",
"resnet101",
"resnet152",
"resnext50_32x4d",
"resnext101_32x8d",
"wide_resnet50_2",
"wide_resnet101_2",
"regnet_x_400mf",
"regnet_x_800mf",
"regnet_x_1_6gf",
"regnet_x_3_2gf",
"regnet_x_8gf",
"regnet_x_16gf",
"regnet_x_32gf",
"regnet_y_400mf",
"regnet_y_800mf",
"regnet_y_1_6gf",
"regnet_y_3_2gf",
"regnet_y_8gf",
"regnet_y_16gf",
"regnet_y_32gf",
"mobilenet_v2",
"mobilenet_v3_large",
"mobilenet_v3_small",
"efficientnet_b0",
"efficientnet_b1",
"efficientnet_b2",
"efficientnet_b3",
"efficientnet_b4",
"efficientnet_b5",
"efficientnet_b6",
"efficientnet_b7",
]


class _ExtractorBackbone(BaseBackbone):
def __init__(self, backbone: nn.Module, node_names: list[str]):
super().__init__()
Expand Down Expand Up @@ -102,139 +53,3 @@ def __init__(self, name: str, pretrained: bool = False):
stage_indices = [2, 3, 4, 6]
node_names = [f"features.{i}.0.block.0" for i in stage_indices] + [f"features.{len(backbone.features)-1}"]
super().__init__(backbone, node_names)


def resnet18(pretrained=False, **kwargs):
return ResNetExtractor("resnet18", pretrained=pretrained, **kwargs)


def resnet34(pretrained=False, **kwargs):
return ResNetExtractor("resnet34", pretrained=pretrained, **kwargs)


def resnet50(pretrained=False, **kwargs):
return ResNetExtractor("resnet50", pretrained=pretrained, **kwargs)


def resnet101(pretrained=False, **kwargs):
return ResNetExtractor("resnet101", pretrained=pretrained, **kwargs)


def resnet152(pretrained=False, **kwargs):
return ResNetExtractor("resnet152", pretrained=pretrained, **kwargs)


def resnext50_32x4d(pretrained=False, **kwargs):
return ResNetExtractor("resnext50_32x4d", pretrained=pretrained, **kwargs)


def resnext101_32x8d(pretrained=False, **kwargs):
return ResNetExtractor("resnext101_32x8d", pretrained=pretrained, **kwargs)


def wide_resnet50_2(pretrained=False, **kwargs):
return ResNetExtractor("wide_resnet50_2", pretrained=pretrained, **kwargs)


def wide_resnet101_2(pretrained=False, **kwargs):
return ResNetExtractor("wide_resnet101_2", pretrained=pretrained, **kwargs)


def mobilenet_v2(pretrained=False, **kwargs):
return MobileNetExtractor("mobilenet_v2", pretrained=pretrained, **kwargs)


def mobilenet_v3_large(pretrained=False, **kwargs):
return MobileNetExtractor("mobilenet_v3_large", pretrained=pretrained, **kwargs)


def mobilenet_v3_small(pretrained=False, **kwargs):
return MobileNetExtractor("mobilenet_v3_small", pretrained=pretrained, **kwargs)


def efficientnet_b0(pretrained=False, **kwargs):
return EfficientNetExtractor("efficientnet_b0", pretrained=pretrained, **kwargs)


def efficientnet_b1(pretrained=False, **kwargs):
return EfficientNetExtractor("efficientnet_b1", pretrained=pretrained, **kwargs)


def efficientnet_b2(pretrained=False, **kwargs):
return EfficientNetExtractor("efficientnet_b2", pretrained=pretrained, **kwargs)


def efficientnet_b3(pretrained=False, **kwargs):
return EfficientNetExtractor("efficientnet_b3", pretrained=pretrained, **kwargs)


def efficientnet_b4(pretrained=False, **kwargs):
return EfficientNetExtractor("efficientnet_b4", pretrained=pretrained, **kwargs)


def efficientnet_b5(pretrained=False, **kwargs):
return EfficientNetExtractor("efficientnet_b5", pretrained=pretrained, **kwargs)


def efficientnet_b6(pretrained=False, **kwargs):
return EfficientNetExtractor("efficientnet_b6", pretrained=pretrained, **kwargs)


def efficientnet_b7(pretrained=False, **kwargs):
return EfficientNetExtractor("efficientnet_b7", pretrained=pretrained, **kwargs)


def regnet_x_400mf(pretrained=False, **kwargs):
return RegNetExtractor("regnet_x_400mf", pretrained=pretrained, **kwargs)


def regnet_x_800mf(pretrained=False, **kwargs):
return RegNetExtractor("regnet_x_800mf", pretrained=pretrained, **kwargs)


def regnet_x_1_6gf(pretrained=False, **kwargs):
return RegNetExtractor("regnet_x_1_6gf", pretrained=pretrained, **kwargs)


def regnet_x_3_2gf(pretrained=False, **kwargs):
return RegNetExtractor("regnet_x_3_2gf", pretrained=pretrained, **kwargs)


def regnet_x_8gf(pretrained=False, **kwargs):
return RegNetExtractor("regnet_x_8gf", pretrained=pretrained, **kwargs)


def regnet_x_16gf(pretrained=False, **kwargs):
return RegNetExtractor("regnet_x_16gf", pretrained=pretrained, **kwargs)


def regnet_x_32gf(pretrained=False, **kwargs):
return RegNetExtractor("regnet_x_32gf", pretrained=pretrained, **kwargs)


def regnet_y_400mf(pretrained=False, **kwargs):
return RegNetExtractor("regnet_y_400mf", pretrained=pretrained, **kwargs)


def regnet_y_800mf(pretrained=False, **kwargs):
return RegNetExtractor("regnet_y_800mf", pretrained=pretrained, **kwargs)


def regnet_y_1_6gf(pretrained=False, **kwargs):
return RegNetExtractor("regnet_y_1_6gf", pretrained=pretrained, **kwargs)


def regnet_y_3_2gf(pretrained=False, **kwargs):
return RegNetExtractor("regnet_y_3_2gf", pretrained=pretrained, **kwargs)


def regnet_y_8gf(pretrained=False, **kwargs):
return RegNetExtractor("regnet_y_8gf", pretrained=pretrained, **kwargs)


def regnet_y_16gf(pretrained=False, **kwargs):
return RegNetExtractor("regnet_y_16gf", pretrained=pretrained, **kwargs)


def regnet_y_32gf(pretrained=False, **kwargs):
return RegNetExtractor("regnet_y_32gf", pretrained=pretrained, **kwargs)

0 comments on commit 51b1b9a

Please sign in to comment.