From 9389d8e16e7f54029a2130c25f22a918f8f174ce Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Sun, 10 Nov 2024 12:24:55 +0800 Subject: [PATCH 1/6] update PAN model support encoder depth --- segmentation_models_pytorch/decoders/pan/model.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/segmentation_models_pytorch/decoders/pan/model.py b/segmentation_models_pytorch/decoders/pan/model.py index 5c46f489..eba28492 100644 --- a/segmentation_models_pytorch/decoders/pan/model.py +++ b/segmentation_models_pytorch/decoders/pan/model.py @@ -20,6 +20,10 @@ class PAN(SegmentationModel): Args: encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) to extract features of different spatial resolution + encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features + two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features + with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). + Default is 5 encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and other pretrained weights (see table with available weights for each encoder_name) encoder_output_stride: 16 or 32, if 16 use dilation in encoder last layer. @@ -52,6 +56,7 @@ class PAN(SegmentationModel): def __init__( self, encoder_name: str = "resnet34", + encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", encoder_output_stride: int = 16, decoder_channels: int = 32, @@ -74,7 +79,7 @@ def __init__( self.encoder = get_encoder( encoder_name, in_channels=in_channels, - depth=5, + depth=encoder_depth, weights=encoder_weights, output_stride=encoder_output_stride, **kwargs, @@ -82,6 +87,7 @@ def __init__( self.decoder = PANDecoder( encoder_channels=self.encoder.out_channels, + encoder_depth=encoder_depth, decoder_channels=decoder_channels, ) From e8a18254c8d8103a37223a6e98a5a154f4358ffb Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Sun, 1 Dec 2024 01:51:48 +0800 Subject: [PATCH 2/6] update PAN decoder support encoder depth --- .../decoders/pan/decoder.py | 52 +++++++++++-------- 1 file changed, 29 insertions(+), 23 deletions(-) diff --git a/segmentation_models_pytorch/decoders/pan/decoder.py b/segmentation_models_pytorch/decoders/pan/decoder.py index ab8f8675..a1180d95 100644 --- a/segmentation_models_pytorch/decoders/pan/decoder.py +++ b/segmentation_models_pytorch/decoders/pan/decoder.py @@ -175,34 +175,40 @@ def forward(self, x, y): class PANDecoder(nn.Module): def __init__( - self, encoder_channels, decoder_channels, upscale_mode: str = "bilinear" + self, + encoder_channels, + encoder_depth, + decoder_channels, + upscale_mode: str = "bilinear", ): super().__init__() + if encoder_depth < 3: + raise ValueError( + "Encoder depth for PAN decoder cannot be less than 3, got {}.".format( + encoder_depth + ) + ) + + encoder_channels = encoder_channels[2:][::-1] + self.fpa = FPABlock( - in_channels=encoder_channels[-1], out_channels=decoder_channels - ) - self.gau3 = GAUBlock( - in_channels=encoder_channels[-2], - out_channels=decoder_channels, - upscale_mode=upscale_mode, - ) - self.gau2 = GAUBlock( - in_channels=encoder_channels[-3], - out_channels=decoder_channels, - upscale_mode=upscale_mode, - ) - self.gau1 = GAUBlock( - in_channels=encoder_channels[-4], - out_channels=decoder_channels, - upscale_mode=upscale_mode, + in_channels=encoder_channels[0], out_channels=decoder_channels ) + for i in range(1, len(encoder_channels)): + self.add_module(f"gau{len(encoder_channels)-i}", GAUBlock( + in_channels=encoder_channels[i], + out_channels=decoder_channels, + upscale_mode=upscale_mode, + )) + def forward(self, *features): - bottleneck = features[-1] - x5 = self.fpa(bottleneck) # 1/32 - x4 = self.gau3(features[-2], x5) # 1/16 - x3 = self.gau2(features[-3], x4) # 1/8 - x2 = self.gau1(features[-4], x3) # 1/4 + features = features[2:] # remove first and second skip + features = features[::-1] # reverse channels to start from head of encoder + + out = self.fpa(features[0]) - return x2 + for i in range(1, len(features)): + out = getattr(self, f"gau{len(features)-i}")(features[i], out) + return out From 5e6db7ef3ae2a9abc78b69f0059f4b9bde4129b1 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Wed, 4 Dec 2024 23:12:02 +0800 Subject: [PATCH 3/6] add typing and fix ruff style --- .../decoders/pan/decoder.py | 28 +++++++++++++------ .../decoders/pan/model.py | 8 +++--- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/segmentation_models_pytorch/decoders/pan/decoder.py b/segmentation_models_pytorch/decoders/pan/decoder.py index a1180d95..1361b654 100644 --- a/segmentation_models_pytorch/decoders/pan/decoder.py +++ b/segmentation_models_pytorch/decoders/pan/decoder.py @@ -1,3 +1,6 @@ +from collections.abc import Sequence +from typing import Literal + import torch import torch.nn as nn import torch.nn.functional as F @@ -44,7 +47,9 @@ def forward(self, x): class FPABlock(nn.Module): - def __init__(self, in_channels, out_channels, upscale_mode="bilinear"): + def __init__( + self, in_channels: int, out_channels: int, upscale_mode: str = "bilinear" + ): super(FPABlock, self).__init__() self.upscale_mode = upscale_mode @@ -118,7 +123,9 @@ def forward(self, x): mid = self.mid(x) x1 = self.down1(x) x2 = self.down2(x1) + print(x2.shape) x3 = self.down3(x2) + print(x3.shape) x3 = F.interpolate(x3, size=(h // 4, w // 4), **upscale_parameters) x2 = self.conv2(x2) @@ -176,9 +183,9 @@ def forward(self, x, y): class PANDecoder(nn.Module): def __init__( self, - encoder_channels, - encoder_depth, - decoder_channels, + encoder_channels: Sequence[int], + encoder_depth: Literal[3, 4, 5], + decoder_channels: int, upscale_mode: str = "bilinear", ): super().__init__() @@ -197,11 +204,14 @@ def __init__( ) for i in range(1, len(encoder_channels)): - self.add_module(f"gau{len(encoder_channels)-i}", GAUBlock( - in_channels=encoder_channels[i], - out_channels=decoder_channels, - upscale_mode=upscale_mode, - )) + self.add_module( + f"gau{len(encoder_channels)-i}", + GAUBlock( + in_channels=encoder_channels[i], + out_channels=decoder_channels, + upscale_mode=upscale_mode, + ), + ) def forward(self, *features): features = features[2:] # remove first and second skip diff --git a/segmentation_models_pytorch/decoders/pan/model.py b/segmentation_models_pytorch/decoders/pan/model.py index eba28492..712541a5 100644 --- a/segmentation_models_pytorch/decoders/pan/model.py +++ b/segmentation_models_pytorch/decoders/pan/model.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union +from typing import Any, Callable, Literal, Optional, Union from segmentation_models_pytorch.base import ( ClassificationHead, @@ -56,13 +56,13 @@ class PAN(SegmentationModel): def __init__( self, encoder_name: str = "resnet34", - encoder_depth: int = 5, + encoder_depth: Literal[3, 4, 5] = 5, encoder_weights: Optional[str] = "imagenet", - encoder_output_stride: int = 16, + encoder_output_stride: Literal[16, 32] = 16, decoder_channels: int = 32, in_channels: int = 3, classes: int = 1, - activation: Optional[Union[str, callable]] = None, + activation: Optional[Union[str, Callable]] = None, upsampling: int = 4, aux_params: Optional[dict] = None, **kwargs: dict[str, Any], From 1901eb731e0392e50ed8fba99b387d58812a49a7 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Wed, 4 Dec 2024 23:12:58 +0800 Subject: [PATCH 4/6] update PAN test sample size --- tests/test_models.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 10f697b8..68d12c43 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -31,9 +31,7 @@ def get_sample(model_class): smp.Segformer, ]: sample = torch.ones([1, 3, 64, 64]) - elif model_class == smp.PAN: - sample = torch.ones([2, 3, 256, 256]) - elif model_class in [smp.DeepLabV3, smp.DeepLabV3Plus]: + elif model_class in [smp.PAN, smp.DeepLabV3, smp.DeepLabV3Plus]: sample = torch.ones([2, 3, 128, 128]) elif model_class in [smp.PSPNet, smp.UPerNet]: # Batch size 2 needed due to nn.BatchNorm2d not supporting (1, C, 1, 1) input From 8d15c285ebfdb8abdd5de0fa0615200605fdd8d5 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Fri, 6 Dec 2024 03:46:46 +0800 Subject: [PATCH 5/6] del print --- segmentation_models_pytorch/decoders/pan/decoder.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/segmentation_models_pytorch/decoders/pan/decoder.py b/segmentation_models_pytorch/decoders/pan/decoder.py index 1361b654..0e832d3f 100644 --- a/segmentation_models_pytorch/decoders/pan/decoder.py +++ b/segmentation_models_pytorch/decoders/pan/decoder.py @@ -123,9 +123,7 @@ def forward(self, x): mid = self.mid(x) x1 = self.down1(x) x2 = self.down2(x1) - print(x2.shape) x3 = self.down3(x2) - print(x3.shape) x3 = F.interpolate(x3, size=(h // 4, w // 4), **upscale_parameters) x2 = self.conv2(x2) From aa1e0056ada125ea2e63ea7013af030dbafd9d6f Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Sun, 8 Dec 2024 03:03:28 +0800 Subject: [PATCH 6/6] update decoder --- .../decoders/pan/decoder.py | 41 ++++++++++++------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/segmentation_models_pytorch/decoders/pan/decoder.py b/segmentation_models_pytorch/decoders/pan/decoder.py index 0e832d3f..fa0bb261 100644 --- a/segmentation_models_pytorch/decoders/pan/decoder.py +++ b/segmentation_models_pytorch/decoders/pan/decoder.py @@ -195,28 +195,41 @@ def __init__( ) ) - encoder_channels = encoder_channels[2:][::-1] + encoder_channels = encoder_channels[2:] self.fpa = FPABlock( - in_channels=encoder_channels[0], out_channels=decoder_channels + in_channels=encoder_channels[-1], out_channels=decoder_channels ) - for i in range(1, len(encoder_channels)): - self.add_module( - f"gau{len(encoder_channels)-i}", - GAUBlock( - in_channels=encoder_channels[i], - out_channels=decoder_channels, - upscale_mode=upscale_mode, - ), + if encoder_depth == 5: + self.gau3 = GAUBlock( + in_channels=encoder_channels[2], + out_channels=decoder_channels, + upscale_mode=upscale_mode, + ) + if encoder_depth >= 4: + self.gau2 = GAUBlock( + in_channels=encoder_channels[1], + out_channels=decoder_channels, + upscale_mode=upscale_mode, + ) + if encoder_depth >= 3: + self.gau1 = GAUBlock( + in_channels=encoder_channels[0], + out_channels=decoder_channels, + upscale_mode=upscale_mode, ) def forward(self, *features): features = features[2:] # remove first and second skip - features = features[::-1] # reverse channels to start from head of encoder - out = self.fpa(features[0]) + out = self.fpa(features[-1]) # 1/16 or 1/32 + + if hasattr(self, "gau3"): + out = self.gau3(features[2], out) # 1/16 + if hasattr(self, "gau2"): + out = self.gau2(features[1], out) # 1/8 + if hasattr(self, "gau1"): + out = self.gau1(features[0], out) # 1/4 - for i in range(1, len(features)): - out = getattr(self, f"gau{len(features)-i}")(features[i], out) return out