diff --git a/segmentation_models_pytorch/decoders/pan/decoder.py b/segmentation_models_pytorch/decoders/pan/decoder.py index ab8f8675..fa0bb261 100644 --- a/segmentation_models_pytorch/decoders/pan/decoder.py +++ b/segmentation_models_pytorch/decoders/pan/decoder.py @@ -1,3 +1,6 @@ +from collections.abc import Sequence +from typing import Literal + import torch import torch.nn as nn import torch.nn.functional as F @@ -44,7 +47,9 @@ def forward(self, x): class FPABlock(nn.Module): - def __init__(self, in_channels, out_channels, upscale_mode="bilinear"): + def __init__( + self, in_channels: int, out_channels: int, upscale_mode: str = "bilinear" + ): super(FPABlock, self).__init__() self.upscale_mode = upscale_mode @@ -175,34 +180,56 @@ def forward(self, x, y): class PANDecoder(nn.Module): def __init__( - self, encoder_channels, decoder_channels, upscale_mode: str = "bilinear" + self, + encoder_channels: Sequence[int], + encoder_depth: Literal[3, 4, 5], + decoder_channels: int, + upscale_mode: str = "bilinear", ): super().__init__() + if encoder_depth < 3: + raise ValueError( + "Encoder depth for PAN decoder cannot be less than 3, got {}.".format( + encoder_depth + ) + ) + + encoder_channels = encoder_channels[2:] + self.fpa = FPABlock( in_channels=encoder_channels[-1], out_channels=decoder_channels ) - self.gau3 = GAUBlock( - in_channels=encoder_channels[-2], - out_channels=decoder_channels, - upscale_mode=upscale_mode, - ) - self.gau2 = GAUBlock( - in_channels=encoder_channels[-3], - out_channels=decoder_channels, - upscale_mode=upscale_mode, - ) - self.gau1 = GAUBlock( - in_channels=encoder_channels[-4], - out_channels=decoder_channels, - upscale_mode=upscale_mode, - ) + + if encoder_depth == 5: + self.gau3 = GAUBlock( + in_channels=encoder_channels[2], + out_channels=decoder_channels, + upscale_mode=upscale_mode, + ) + if encoder_depth >= 4: + self.gau2 = GAUBlock( + in_channels=encoder_channels[1], + out_channels=decoder_channels, + upscale_mode=upscale_mode, + ) + if encoder_depth >= 3: + self.gau1 = GAUBlock( + in_channels=encoder_channels[0], + out_channels=decoder_channels, + upscale_mode=upscale_mode, + ) def forward(self, *features): - bottleneck = features[-1] - x5 = self.fpa(bottleneck) # 1/32 - x4 = self.gau3(features[-2], x5) # 1/16 - x3 = self.gau2(features[-3], x4) # 1/8 - x2 = self.gau1(features[-4], x3) # 1/4 + features = features[2:] # remove first and second skip + + out = self.fpa(features[-1]) # 1/16 or 1/32 + + if hasattr(self, "gau3"): + out = self.gau3(features[2], out) # 1/16 + if hasattr(self, "gau2"): + out = self.gau2(features[1], out) # 1/8 + if hasattr(self, "gau1"): + out = self.gau1(features[0], out) # 1/4 - return x2 + return out diff --git a/segmentation_models_pytorch/decoders/pan/model.py b/segmentation_models_pytorch/decoders/pan/model.py index 5c46f489..712541a5 100644 --- a/segmentation_models_pytorch/decoders/pan/model.py +++ b/segmentation_models_pytorch/decoders/pan/model.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union +from typing import Any, Callable, Literal, Optional, Union from segmentation_models_pytorch.base import ( ClassificationHead, @@ -20,6 +20,10 @@ class PAN(SegmentationModel): Args: encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) to extract features of different spatial resolution + encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features + two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features + with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). + Default is 5 encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and other pretrained weights (see table with available weights for each encoder_name) encoder_output_stride: 16 or 32, if 16 use dilation in encoder last layer. @@ -52,12 +56,13 @@ class PAN(SegmentationModel): def __init__( self, encoder_name: str = "resnet34", + encoder_depth: Literal[3, 4, 5] = 5, encoder_weights: Optional[str] = "imagenet", - encoder_output_stride: int = 16, + encoder_output_stride: Literal[16, 32] = 16, decoder_channels: int = 32, in_channels: int = 3, classes: int = 1, - activation: Optional[Union[str, callable]] = None, + activation: Optional[Union[str, Callable]] = None, upsampling: int = 4, aux_params: Optional[dict] = None, **kwargs: dict[str, Any], @@ -74,7 +79,7 @@ def __init__( self.encoder = get_encoder( encoder_name, in_channels=in_channels, - depth=5, + depth=encoder_depth, weights=encoder_weights, output_stride=encoder_output_stride, **kwargs, @@ -82,6 +87,7 @@ def __init__( self.decoder = PANDecoder( encoder_channels=self.encoder.out_channels, + encoder_depth=encoder_depth, decoder_channels=decoder_channels, ) diff --git a/tests/test_models.py b/tests/test_models.py index 10f697b8..68d12c43 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -31,9 +31,7 @@ def get_sample(model_class): smp.Segformer, ]: sample = torch.ones([1, 3, 64, 64]) - elif model_class == smp.PAN: - sample = torch.ones([2, 3, 256, 256]) - elif model_class in [smp.DeepLabV3, smp.DeepLabV3Plus]: + elif model_class in [smp.PAN, smp.DeepLabV3, smp.DeepLabV3Plus]: sample = torch.ones([2, 3, 128, 128]) elif model_class in [smp.PSPNet, smp.UPerNet]: # Batch size 2 needed due to nn.BatchNorm2d not supporting (1, C, 1, 1) input