Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepLabV3Plus is not compatible with encoder_depth=4 and swin models #54

Open
Akshay1-6180 opened this issue Jul 19, 2024 · 4 comments
Open

Comments

@Akshay1-6180
Copy link

So i was working with both swinv2_tiny_window8_256 and swinv2_base_window12to16_192to256 and noticed that it was not loading with torchseg.DeepLabV3Plus

model = torchseg.DeepLabV3Plus(
    "swinv2_base_window12to16_192to256",
    in_channels=1,
    classes=2,
    encoder_weights=True,
    encoder_depth=4,
    decoder_channels=256,
    encoder_output_stride=16,
    encoder_params={"img_size": 1024}  # need to define img size since swin is a ViT hybrid
)

in both the cases it gives this error , for a sample code here u go

import torchseg
model = torchseg.DeepLabV3Plus(
    "swinv2_tiny_window8_256",
    in_channels=1,
    classes=2,
    encoder_weights=True,
    encoder_depth=4,
    decoder_channels=256,
    encoder_output_stride=16,
    encoder_params={"img_size": 1024}  # need to define img size since swin is a ViT hybrid
)
dummy_im = torch.randn(
            4, 1, 1024, 1024
        )  
out = model.encoder(dummy_im)
dummy_dec_out = model.decoder(*out)

It gives this error
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 128 but got size 256 for tensor number 1 in the list.
But this error occurs also with resent50 when the encoder depth = 4

model = torchseg.DeepLabV3Plus(
    "resnet50",
    in_channels=1,
    classes=2,
    encoder_weights=True,
    encoder_depth=4,
    decoder_channels=256,
    encoder_output_stride=16,
    #encoder_params={"img_size": 1024}  # need to define img size since swin is a ViT hybrid
)

So i changed the encoder depth as 5 and this worked now for resent50.
But swin models have a maximum depth of 4 for the encoders and this makes it incompatible to work with swin models , is there any easy fix around for it ?

@Akshay1-6180
Copy link
Author

The issue lies here and it gets solved when u do
highres_in_channels = encoder_channels[-3] and
high_res_features = self.block1(features[-3])
not sure if its a good workaround , would love to hear others opinion

class DeepLabV3PlusDecoder(nn.Module):
    def __init__(
        self,
        encoder_channels,
        out_channels=256,
        atrous_rates=(12, 24, 36),
        output_stride=16,
    ):
        super().__init__()
        if output_stride not in {8, 16}:
            raise ValueError(f"Output stride should be 8 or 16, got {output_stride}.")

        self.out_channels = out_channels
        self.output_stride = output_stride

        self.aspp = nn.Sequential(
            ASPP(encoder_channels[-1], out_channels, atrous_rates, separable=True),
            SeparableConv2d(
                out_channels, out_channels, kernel_size=3, padding=1, bias=False
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

        scale_factor = 2 if output_stride == 8 else 4
        self.up = nn.UpsamplingBilinear2d(scale_factor=scale_factor)

        highres_in_channels = encoder_channels[-3]  # Changed from -4 to -3
        highres_out_channels = 48  # proposed by authors of paper
        self.block1 = nn.Sequential(
            nn.Conv2d(
                highres_in_channels, highres_out_channels, kernel_size=1, bias=False
            ),
            nn.BatchNorm2d(highres_out_channels),
            nn.ReLU(),
        )
        self.block2 = nn.Sequential(
            SeparableConv2d(
                highres_out_channels + out_channels,
                out_channels,
                kernel_size=3,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, *features):
        aspp_features = self.aspp(features[-1])
        aspp_features = self.up(aspp_features)
        high_res_features = self.block1(features[-3])  # Changed from -4 to -3

        concat_features = torch.cat([aspp_features, high_res_features], dim=1)
        fused_features = self.block2(concat_features)
        return fused_features

@Akshay1-6180
Copy link
Author

It also gets resolved if the scale factor is 8 , but not sure of the far reaching implications of this change while training , needs to be empirically tested.
self.up = nn.UpsamplingBilinear2d(scale_factor=8)

@Akshay1-6180
Copy link
Author

But these changes would make it incompatible with encoder_depth=5 , so there should be a way to handle different depth cases

@Akshay1-6180
Copy link
Author

@isaaccorley any idea on how to go about this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant