Skip to content

Commit

Permalink
Fix encoder depth & output stride on DeeplabV3 & DeeplabV3+ (#991)
Browse files Browse the repository at this point in the history
* fix encoder depth & output stride

* fix ruff style

* Revert "fix ruff style"

This reverts commit 79d5568.

* fix encoder depth & output stride

* fix ruff style

* update deeplabv3+ doc

* restored aux_params
  • Loading branch information
brianhou0208 authored Dec 9, 2024
1 parent 589583e commit 7c64aa5
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 35 deletions.
41 changes: 14 additions & 27 deletions segmentation_models_pytorch/decoders/deeplabv3/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def __init__(
nn.BatchNorm2d(out_channels),
nn.ReLU(),
)
self.out_channels = out_channels

def forward(self, *features):
return super().forward(features[-1])
Expand All @@ -79,17 +78,12 @@ def __init__(
aspp_dropout: float,
):
super().__init__()
if encoder_depth not in (3, 4, 5):
if encoder_depth < 3:
raise ValueError(
"Encoder depth should be 3, 4 or 5, got {}.".format(encoder_depth)
"Encoder depth for DeepLabV3Plus decoder cannot be less than 3, got {}.".format(
encoder_depth
)
)
if output_stride not in (8, 16):
raise ValueError(
"Output stride should be 8 or 16, got {}.".format(output_stride)
)

self.out_channels = out_channels
self.output_stride = output_stride

self.aspp = nn.Sequential(
ASPP(
Expand All @@ -106,17 +100,10 @@ def __init__(
nn.ReLU(),
)

scale_factor = 2 if output_stride == 8 else 4
scale_factor = 4 if output_stride == 16 and encoder_depth > 3 else 2
self.up = nn.UpsamplingBilinear2d(scale_factor=scale_factor)

if encoder_depth == 3 and output_stride == 8:
self.highres_input_index = -2
elif encoder_depth == 3 or encoder_depth == 4:
self.highres_input_index = -3
else:
self.highres_input_index = -4

highres_in_channels = encoder_channels[self.highres_input_index]
highres_in_channels = encoder_channels[2]
highres_out_channels = 48 # proposed by authors of paper
self.block1 = nn.Sequential(
nn.Conv2d(
Expand All @@ -140,7 +127,7 @@ def __init__(
def forward(self, *features):
aspp_features = self.aspp(features[-1])
aspp_features = self.up(aspp_features)
high_res_features = self.block1(features[self.highres_input_index])
high_res_features = self.block1(features[2])
concat_features = torch.cat([aspp_features, high_res_features], dim=1)
fused_features = self.block2(concat_features)
return fused_features
Expand Down Expand Up @@ -240,13 +227,13 @@ def forward(self, x):
class SeparableConv2d(nn.Sequential):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
bias=True,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
bias: bool = True,
):
dephtwise_conv = nn.Conv2d(
in_channels,
Expand Down
38 changes: 30 additions & 8 deletions segmentation_models_pytorch/decoders/deeplabv3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,16 @@ class DeepLabV3(SegmentationModel):
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
**callable** and **None**.
Default is **None**
upsampling: Final upsampling factor (should have the same value as ``encoder_output_stride`` to preserve input-output spatial shape identity).
upsampling: Final upsampling factor. Default is **None** to preserve input-output spatial shape identity
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
on top of encoder if **aux_params** is not **None** (default). Supported params:
- classes (int): A number of classes
- pooling (str): One of "max", "avg". Default is "avg"
- dropout (float): Dropout factor in [0, 1)
- activation (str): An activation function to apply "sigmoid"/"softmax"
(could be **None** to return logits)
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models.
Keys with ``None`` values are pruned before passing.
Returns:
``torch.nn.Module``: **DeepLabV3**
Expand Down Expand Up @@ -72,6 +73,12 @@ def __init__(
):
super().__init__()

if encoder_output_stride not in [8, 16]:
raise ValueError(
"DeeplabV3 support output stride 8 or 16, got {}.".format(
encoder_output_stride
)
)
self.encoder = get_encoder(
encoder_name,
in_channels=in_channels,
Expand All @@ -81,6 +88,14 @@ def __init__(
**kwargs,
)

if upsampling is None:
if encoder_depth <= 3:
scale_factor = 2**encoder_depth
else:
scale_factor = encoder_output_stride
else:
scale_factor = upsampling

self.decoder = DeepLabV3Decoder(
in_channels=self.encoder.out_channels[-1],
out_channels=decoder_channels,
Expand All @@ -90,11 +105,11 @@ def __init__(
)

self.segmentation_head = SegmentationHead(
in_channels=self.decoder.out_channels,
in_channels=decoder_channels,
out_channels=classes,
activation=activation,
kernel_size=1,
upsampling=encoder_output_stride if upsampling is None else upsampling,
upsampling=scale_factor,
)

if aux_params is not None:
Expand Down Expand Up @@ -129,16 +144,16 @@ class DeepLabV3Plus(SegmentationModel):
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
**callable** and **None**.
Default is **None**
upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity. In case
**encoder_depth** and **encoder_output_stride** are 3 and 16 resp., set **upsampling** to 2 to preserve.
upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity.
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
on top of encoder if **aux_params** is not **None** (default). Supported params:
- classes (int): A number of classes
- pooling (str): One of "max", "avg". Default is "avg"
- dropout (float): Dropout factor in [0, 1)
- activation (str): An activation function to apply "sigmoid"/"softmax"
(could be **None** to return logits)
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models.
Keys with ``None`` values are pruned before passing.
Returns:
``torch.nn.Module``: **DeepLabV3Plus**
Expand Down Expand Up @@ -167,6 +182,13 @@ def __init__(
):
super().__init__()

if encoder_output_stride not in [8, 16]:
raise ValueError(
"DeeplabV3Plus support output stride 8 or 16, got {}.".format(
encoder_output_stride
)
)

self.encoder = get_encoder(
encoder_name,
in_channels=in_channels,
Expand All @@ -187,7 +209,7 @@ def __init__(
)

self.segmentation_head = SegmentationHead(
in_channels=self.decoder.out_channels,
in_channels=decoder_channels,
out_channels=classes,
activation=activation,
kernel_size=1,
Expand Down

0 comments on commit 7c64aa5

Please sign in to comment.