Skip to content

Commit

Permalink
Release 0.2.0 (#430)
Browse files Browse the repository at this point in the history
* new in_channels != 3 initialization
* docs fixes 
* version resolving
  • Loading branch information
qubvel authored Jul 5, 2021
1 parent 225823b commit 914f2bf
Show file tree
Hide file tree
Showing 32 changed files with 233 additions and 366 deletions.
1 change: 0 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ jobs:
python -m pip install codecov pytest mock
pip3 install torch==1.9.0+cpu torchvision==0.10.0+cpu torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
pip install .
pip install -U git+https://github.com/rwightman/pytorch-image-models
- name: Test
run: |
python -m pytest -s tests
31 changes: 10 additions & 21 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ The main features of this library are:

- High level API (just two lines to create a neural network)
- 9 models architectures for binary and multi class segmentation (including legendary Unet)
- 115 available encoders
- 113 available encoders
- All encoders have pre-trained weights for faster and better convergence

### [📚 Project Documentation 📚](http://smp.readthedocs.io/)
Expand Down Expand Up @@ -297,8 +297,12 @@ The following is a list of supported encoders in the SMP. Select the appropriate
|Encoder |Weights |Params, M |
|--------------------------------|:------------------------------:|:------------------------------:|
|mobilenet_v2 |imagenet |2M |
|mobilenet_v3_large |imagenet |3M |
|mobilenet_v3_small |imagenet |1M |
|timm-mobilenetv3_large_075 |imagenet |1.78M |
|timm-mobilenetv3_large_100 |imagenet |2.97M |
|timm-mobilenetv3_large_minimal_100|imagenet |1.41M |
|timm-mobilenetv3_small_075 |imagenet |0.57M |
|timm-mobilenetv3_small_100 |imagenet |0.93M |
|timm-mobilenetv3_small_minimal_100|imagenet |0.43M |

</div>
</details>
Expand Down Expand Up @@ -337,22 +341,6 @@ The following is a list of supported encoders in the SMP. Select the appropriate
</div>
</details>

<details>
<summary style="margin-left: 25px;">MobileNetV3</summary>
<div style="margin-left: 25px;">

|Encoder |Weights |Params, M |
|--------------------------------|:------------------------------:|:------------------------------:|
|timm-mobilenetv3_large_075 |imagenet |1.78M |
|timm-mobilenetv3_large_100 |imagenet |2.97M |
|timm-mobilenetv3_large_minimal_100|imagenet |1.41M |
|timm-mobilenetv3_small_075 |imagenet |0.57M |
|timm-mobilenetv3_small_100 |imagenet |0.93M |
|timm-mobilenetv3_small_minimal_100|imagenet |0.43M |

</div>
</details>


\* `ssl`, `swsl` - semi-supervised and weakly-supervised learning on ImageNet ([repo](https://github.com/facebookresearch/semi-supervised-ImageNet1K-models)).

Expand All @@ -367,8 +355,9 @@ The following is a list of supported encoders in the SMP. Select the appropriate

##### Input channels
Input channels parameter allows you to create models, which process tensors with arbitrary number of channels.
If you use pretrained weights from imagenet - weights of first convolution will be reused for
1- or 2- channels inputs, for input channels > 4 weights of first convolution will be initialized randomly.
If you use pretrained weights from imagenet - weights of first convolution will be reused. For
1-channel case it would be a sum of weights of first convolution layer, otherwise channels would be
populated with weights like `new_weight[:, i] = pretrained_weight[:, i % 3]` and than scaled with `new_weight * 3 / new_in_channels`.
```python
model = smp.FPN('resnet34', in_channels=1)
mask = model(torch.ones([1, 1, 64, 64]))
Expand Down
45 changes: 17 additions & 28 deletions docs/encoders.rst
Original file line number Diff line number Diff line change
Expand Up @@ -265,15 +265,23 @@ EfficientNet
MobileNet
~~~~~~~~~

+---------------------+------------+-------------+
| Encoder | Weights | Params, M |
+=====================+============+=============+
| mobilenet\_v2 | imagenet | 2M |
+---------------------+------------+-------------+
| mobilenet\_v3_large | imagenet | 3M |
+---------------------+------------+-------------+
| mobilenet\_v2_small | imagenet | 1M |
+---------------------+------------+-------------+
+---------------------------------------+------------+-------------+
| Encoder | Weights | Params, M |
+=======================================+============+=============+
| mobilenet\_v2 | imagenet | 2M |
+---------------------------------------+------------+-------------+
| timm-mobilenetv3\_large\_075 | imagenet | 1.78M |
+---------------------------------------+------------+-------------+
| timm-mobilenetv3\_large\_100 | imagenet | 2.97M |
+---------------------------------------+------------+-------------+
| timm-mobilenetv3\_large\_minimal\_100 | imagenet | 1.41M |
+---------------------------------------+------------+-------------+
| timm-mobilenetv3\_small\_075 | imagenet | 0.57M |
+---------------------------------------+------------+-------------+
| timm-mobilenetv3\_small\_100 | imagenet | 0.93M |
+---------------------------------------+------------+-------------+
| timm-mobilenetv3\_small\_minimal\_100 | imagenet | 0.43M |
+---------------------------------------+------------+-------------+

DPN
~~~
Expand Down Expand Up @@ -316,22 +324,3 @@ VGG
+-------------+------------+-------------+
| vgg19\_bn | imagenet | 20M |
+-------------+------------+-------------+

MobileNetV3
~~~~~~~~~

+-----------------------------------+------------+-------------+
| Encoder | Weights | Params, M |
+===================================+============+=============+
| timm-mobilenetv3_large_075 | imagenet | 1.78M |
+-----------------------------------+------------+-------------+
| timm-mobilenetv3_large_100 | imagenet | 2.97M |
+-----------------------------------+------------+-------------+
| timm-mobilenetv3_large_minimal_100| imagenet | 1.41M |
+-----------------------------------+------------+-------------+
| timm-mobilenetv3_small_075 | imagenet | 0.57M |
+-----------------------------------+------------+-------------+
| timm-mobilenetv3_small_100 | imagenet | 0.93M |
+-----------------------------------+------------+-------------+
| timm-mobilenetv3_small_minimal_100| imagenet | 0.43M |
+-----------------------------------+------------+-------------+
4 changes: 4 additions & 0 deletions docs/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ DiceLoss
~~~~~~~~
.. autoclass:: segmentation_models_pytorch.losses.DiceLoss

TverskyLoss
~~~~~~~~
.. autoclass:: segmentation_models_pytorch.losses.TverskyLoss

FocalLoss
~~~~~~~~~
.. autoclass:: segmentation_models_pytorch.losses.FocalLoss
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
torchvision>=0.9.0
torchvision>=0.5.0
pretrainedmodels==0.7.4
efficientnet-pytorch==0.6.3
timm==0.4.12
2 changes: 1 addition & 1 deletion segmentation_models_pytorch/__version__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
VERSION = (0, 1, 3)
VERSION = (0, 2, 0)

__version__ = '.'.join(map(str, VERSION))
11 changes: 2 additions & 9 deletions segmentation_models_pytorch/encoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,14 @@
from .inceptionv4 import inceptionv4_encoders
from .efficientnet import efficient_net_encoders
from .mobilenet import mobilenet_encoders
from .mobilenet_v3 import mobilenet_v3_encoders
from .xception import xception_encoders
from .timm_efficientnet import timm_efficientnet_encoders
from .timm_resnest import timm_resnest_encoders
from .timm_res2net import timm_res2net_encoders
from .timm_regnet import timm_regnet_encoders
from .timm_sknet import timm_sknet_encoders
from .timm_mobilenetv3 import timm_mobilenetv3_encoders
try:
from .timm_gernet import timm_gernet_encoders
except ImportError as e:
timm_gernet_encoders = {}
print("Current timm version doesn't support GERNet."
"If GERNet support is needed please update timm")
from .timm_gernet import timm_gernet_encoders

from ._preprocessing import preprocess_input

Expand All @@ -37,7 +31,6 @@
encoders.update(inceptionv4_encoders)
encoders.update(efficient_net_encoders)
encoders.update(mobilenet_encoders)
encoders.update(mobilenet_v3_encoders)
encoders.update(xception_encoders)
encoders.update(timm_efficientnet_encoders)
encoders.update(timm_resnest_encoders)
Expand Down Expand Up @@ -68,7 +61,7 @@ def get_encoder(name, in_channels=3, depth=5, weights=None):
))
encoder.load_state_dict(model_zoo.load_url(settings["url"]))

encoder.set_in_channels(in_channels)
encoder.set_in_channels(in_channels, pretrained=weights is not None)

return encoder

Expand Down
4 changes: 2 additions & 2 deletions segmentation_models_pytorch/encoders/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def out_channels(self):
"""Return channels dimensions for each tensor of forward output of encoder"""
return self._out_channels[: self._depth + 1]

def set_in_channels(self, in_channels):
def set_in_channels(self, in_channels, pretrained=True):
"""Change first convolution channels"""
if in_channels == 3:
return
Expand All @@ -26,7 +26,7 @@ def set_in_channels(self, in_channels):
if self._out_channels[0] == 3:
self._out_channels = tuple([in_channels] + list(self._out_channels)[1:])

utils.patch_first_conv(model=self, in_channels=in_channels)
utils.patch_first_conv(model=self, new_in_channels=in_channels, pretrained=pretrained)

def get_stages(self):
"""Method should be overridden in encoder"""
Expand Down
43 changes: 26 additions & 17 deletions segmentation_models_pytorch/encoders/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.nn as nn


def patch_first_conv(model, in_channels):
def patch_first_conv(model, new_in_channels, default_in_channels=3, pretrained=True):
"""Change first convolution layer input channels.
In case:
in_channels == 1 or in_channels == 2 -> reuse original weights
Expand All @@ -11,29 +11,38 @@ def patch_first_conv(model, in_channels):

# get first conv
for module in model.modules():
if isinstance(module, nn.Conv2d):
if isinstance(module, nn.Conv2d) and module.in_channels == default_in_channels:
break

# change input channels for first conv
module.in_channels = in_channels

weight = module.weight.detach()
reset = False

if in_channels == 1:
weight = weight.sum(1, keepdim=True)
elif in_channels == 2:
weight = weight[:, :2] * (3.0 / 2.0)
module.in_channels = new_in_channels

if not pretrained:
module.weight = nn.parameter.Parameter(
torch.Tensor(
module.out_channels,
new_in_channels // module.groups,
*module.kernel_size
)
)
module.reset_parameters()

elif new_in_channels == 1:
new_weight = weight.sum(1, keepdim=True)
module.weight = nn.parameter.Parameter(new_weight)

else:
reset = True
weight = torch.Tensor(
new_weight = torch.Tensor(
module.out_channels,
module.in_channels // module.groups,
new_in_channels // module.groups,
*module.kernel_size
)

module.weight = nn.parameter.Parameter(weight)
if reset:
module.reset_parameters()
for i in range(new_in_channels):
new_weight[:, i] = weight[:, i % default_in_channels]

new_weight = new_weight * (default_in_channels / new_in_channels)
module.weight = nn.parameter.Parameter(new_weight)


def replace_strides_with_dilation(module, dilation_rate):
Expand Down
4 changes: 2 additions & 2 deletions segmentation_models_pytorch/encoders/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ def load_state_dict(self, state_dict):
del state_dict[key]

# remove linear
state_dict.pop("classifier.bias")
state_dict.pop("classifier.weight")
state_dict.pop("classifier.bias", None)
state_dict.pop("classifier.weight", None)

super().load_state_dict(state_dict)

Expand Down
4 changes: 2 additions & 2 deletions segmentation_models_pytorch/encoders/dpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def forward(self, x):
return features

def load_state_dict(self, state_dict, **kwargs):
state_dict.pop("last_linear.bias")
state_dict.pop("last_linear.weight")
state_dict.pop("last_linear.bias", None)
state_dict.pop("last_linear.weight", None)
super().load_state_dict(state_dict, **kwargs)


Expand Down
4 changes: 2 additions & 2 deletions segmentation_models_pytorch/encoders/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def forward(self, x):
return features

def load_state_dict(self, state_dict, **kwargs):
state_dict.pop("_fc.bias")
state_dict.pop("_fc.weight")
state_dict.pop("_fc.bias", None)
state_dict.pop("_fc.weight", None)
super().load_state_dict(state_dict, **kwargs)


Expand Down
4 changes: 2 additions & 2 deletions segmentation_models_pytorch/encoders/inceptionresnetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def forward(self, x):
return features

def load_state_dict(self, state_dict, **kwargs):
state_dict.pop("last_linear.bias")
state_dict.pop("last_linear.weight")
state_dict.pop("last_linear.bias", None)
state_dict.pop("last_linear.weight", None)
super().load_state_dict(state_dict, **kwargs)


Expand Down
4 changes: 2 additions & 2 deletions segmentation_models_pytorch/encoders/inceptionv4.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ def forward(self, x):
return features

def load_state_dict(self, state_dict, **kwargs):
state_dict.pop("last_linear.bias")
state_dict.pop("last_linear.weight")
state_dict.pop("last_linear.bias", None)
state_dict.pop("last_linear.weight", None)
super().load_state_dict(state_dict, **kwargs)


Expand Down
4 changes: 2 additions & 2 deletions segmentation_models_pytorch/encoders/mobilenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def forward(self, x):
return features

def load_state_dict(self, state_dict, **kwargs):
state_dict.pop("classifier.1.bias")
state_dict.pop("classifier.1.weight")
state_dict.pop("classifier.1.bias", None)
state_dict.pop("classifier.1.weight", None)
super().load_state_dict(state_dict, **kwargs)


Expand Down
Loading

0 comments on commit 914f2bf

Please sign in to comment.