diff --git a/keras_retinanet/models/__init__.py b/keras_retinanet/models/__init__.py index 24a58e3..d6428ee 100644 --- a/keras_retinanet/models/__init__.py +++ b/keras_retinanet/models/__init__.py @@ -49,16 +49,16 @@ def preprocess_image(self, inputs): def backbone(backbone_name): """ Returns a backbone object for the given backbone. """ - if 'seresnext' in backbone_name or 'seresnet' in backbone_name or 'senet' in backbone_name: + if 'densenet' in backbone_name: + from .densenet import DenseNetBackbone as b + elif 'seresnext' in backbone_name or 'seresnet' in backbone_name or 'senet' in backbone_name: from .senet import SeBackbone as b elif 'resnet' in backbone_name: from .resnet import ResNetBackbone as b elif 'mobilenet' in backbone_name: from .mobilenet import MobileNetBackbone as b elif 'vgg' in backbone_name: - from .vgg import VGGBackbone as b - elif 'densenet' in backbone_name: - from .densenet import DenseNetBackbone as b + from .vgg import VGGBackbone as b elif 'EfficientNet' in backbone_name: from .effnet import EfficientNetBackbone as b else: