Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update PAN Decoder support encoder depth #999

Merged
merged 6 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 50 additions & 23 deletions segmentation_models_pytorch/decoders/pan/decoder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from collections.abc import Sequence
from typing import Literal

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -44,7 +47,9 @@ def forward(self, x):


class FPABlock(nn.Module):
def __init__(self, in_channels, out_channels, upscale_mode="bilinear"):
def __init__(
self, in_channels: int, out_channels: int, upscale_mode: str = "bilinear"
):
super(FPABlock, self).__init__()

self.upscale_mode = upscale_mode
Expand Down Expand Up @@ -175,34 +180,56 @@ def forward(self, x, y):

class PANDecoder(nn.Module):
def __init__(
self, encoder_channels, decoder_channels, upscale_mode: str = "bilinear"
self,
encoder_channels: Sequence[int],
encoder_depth: Literal[3, 4, 5],
decoder_channels: int,
upscale_mode: str = "bilinear",
):
super().__init__()

if encoder_depth < 3:
raise ValueError(
"Encoder depth for PAN decoder cannot be less than 3, got {}.".format(
encoder_depth
)
)

encoder_channels = encoder_channels[2:]

self.fpa = FPABlock(
in_channels=encoder_channels[-1], out_channels=decoder_channels
)
self.gau3 = GAUBlock(
in_channels=encoder_channels[-2],
out_channels=decoder_channels,
upscale_mode=upscale_mode,
)
self.gau2 = GAUBlock(
in_channels=encoder_channels[-3],
out_channels=decoder_channels,
upscale_mode=upscale_mode,
)
self.gau1 = GAUBlock(
in_channels=encoder_channels[-4],
out_channels=decoder_channels,
upscale_mode=upscale_mode,
)

if encoder_depth == 5:
self.gau3 = GAUBlock(
in_channels=encoder_channels[2],
out_channels=decoder_channels,
upscale_mode=upscale_mode,
)
if encoder_depth >= 4:
self.gau2 = GAUBlock(
in_channels=encoder_channels[1],
out_channels=decoder_channels,
upscale_mode=upscale_mode,
)
if encoder_depth >= 3:
self.gau1 = GAUBlock(
in_channels=encoder_channels[0],
out_channels=decoder_channels,
upscale_mode=upscale_mode,
)

def forward(self, *features):
bottleneck = features[-1]
x5 = self.fpa(bottleneck) # 1/32
x4 = self.gau3(features[-2], x5) # 1/16
x3 = self.gau2(features[-3], x4) # 1/8
x2 = self.gau1(features[-4], x3) # 1/4
features = features[2:] # remove first and second skip

out = self.fpa(features[-1]) # 1/16 or 1/32

if hasattr(self, "gau3"):
out = self.gau3(features[2], out) # 1/16
if hasattr(self, "gau2"):
out = self.gau2(features[1], out) # 1/8
if hasattr(self, "gau1"):
out = self.gau1(features[0], out) # 1/4

return x2
return out
14 changes: 10 additions & 4 deletions segmentation_models_pytorch/decoders/pan/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, Union
from typing import Any, Callable, Literal, Optional, Union

from segmentation_models_pytorch.base import (
ClassificationHead,
Expand All @@ -20,6 +20,10 @@ class PAN(SegmentationModel):
Args:
encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
to extract features of different spatial resolution
encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
Default is 5
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
other pretrained weights (see table with available weights for each encoder_name)
encoder_output_stride: 16 or 32, if 16 use dilation in encoder last layer.
Expand Down Expand Up @@ -52,12 +56,13 @@ class PAN(SegmentationModel):
def __init__(
self,
encoder_name: str = "resnet34",
encoder_depth: Literal[3, 4, 5] = 5,
encoder_weights: Optional[str] = "imagenet",
encoder_output_stride: int = 16,
encoder_output_stride: Literal[16, 32] = 16,
decoder_channels: int = 32,
in_channels: int = 3,
classes: int = 1,
activation: Optional[Union[str, callable]] = None,
activation: Optional[Union[str, Callable]] = None,
brianhou0208 marked this conversation as resolved.
Show resolved Hide resolved
upsampling: int = 4,
aux_params: Optional[dict] = None,
**kwargs: dict[str, Any],
Expand All @@ -74,14 +79,15 @@ def __init__(
self.encoder = get_encoder(
encoder_name,
in_channels=in_channels,
depth=5,
depth=encoder_depth,
weights=encoder_weights,
output_stride=encoder_output_stride,
**kwargs,
)

self.decoder = PANDecoder(
encoder_channels=self.encoder.out_channels,
encoder_depth=encoder_depth,
decoder_channels=decoder_channels,
)

Expand Down
4 changes: 1 addition & 3 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ def get_sample(model_class):
smp.Segformer,
]:
sample = torch.ones([1, 3, 64, 64])
elif model_class == smp.PAN:
sample = torch.ones([2, 3, 256, 256])
elif model_class in [smp.DeepLabV3, smp.DeepLabV3Plus]:
elif model_class in [smp.PAN, smp.DeepLabV3, smp.DeepLabV3Plus]:
sample = torch.ones([2, 3, 128, 128])
elif model_class in [smp.PSPNet, smp.UPerNet]:
# Batch size 2 needed due to nn.BatchNorm2d not supporting (1, C, 1, 1) input
Expand Down
Loading