Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update timm to 1.* version and support more encoders #885

Closed
wants to merge 16 commits into from
Closed
2 changes: 1 addition & 1 deletion .github/workflows/pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.7'
python-version: '3.8'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: 3.7
python-version: 3.8
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Segmentation based on [PyTorch](https://pytorch.org/).**
[![PyPI - Downloads](https://img.shields.io/pypi/dm/segmentation-models-pytorch?style=for-the-badge&color=blue)](https://pepy.tech/project/segmentation-models-pytorch)
<br>
[![PyTorch - Version](https://img.shields.io/badge/PYTORCH-1.4+-red?style=for-the-badge&logo=pytorch)](https://pepy.tech/project/segmentation-models-pytorch)
[![Python - Version](https://img.shields.io/badge/PYTHON-3.7+-red?style=for-the-badge&logo=python&logoColor=white)](https://pepy.tech/project/segmentation-models-pytorch)
[![Python - Version](https://img.shields.io/badge/PYTHON-3.8+-red?style=for-the-badge&logo=python&logoColor=white)](https://pepy.tech/project/segmentation-models-pytorch)

</div>

Expand Down
381 changes: 381 additions & 0 deletions examples/configure_timm_encoder.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
torchvision>=0.5.0
pretrainedmodels==0.7.4
efficientnet-pytorch==0.7.1
timm==0.9.7
timm>1.0.0

tqdm
pillow
six
loguru
6 changes: 6 additions & 0 deletions segmentation_models_pytorch/base/model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import torch
from typing import TypeVar

from . import initialization as init
from .hub_mixin import SMPHubMixin

T = TypeVar("T", bound="SegmentationModel")


class SegmentationModel(torch.nn.Module, SMPHubMixin):
def __new__(cls, *args, **kwargs) -> T:
return super().__new__(cls, *args, **kwargs)

def initialize(self):
init.initialize_decoder(self.decoder)
init.initialize_head(self.segmentation_head)
Expand Down
24 changes: 13 additions & 11 deletions segmentation_models_pytorch/datasets/oxford_pet.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,22 +60,24 @@ def _read_split(self):
return filenames

@staticmethod
def download(root):
def download(root, force_reload=False):
# load images
filepath = os.path.join(root, "images.tar.gz")
download_url(
url="https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz",
filepath=filepath,
)
extract_archive(filepath)
if not os.path.exists(filepath) or force_reload:
download_url(
url="https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz",
filepath=filepath,
)
extract_archive(filepath)

# load annotations
filepath = os.path.join(root, "annotations.tar.gz")
download_url(
url="https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz",
filepath=filepath,
)
extract_archive(filepath)
if not os.path.exists(filepath) or force_reload:
download_url(
url="https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz",
filepath=filepath,
)
extract_archive(filepath)


class SimpleOxfordPetDataset(OxfordPetDataset):
Expand Down
9 changes: 5 additions & 4 deletions segmentation_models_pytorch/decoders/manet/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,12 @@ def __init__(
):
super().__init__()

if n_blocks != len(decoder_channels):
if n_blocks < len(decoder_channels):
decoder_channels = decoder_channels[-n_blocks:]
elif n_blocks > len(decoder_channels):
raise ValueError(
"Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
n_blocks, len(decoder_channels)
)
f"Specified `encoder_depth={n_blocks}`, but provided only {len(decoder_channels)} "
f"`decoder_channels={decoder_channels}`. Please provide a list of channels for all {n_blocks} blocks."
)

# remove first skip with same spatial resolution
Expand Down
9 changes: 5 additions & 4 deletions segmentation_models_pytorch/decoders/unet/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,12 @@ def __init__(
):
super().__init__()

if n_blocks != len(decoder_channels):
if n_blocks < len(decoder_channels):
decoder_channels = decoder_channels[-n_blocks:]
elif n_blocks > len(decoder_channels):
raise ValueError(
"Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
n_blocks, len(decoder_channels)
)
f"Specified `encoder_depth={n_blocks}`, but provided only {len(decoder_channels)} "
f"`decoder_channels={decoder_channels}`. Please provide a list of channels for all {n_blocks} blocks."
)

# remove first skip with same spatial resolution
Expand Down
21 changes: 17 additions & 4 deletions segmentation_models_pytorch/decoders/unet/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union, List
from typing import Optional, Union, List, Sequence, Callable

from segmentation_models_pytorch.encoders import get_encoder
from segmentation_models_pytorch.base import (
Expand All @@ -24,6 +24,15 @@ class Unet(SegmentationModel):
Default is 5
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
other pretrained weights (see table with available weights for each encoder_name)
encoder_indices: The indices of the encoder features that will be used in the decoder.
If **"first"**, only the first `encoder_depth` features will be used.
If **"last"**, only the last `encoder_depth` features will be used.
If a list of integers, the indices of the encoder features that will be used in the decoder.
If **None**, defaults to **"first"**.
encoder_channels: A list of integers that specify the number of output channels for each encoder layer.
If **None**, the number of encoder output channels stays the same as for specifier `encoder_name`.
If a list of integers, the number of encoder output channels is equal to the provided list,
features are adjusted by 1x1 convolutions without non-linearity.
decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder.
Length of the list should be the same as **encoder_depth**
decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers
Expand Down Expand Up @@ -58,12 +67,14 @@ def __init__(
encoder_name: str = "resnet34",
encoder_depth: int = 5,
encoder_weights: Optional[str] = "imagenet",
encoder_indices: Optional[Union[str, List[int]]] = None,
encoder_channels: Optional[List[int]] = None,
decoder_use_batchnorm: bool = True,
decoder_channels: List[int] = (256, 128, 64, 32, 16),
decoder_channels: Sequence[int] = (256, 128, 64, 32, 16),
decoder_attention_type: Optional[str] = None,
in_channels: int = 3,
classes: int = 1,
activation: Optional[Union[str, callable]] = None,
activation: Optional[Union[str, Callable]] = None,
aux_params: Optional[dict] = None,
):
super().__init__()
Expand All @@ -73,12 +84,14 @@ def __init__(
in_channels=in_channels,
depth=encoder_depth,
weights=encoder_weights,
out_indices=encoder_indices,
out_channels=encoder_channels,
)

self.decoder = UnetDecoder(
encoder_channels=self.encoder.out_channels,
decoder_channels=decoder_channels,
n_blocks=encoder_depth,
n_blocks=self.encoder.depth,
use_batchnorm=decoder_use_batchnorm,
center=True if encoder_name.startswith("vgg") else False,
attention_type=decoder_attention_type,
Expand Down
9 changes: 5 additions & 4 deletions segmentation_models_pytorch/decoders/unetplusplus/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,12 @@ def __init__(
):
super().__init__()

if n_blocks != len(decoder_channels):
if n_blocks < len(decoder_channels):
decoder_channels = decoder_channels[-n_blocks:]
elif n_blocks > len(decoder_channels):
raise ValueError(
"Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
n_blocks, len(decoder_channels)
)
f"Specified `encoder_depth={n_blocks}`, but provided only {len(decoder_channels)} "
f"`decoder_channels={decoder_channels}`. Please provide a list of channels for all {n_blocks} blocks."
)

# remove first skip with same spatial resolution
Expand Down
17 changes: 17 additions & 0 deletions segmentation_models_pytorch/encoders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import timm
import functools
import torch.utils.model_zoo as model_zoo
from loguru import logger

from .resnet import resnet_encoders
from .dpn import dpn_encoders
Expand Down Expand Up @@ -51,6 +52,10 @@
def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs):
if name.startswith("tu-"):
name = name[3:]

if "encoder_indices" in kwargs and kwargs["encoder_indices"] is None:
kwargs["encoder_indices"] = "first"
Comment on lines +56 to +57

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if "encoder_indices" in kwargs and kwargs["encoder_indices"] is None:
kwargs["encoder_indices"] = "first"
if "out_indices" in kwargs and kwargs["out_indices"] is None:
kwargs["out_indices"] = "first"

While it's called encoder_indices in the function select_feature_indices, in TimmUniversalEncoder it is still called out_indices


encoder = TimmUniversalEncoder(
name=name,
in_channels=in_channels,
Expand All @@ -61,6 +66,18 @@ def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **
)
return encoder

encoder_indices = kwargs.pop("encoder_indices", None)
if encoder_indices is not None:
logger.warning(
"Argument `encoder_indices` is supported only for `tu-` encoders (Timm) and will be ignored."
)
Comment on lines +69 to +73

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
encoder_indices = kwargs.pop("encoder_indices", None)
if encoder_indices is not None:
logger.warning(
"Argument `encoder_indices` is supported only for `tu-` encoders (Timm) and will be ignored."
)
out_indices = kwargs.pop("encoder_indices", None)
if out_indices is not None:
logger.warning(
"Argument `out_indices ` is supported only for `tu-` encoders (Timm) and will be ignored."
)


encoder_channels = kwargs.pop("encoder_channels", None)
if encoder_channels is not None:
logger.warning(
"Argument `encoder_channels` is supported only for `tu-` encoders (Timm) and will be ignored."
)

try:
Encoder = encoders[name]["encoder"]
except KeyError:
Expand Down
11 changes: 11 additions & 0 deletions segmentation_models_pytorch/encoders/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@ def out_channels(self):
def output_stride(self):
return min(self._output_stride, 2**self._depth)

@property
def depth(self):
return self._depth

@property
def features_info_str(self):
"""Return a string with information about intermediate and output tensor shapes"""
raise NotImplementedError(
"Method is only implemented for `timm` encoders ('tu-' prefix)"
)

def set_in_channels(self, in_channels, pretrained=True):
"""Change first convolution channels"""
if in_channels == 3:
Expand Down
Loading
Loading