Skip to content

Commit

Permalink
add timm-MobileNetV3 as an Encoder (#355)
Browse files Browse the repository at this point in the history
* add timm-mobilenetv3 as encoder

* fix import bug

Co-authored-by: Pavel Yakubovskiy <[email protected]>
  • Loading branch information
markson14 and qubvel authored Jul 4, 2021
1 parent 23a54b4 commit 225823b
Show file tree
Hide file tree
Showing 4 changed files with 202 additions and 1 deletion.
18 changes: 17 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ The main features of this library are:

- High level API (just two lines to create a neural network)
- 9 models architectures for binary and multi class segmentation (including legendary Unet)
- 109 available encoders
- 115 available encoders
- All encoders have pre-trained weights for faster and better convergence

### [📚 Project Documentation 📚](http://smp.readthedocs.io/)
Expand Down Expand Up @@ -337,6 +337,22 @@ The following is a list of supported encoders in the SMP. Select the appropriate
</div>
</details>

<details>
<summary style="margin-left: 25px;">MobileNetV3</summary>
<div style="margin-left: 25px;">

|Encoder |Weights |Params, M |
|--------------------------------|:------------------------------:|:------------------------------:|
|timm-mobilenetv3_large_075 |imagenet |1.78M |
|timm-mobilenetv3_large_100 |imagenet |2.97M |
|timm-mobilenetv3_large_minimal_100|imagenet |1.41M |
|timm-mobilenetv3_small_075 |imagenet |0.57M |
|timm-mobilenetv3_small_100 |imagenet |0.93M |
|timm-mobilenetv3_small_minimal_100|imagenet |0.43M |

</div>
</details>


\* `ssl`, `swsl` - semi-supervised and weakly-supervised learning on ImageNet ([repo](https://github.com/facebookresearch/semi-supervised-ImageNet1K-models)).

Expand Down
19 changes: 19 additions & 0 deletions docs/encoders.rst
Original file line number Diff line number Diff line change
Expand Up @@ -316,3 +316,22 @@ VGG
+-------------+------------+-------------+
| vgg19\_bn | imagenet | 20M |
+-------------+------------+-------------+

MobileNetV3
~~~~~~~~~

+-----------------------------------+------------+-------------+
| Encoder | Weights | Params, M |
+===================================+============+=============+
| timm-mobilenetv3_large_075 | imagenet | 1.78M |
+-----------------------------------+------------+-------------+
| timm-mobilenetv3_large_100 | imagenet | 2.97M |
+-----------------------------------+------------+-------------+
| timm-mobilenetv3_large_minimal_100| imagenet | 1.41M |
+-----------------------------------+------------+-------------+
| timm-mobilenetv3_small_075 | imagenet | 0.57M |
+-----------------------------------+------------+-------------+
| timm-mobilenetv3_small_100 | imagenet | 0.93M |
+-----------------------------------+------------+-------------+
| timm-mobilenetv3_small_minimal_100| imagenet | 0.43M |
+-----------------------------------+------------+-------------+
2 changes: 2 additions & 0 deletions segmentation_models_pytorch/encoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .timm_res2net import timm_res2net_encoders
from .timm_regnet import timm_regnet_encoders
from .timm_sknet import timm_sknet_encoders
from .timm_mobilenetv3 import timm_mobilenetv3_encoders
try:
from .timm_gernet import timm_gernet_encoders
except ImportError as e:
Expand All @@ -43,6 +44,7 @@
encoders.update(timm_res2net_encoders)
encoders.update(timm_regnet_encoders)
encoders.update(timm_sknet_encoders)
encoders.update(timm_mobilenetv3_encoders)
encoders.update(timm_gernet_encoders)


Expand Down
164 changes: 164 additions & 0 deletions segmentation_models_pytorch/encoders/timm_mobilenetv3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
from timm import create_model
import torch.nn as nn
from ._base import EncoderMixin


def make_divisible(x, divisible_by=8):
import numpy as np
return int(np.ceil(x * 1. / divisible_by) * divisible_by)


class MobileNetV3Encoder(nn.Module, EncoderMixin):
def __init__(self, model, width_mult, depth=5, **kwargs):
super().__init__()
self._depth = depth
if 'small' in str(model):
self.mode = 'small'
self._out_channels = (16*width_mult, 16*width_mult, 24*width_mult, 48*width_mult, 576*width_mult)
self._out_channels = tuple(map(make_divisible, self._out_channels))
elif 'large' in str(model):
self.mode = 'large'
self._out_channels = (16*width_mult, 24*width_mult, 40*width_mult, 112*width_mult, 960*width_mult)
self._out_channels = tuple(map(make_divisible, self._out_channels))
else:
self.mode = 'None'
raise ValueError(
'MobileNetV3 mode should be small or large, got {}'.format(self.mode))
self._out_channels = (3,) + self._out_channels
self._in_channels = 3
# minimal models replace hardswish with relu
model = create_model(model_name=model,
scriptable=True, # torch.jit scriptable
exportable=True, # onnx export
features_only=True)
self.conv_stem = model.conv_stem
self.bn1 = model.bn1
self.act1 = model.act1
self.blocks = model.blocks

def get_stages(self):
if self.mode == 'small':
return [
nn.Identity(),
nn.Sequential(self.conv_stem, self.bn1, self.act1),
self.blocks[0],
self.blocks[1],
self.blocks[2:4],
self.blocks[4:],
]
elif self.mode == 'large':
return [
nn.Identity(),
nn.Sequential(self.conv_stem, self.bn1, self.act1, self.blocks[0]),
self.blocks[1],
self.blocks[2],
self.blocks[3:5],
self.blocks[5:],
]
else:
ValueError('MobileNetV3 mode should be small or large, got {}'.format(self.mode))

def forward(self, x):
stages = self.get_stages()

features = []
for i in range(self._depth + 1):
x = stages[i](x)
features.append(x)

return features

def load_state_dict(self, state_dict, **kwargs):
state_dict.pop('conv_head.weight')
state_dict.pop('conv_head.bias')
state_dict.pop('classifier.weight')
state_dict.pop('classifier.bias')
super().load_state_dict(state_dict, **kwargs)


mobilenetv3_weights = {
'tf_mobilenetv3_large_075': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth'
},
'tf_mobilenetv3_large_100': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth'
},
'tf_mobilenetv3_large_minimal_100': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth'
},
'tf_mobilenetv3_small_075': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth'
},
'tf_mobilenetv3_small_100': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth'
},
'tf_mobilenetv3_small_minimal_100': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth'
},


}

pretrained_settings = {}
for model_name, sources in mobilenetv3_weights.items():
pretrained_settings[model_name] = {}
for source_name, source_url in sources.items():
pretrained_settings[model_name][source_name] = {
"url": source_url,
'input_range': [0, 1],
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
'input_space': 'RGB',
}


timm_mobilenetv3_encoders = {
'timm-mobilenetv3_large_075': {
'encoder': MobileNetV3Encoder,
'pretrained_settings': pretrained_settings['tf_mobilenetv3_large_075'],
'params': {
'model': 'tf_mobilenetv3_large_075',
'width_mult': 0.75
}
},
'timm-mobilenetv3_large_100': {
'encoder': MobileNetV3Encoder,
'pretrained_settings': pretrained_settings['tf_mobilenetv3_large_100'],
'params': {
'model': 'tf_mobilenetv3_large_100',
'width_mult': 1.0
}
},
'timm-mobilenetv3_large_minimal_100': {
'encoder': MobileNetV3Encoder,
'pretrained_settings': pretrained_settings['tf_mobilenetv3_large_minimal_100'],
'params': {
'model': 'tf_mobilenetv3_large_minimal_100',
'width_mult': 1.0
}
},
'timm-mobilenetv3_small_075': {
'encoder': MobileNetV3Encoder,
'pretrained_settings': pretrained_settings['tf_mobilenetv3_small_075'],
'params': {
'model': 'tf_mobilenetv3_small_075',
'width_mult': 0.75
}
},
'timm-mobilenetv3_small_100': {
'encoder': MobileNetV3Encoder,
'pretrained_settings': pretrained_settings['tf_mobilenetv3_small_100'],
'params': {
'model': 'tf_mobilenetv3_small_100',
'width_mult': 1.0
}
},
'timm-mobilenetv3_small_minimal_100': {
'encoder': MobileNetV3Encoder,
'pretrained_settings': pretrained_settings['tf_mobilenetv3_small_minimal_100'],
'params': {
'model': 'tf_mobilenetv3_small_minimal_100',
'width_mult': 1.0
}
},
}

0 comments on commit 225823b

Please sign in to comment.