Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Commit

Permalink
Add standard resnet models (#405)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #405

It is annoying to configure the ResNet blocks all the time. Add the standard
models to the code so we can refer to them by name

Reviewed By: aadcock

Differential Revision: D20050757

fbshipit-source-id: 41321a323e1a4b0259c76d525b59e66caeb2ce0e
  • Loading branch information
vreis authored and facebook-github-bot committed Feb 24, 2020
1 parent 4715b0a commit e47a18d
Showing 1 changed file with 113 additions and 8 deletions.
121 changes: 113 additions & 8 deletions classy_vision/models/resnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"""

import math
from typing import Any, Dict
from typing import Any, Dict, List, Optional, Tuple, Union

import torch.nn as nn
from classy_vision.generic.util import is_pos_int
Expand Down Expand Up @@ -228,13 +228,13 @@ class ResNeXt(ClassyModel):
def __init__(
self,
num_blocks,
init_planes,
reduction,
small_input,
zero_init_bn_residuals,
base_width_and_cardinality,
basic_layer,
final_bn_relu,
init_planes: int = 64,
reduction: int = 4,
small_input: bool = False,
zero_init_bn_residuals: bool = False,
base_width_and_cardinality: Optional[Union[Tuple, List]] = None,
basic_layer: bool = False,
final_bn_relu: bool = True,
):
"""
Implementation of `ResNeXt <https://arxiv.org/pdf/1611.05431.pdf>`_.
Expand Down Expand Up @@ -414,3 +414,108 @@ def output_shape(self):
@property
def model_depth(self):
return sum(self.num_blocks)


@register_model("resnet18")
class ResNet18(ResNeXt):
def __init__(self):
super().__init__(
num_blocks=[2, 2, 2, 2], basic_layer=True, zero_init_bn_residuals=True
)

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
return cls()


@register_model("resnet34")
class ResNet34(ResNeXt):
def __init__(self):
super().__init__(
num_blocks=[3, 4, 6, 3], basic_layer=True, zero_init_bn_residuals=True
)

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
return cls()


@register_model("resnet50")
class ResNet50(ResNeXt):
def __init__(self):
super().__init__(
num_blocks=[3, 4, 6, 3], basic_layer=False, zero_init_bn_residuals=True
)

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
return cls()


@register_model("resnet101")
class ResNet101(ResNeXt):
def __init__(self):
super().__init__(
num_blocks=[3, 4, 23, 3], basic_layer=False, zero_init_bn_residuals=True
)

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
return cls()


@register_model("resnet152")
class ResNet152(ResNeXt):
def __init__(self):
super().__init__(
num_blocks=[3, 8, 36, 3], basic_layer=False, zero_init_bn_residuals=True
)

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
return cls()


@register_model("resnext50_32x4d")
class ResNeXt50(ResNeXt):
def __init__(self):
super().__init__(
num_blocks=[3, 4, 6, 3],
basic_layer=False,
zero_init_bn_residuals=True,
base_width_and_cardinality=(4, 32),
)

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
return cls()


@register_model("resnext101_32x4d")
class ResNeXt101(ResNeXt):
def __init__(self):
super().__init__(
num_blocks=[3, 4, 23, 3],
basic_layer=False,
zero_init_bn_residuals=True,
base_width_and_cardinality=(4, 32),
)

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
return cls()


@register_model("resnext152_32x4d")
class ResNeXt152(ResNeXt):
def __init__(self):
super().__init__(
num_blocks=[3, 8, 36, 3],
basic_layer=False,
zero_init_bn_residuals=True,
base_width_and_cardinality=(4, 32),
)

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
return cls()

0 comments on commit e47a18d

Please sign in to comment.