Skip to content

Commit

Permalink
Make Backbone parent classes to allow handling all subclass presets
Browse files Browse the repository at this point in the history
  • Loading branch information
smitlg committed Apr 16, 2024
1 parent c60112e commit a94af37
Show file tree
Hide file tree
Showing 40 changed files with 330 additions and 895 deletions.
57 changes: 18 additions & 39 deletions keras_cv/models/backbones/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@

from keras_cv.api_export import keras_cv_export
from keras_cv.backend import keras
from keras_cv.utils.preset_utils import check_preset_class
from keras_cv.utils.preset_utils import check_config_class
from keras_cv.utils.preset_utils import list_presets
from keras_cv.utils.preset_utils import list_subclasses
from keras_cv.utils.preset_utils import load_from_preset
from keras_cv.utils.python_utils import classproperty
from keras_cv.utils.python_utils import format_docstring


@keras_cv_export("keras_cv.models.Backbone")
Expand Down Expand Up @@ -64,12 +65,18 @@ def from_config(cls, config):
@classproperty
def presets(cls):
"""Dictionary of preset names and configs."""
return {}
presets = list_presets(cls)
for subclass in list_subclasses(cls):
presets.update(subclass.presets)
return presets

@classproperty
def presets_with_weights(cls):
"""Dictionary of preset names and configs that include weights."""
return {}
presets = list_presets(cls, with_weights=True)
for subclass in list_subclasses(cls):
presets.update(subclass.presets)
return presets

@classproperty
def presets_without_weights(cls):
Expand Down Expand Up @@ -109,47 +116,19 @@ def from_preset(
load_weights=False,
```
"""
# We support short IDs for official presets, e.g. `"bert_base_en"`.
# Map these to a Kaggle Models handle.
if preset in cls.presets:
preset = cls.presets[preset]["kaggle_handle"]

check_preset_class(preset, cls)
preset_cls = check_config_class(preset)
if not issubclass(preset_cls, cls):
raise ValueError(
f"Preset has type `{preset_cls.__name__}` which is not a "
f"a subclass of calling class `{cls.__name__}`. Call "
f"`from_preset` directly on `{preset_cls.__name__}` instead."
)
return load_from_preset(
preset,
load_weights=load_weights,
config_overrides=kwargs,
)

def __init_subclass__(cls, **kwargs):
# Use __init_subclass__ to set up a correct docstring for from_preset.
super().__init_subclass__(**kwargs)

# If the subclass does not define from_preset, assign a wrapper so that
# each class can have a distinct docstring.
if "from_preset" not in cls.__dict__:

def from_preset(calling_cls, *args, **kwargs):
return super(cls, calling_cls).from_preset(*args, **kwargs)

cls.from_preset = classmethod(from_preset)

if not cls.presets:
cls.from_preset.__func__.__doc__ = """Not implemented.
No presets available for this class.
"""

# Format and assign the docstring unless the subclass has overridden it.
if cls.from_preset.__doc__ is None:
cls.from_preset.__func__.__doc__ = Backbone.from_preset.__doc__
format_docstring(
model_name=cls.__name__,
example_preset_name=next(iter(cls.presets_with_weights), ""),
preset_names='", "'.join(cls.presets),
preset_with_weights_names='", "'.join(cls.presets_with_weights),
)(cls.from_preset.__func__)

@property
def pyramid_level_inputs(self):
"""Intermediate model outputs for feature extraction.
Expand Down
3 changes: 3 additions & 0 deletions keras_cv/models/backbones/backbone_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@
from keras_cv.models.backbones.video_swin import video_swin_backbone_presets
from keras_cv.models.backbones.vit_det import vit_det_backbone_presets
from keras_cv.models.object_detection.yolo_v8 import yolo_v8_backbone_presets
from keras_cv.models.object_detection_3d import center_pillar_backbone_presets

backbone_presets_no_weights = {
**center_pillar_backbone_presets.backbone_presets_no_weights,
**resnet_v1_backbone_presets.backbone_presets_no_weights,
**resnet_v2_backbone_presets.backbone_presets_no_weights,
**mobilenet_v3_backbone_presets.backbone_presets_no_weights,
Expand All @@ -47,6 +49,7 @@
}

backbone_presets_with_weights = {
**center_pillar_backbone_presets.backbone_presets_with_weights,
**resnet_v1_backbone_presets.backbone_presets_with_weights,
**resnet_v2_backbone_presets.backbone_presets_with_weights,
**mobilenet_v3_backbone_presets.backbone_presets_with_weights,
Expand Down
19 changes: 19 additions & 0 deletions keras_cv/models/backbones/csp_darknet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_cv.models.backbones.csp_darknet.csp_darknet_backbone_presets import (
backbone_presets_no_weights, backbone_presets_with_weights,
)
from keras_cv.models.backbones.csp_darknet.csp_darknet_backbone import (
CSPDarkNetBackbone,
)
from keras_cv.models.backbones.csp_darknet.csp_darknet_aliases import (
CSPDarkNetTinyBackbone, CSPDarkNetLBackbone,
)
from keras_cv.utils.preset_utils import register_presets, register_preset

register_presets(backbone_presets_no_weights, (CSPDarkNetBackbone, ), with_weights=False)
register_presets(backbone_presets_with_weights, (CSPDarkNetBackbone, ), with_weights=True)
register_presets(backbone_presets_with_weights, (CSPDarkNetBackbone, ), with_weights=True)
register_preset("csp_darknet_tiny_imagenet", backbone_presets_with_weights["csp_darknet_tiny_imagenet"],
(CSPDarkNetTinyBackbone,), with_weights=True)
register_preset("csp_darknet_l_imagenet", backbone_presets_with_weights["csp_darknet_l_imagenet"],
(CSPDarkNetLBackbone,), with_weights=True)
67 changes: 0 additions & 67 deletions keras_cv/models/backbones/csp_darknet/csp_darknet_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy

from keras_cv.api_export import keras_cv_export
from keras_cv.models.backbones.csp_darknet.csp_darknet_backbone import (
CSPDarkNetBackbone,
)
from keras_cv.models.backbones.csp_darknet.csp_darknet_backbone_presets import (
backbone_presets,
)
from keras_cv.utils.python_utils import classproperty

ALIAS_DOCSTRING = """CSPDarkNetBackbone model with {stackwise_channels} channels
and {stackwise_depth} depths.
Expand Down Expand Up @@ -71,21 +66,6 @@ def __new__(
)
return CSPDarkNetBackbone.from_preset("csp_darknet_tiny", **kwargs)

@classproperty
def presets(cls):
"""Dictionary of preset names and configurations."""
return {
"csp_darknet_tiny_imagenet": copy.deepcopy(
backbone_presets["csp_darknet_tiny_imagenet"]
)
}

@classproperty
def presets_with_weights(cls):
"""Dictionary of preset names and configurations that include
weights."""
return cls.presets


@keras_cv_export("keras_cv.models.CSPDarkNetSBackbone")
class CSPDarkNetSBackbone(CSPDarkNetBackbone):
Expand All @@ -106,17 +86,6 @@ def __new__(
)
return CSPDarkNetBackbone.from_preset("csp_darknet_s", **kwargs)

@classproperty
def presets(cls):
"""Dictionary of preset names and configurations."""
return {}

@classproperty
def presets_with_weights(cls):
"""Dictionary of preset names and configurations that include
weights."""
return {}


@keras_cv_export("keras_cv.models.CSPDarkNetMBackbone")
class CSPDarkNetMBackbone(CSPDarkNetBackbone):
Expand All @@ -137,17 +106,6 @@ def __new__(
)
return CSPDarkNetBackbone.from_preset("csp_darknet_m", **kwargs)

@classproperty
def presets(cls):
"""Dictionary of preset names and configurations."""
return {}

@classproperty
def presets_with_weights(cls):
"""Dictionary of preset names and configurations that include
weights."""
return {}


@keras_cv_export("keras_cv.models.CSPDarkNetLBackbone")
class CSPDarkNetLBackbone(CSPDarkNetBackbone):
Expand All @@ -168,21 +126,6 @@ def __new__(
)
return CSPDarkNetBackbone.from_preset("csp_darknet_l", **kwargs)

@classproperty
def presets(cls):
"""Dictionary of preset names and configurations."""
return {
"csp_darknet_l_imagenet": copy.deepcopy(
backbone_presets["csp_darknet_l_imagenet"]
)
}

@classproperty
def presets_with_weights(cls):
"""Dictionary of preset names and configurations that include
weights."""
return cls.presets


@keras_cv_export("keras_cv.models.CSPDarkNetXLBackbone")
class CSPDarkNetXLBackbone(CSPDarkNetBackbone):
Expand All @@ -203,16 +146,6 @@ def __new__(
)
return CSPDarkNetBackbone.from_preset("csp_darknet_xl", **kwargs)

@classproperty
def presets(cls):
"""Dictionary of preset names and configurations."""
return {}

@classproperty
def presets_with_weights(cls):
"""Dictionary of preset names and configurations that include
weights."""
return {}


setattr(
Expand Down
20 changes: 0 additions & 20 deletions keras_cv/models/backbones/csp_darknet/csp_darknet_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,11 @@
# limitations under the License.

"""CSPDarkNet backbone model. """
import copy

from keras_cv.api_export import keras_cv_export
from keras_cv.backend import keras
from keras_cv.models import utils
from keras_cv.models.backbones.backbone import Backbone
from keras_cv.models.backbones.csp_darknet.csp_darknet_backbone_presets import (
backbone_presets,
)
from keras_cv.models.backbones.csp_darknet.csp_darknet_backbone_presets import (
backbone_presets_with_weights,
)
from keras_cv.models.backbones.csp_darknet.csp_darknet_utils import (
CrossStagePartial,
)
Expand All @@ -38,8 +31,6 @@
from keras_cv.models.backbones.csp_darknet.csp_darknet_utils import (
SpatialPyramidPoolingBottleneck,
)
from keras_cv.utils.python_utils import classproperty


@keras_cv_export("keras_cv.models.CSPDarkNetBackbone")
class CSPDarkNetBackbone(Backbone):
Expand Down Expand Up @@ -169,14 +160,3 @@ def get_config(self):
}
)
return config

@classproperty
def presets(cls):
"""Dictionary of preset names and configurations."""
return copy.deepcopy(backbone_presets)

@classproperty
def presets_with_weights(cls):
"""Dictionary of preset names and configurations that include
weights."""
return copy.deepcopy(backbone_presets_with_weights)
23 changes: 23 additions & 0 deletions keras_cv/models/backbones/densenet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_cv.models.backbones.densenet.densenet_backbone_presets import (
backbone_presets_no_weights,
)
from keras_cv.models.backbones.densenet.densenet_backbone_presets import (
backbone_presets_with_weights,
)
from keras_cv.models.backbones.densenet.densenet_backbone import (
DenseNetBackbone,
)
from keras_cv.models.backbones.densenet.densenet_aliases import (
DenseNet121Backbone, DenseNet169Backbone, DenseNet201Backbone
)
from keras_cv.utils.preset_utils import register_presets, register_preset

register_presets(backbone_presets_no_weights, (DenseNetBackbone, ), with_weights=False)
register_presets(backbone_presets_with_weights, (DenseNetBackbone, ), with_weights=True)
register_preset("densenet121_imagenet", backbone_presets_with_weights["densenet121_imagenet"],
(DenseNet121Backbone,), with_weights=True)
register_preset("densenet169_imagenet", backbone_presets_with_weights["densenet169_imagenet"],
(DenseNet169Backbone,), with_weights=True)
register_preset("densenet201_imagenet", backbone_presets_with_weights["densenet201_imagenet"],
(DenseNet201Backbone,), with_weights=True)
49 changes: 0 additions & 49 deletions keras_cv/models/backbones/densenet/densenet_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy

from keras_cv.api_export import keras_cv_export
from keras_cv.models.backbones.densenet.densenet_backbone import (
DenseNetBackbone,
)
from keras_cv.models.backbones.densenet.densenet_backbone_presets import (
backbone_presets,
)
from keras_cv.utils.python_utils import classproperty

ALIAS_DOCSTRING = """DenseNetBackbone model with {num_layers} layers.
Expand Down Expand Up @@ -69,21 +63,6 @@ def __new__(
)
return DenseNetBackbone.from_preset("densenet121", **kwargs)

@classproperty
def presets(cls):
"""Dictionary of preset names and configurations."""
return {
"densenet121_imagenet": copy.deepcopy(
backbone_presets["densenet121_imagenet"]
),
}

@classproperty
def presets_with_weights(cls):
"""Dictionary of preset names and configurations that include weights.""" # noqa: E501
return cls.presets


@keras_cv_export("keras_cv.models.DenseNet169Backbone")
class DenseNet169Backbone(DenseNetBackbone):
def __new__(
Expand All @@ -103,20 +82,6 @@ def __new__(
)
return DenseNetBackbone.from_preset("densenet169", **kwargs)

@classproperty
def presets(cls):
"""Dictionary of preset names and configurations."""
return {
"densenet169_imagenet": copy.deepcopy(
backbone_presets["densenet169_imagenet"]
),
}

@classproperty
def presets_with_weights(cls):
"""Dictionary of preset names and configurations that include weights.""" # noqa: E501
return cls.presets


@keras_cv_export("keras_cv.models.DenseNet201Backbone")
class DenseNet201Backbone(DenseNetBackbone):
Expand All @@ -137,20 +102,6 @@ def __new__(
)
return DenseNetBackbone.from_preset("densenet201", **kwargs)

@classproperty
def presets(cls):
"""Dictionary of preset names and configurations."""
return {
"densenet201_imagenet": copy.deepcopy(
backbone_presets["densenet201_imagenet"]
),
}

@classproperty
def presets_with_weights(cls):
"""Dictionary of preset names and configurations that include weights.""" # noqa: E501
return cls.presets


setattr(DenseNet121Backbone, "__doc__", ALIAS_DOCSTRING.format(num_layers=121))
setattr(DenseNet169Backbone, "__doc__", ALIAS_DOCSTRING.format(num_layers=169))
Expand Down
Loading

0 comments on commit a94af37

Please sign in to comment.