Skip to content

Commit

Permalink
Fix DeepLabV3Plus encoder depth (#986)
Browse files Browse the repository at this point in the history
* fix issue #377

* modify docstring for upsampling of DeepLabV3Plus

* modify type hint and value check
  • Loading branch information
munehiro-k authored Nov 29, 2024
1 parent d490cdf commit cc482aa
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
20 changes: 16 additions & 4 deletions segmentation_models_pytorch/decoders/deeplabv3/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,20 @@ def forward(self, *features):
class DeepLabV3PlusDecoder(nn.Module):
def __init__(
self,
encoder_channels: Sequence[int, ...],
encoder_channels: Sequence[int],
encoder_depth: Literal[3, 4, 5],
out_channels: int,
atrous_rates: Iterable[int],
output_stride: Literal[8, 16],
aspp_separable: bool,
aspp_dropout: float,
):
super().__init__()
if output_stride not in {8, 16}:
if encoder_depth not in (3, 4, 5):
raise ValueError(
"Encoder depth should be 3, 4 or 5, got {}.".format(encoder_depth)
)
if output_stride not in (8, 16):
raise ValueError(
"Output stride should be 8 or 16, got {}.".format(output_stride)
)
Expand All @@ -104,7 +109,14 @@ def __init__(
scale_factor = 2 if output_stride == 8 else 4
self.up = nn.UpsamplingBilinear2d(scale_factor=scale_factor)

highres_in_channels = encoder_channels[-4]
if encoder_depth == 3 and output_stride == 8:
self.highres_input_index = -2
elif encoder_depth == 3 or encoder_depth == 4:
self.highres_input_index = -3
else:
self.highres_input_index = -4

highres_in_channels = encoder_channels[self.highres_input_index]
highres_out_channels = 48 # proposed by authors of paper
self.block1 = nn.Sequential(
nn.Conv2d(
Expand All @@ -128,7 +140,7 @@ def __init__(
def forward(self, *features):
aspp_features = self.aspp(features[-1])
aspp_features = self.up(aspp_features)
high_res_features = self.block1(features[-4])
high_res_features = self.block1(features[self.highres_input_index])
concat_features = torch.cat([aspp_features, high_res_features], dim=1)
fused_features = self.block2(concat_features)
return fused_features
Expand Down
6 changes: 4 additions & 2 deletions segmentation_models_pytorch/decoders/deeplabv3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ class DeepLabV3Plus(SegmentationModel):
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
**callable** and **None**.
Default is **None**
upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity
upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity. In case
**encoder_depth** and **encoder_output_stride** are 3 and 16 resp., set **upsampling** to 2 to preserve.
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
on top of encoder if **aux_params** is not **None** (default). Supported params:
- classes (int): A number of classes
Expand All @@ -150,7 +151,7 @@ class DeepLabV3Plus(SegmentationModel):
def __init__(
self,
encoder_name: str = "resnet34",
encoder_depth: int = 5,
encoder_depth: Literal[3, 4, 5] = 5,
encoder_weights: Optional[str] = "imagenet",
encoder_output_stride: Literal[8, 16] = 16,
decoder_channels: int = 256,
Expand All @@ -177,6 +178,7 @@ def __init__(

self.decoder = DeepLabV3PlusDecoder(
encoder_channels=self.encoder.out_channels,
encoder_depth=encoder_depth,
out_channels=decoder_channels,
atrous_rates=decoder_atrous_rates,
output_stride=encoder_output_stride,
Expand Down

0 comments on commit cc482aa

Please sign in to comment.