Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix encoder depth & output stride on DeeplabV3 & DeeplabV3+ #991

Merged
merged 8 commits into from
Dec 9, 2024

Conversation

brianhou0208
Copy link
Contributor

Hi @qubvel ,

This PR solve issues in DeepLabV3 and DeepLabV3+ models that may occur with different encoder_depth and output_stride configurations.

Updates & Fixes

  1. Added validation for the output_stride parameter.
  2. Removed redundant out_channels parameter in the decoder.
  3. Fixed segmentation_head.upsampling to ensure input and output sizes match in DeepLabV3 for all encoder_depth and output_stride settings.
  4. Adjusted high_res_features in DeepLabV3+ to always use a 1/4 resolution of the input, ensuring no errors occur with encoder_depth values between 3 and 5.

Encoder Depth & Output Stride

  • Test model: ResNet18
  • Test input shape: (256, 256)

Encoder Output Shape

Output Stride \ Depth 1 2 3 4 5
32 (128, 128) (64, 64) (32, 32) (16, 16) (8, 8)
16 (128, 128) (64, 64) (32, 32) (16, 16) (16, 16)
8 (128, 128) (64, 64) (32, 32) (32, 32) (32, 32)

Downsample factor

Output Stride \ Depth 1 2 3 4 5
32 1/2 1/4 1/8 1/16 1/32
16 1/2 1/4 1/8 1/16 1/16
8 1/2 1/4 1/8 1/8 1/8

DeeplabV3

From the tables above:

When encoder_depth is between 1 and 3, the upsampling ratio between input and output is scale_factor=2**encoder_depth.
When encoder_depth is 4 or 5, the upsampling ratio is scale_factor=output_stride.

The following code ensures input and output sizes match for all encoder_depth values in DeepLabV3:

if upsampling is None:
    if encoder_depth <= 3:
        scale_factor = 2**encoder_depth
    else:
        scale_factor = encoder_output_stride
else:
    scale_factor = upsampling

Test Code

import torch
import segmentation_models_pytorch as smp

def test_depth(depth, output_stride):
    x = torch.rand(1, 3, 256, 256)
    model = smp.DeepLabV3(
        "resnet18",
        encoder_depth=depth,
        encoder_weights=None,
        encoder_output_stride=output_stride,
    ).eval()
    y = model(x)
    assert x.shape[2:] == y.shape[2:]
    print(
        f"Encoder Depth: {depth} | Output Stride: {output_stride} : "
        f"Input/Output Shape {x.detach().numpy().shape[2:]}/{y.detach().numpy().shape[2:]}"      
    )
if __name__ == "__main__":
    for depth in [5, 4, 3, 2, 1]:
        for output_stride in [8, 16]:
            test_depth(depth, output_stride)
    print("all pass")

Output

Encoder Depth: 5 | Output Stride: 8 : Input/Output Shape (256, 256)/(256, 256)
Encoder Depth: 5 | Output Stride: 16 : Input/Output Shape (256, 256)/(256, 256)
Encoder Depth: 4 | Output Stride: 8 : Input/Output Shape (256, 256)/(256, 256)
Encoder Depth: 4 | Output Stride: 16 : Input/Output Shape (256, 256)/(256, 256)
Encoder Depth: 3 | Output Stride: 8 : Input/Output Shape (256, 256)/(256, 256)
Encoder Depth: 3 | Output Stride: 16 : Input/Output Shape (256, 256)/(256, 256)
Encoder Depth: 2 | Output Stride: 8 : Input/Output Shape (256, 256)/(256, 256)
Encoder Depth: 2 | Output Stride: 16 : Input/Output Shape (256, 256)/(256, 256)
Encoder Depth: 1 | Output Stride: 8 : Input/Output Shape (256, 256)/(256, 256)
Encoder Depth: 1 | Output Stride: 16 : Input/Output Shape (256, 256)/(256, 256)
all pass

DeeplabV3+

Originally, DeepLabV3+ only supported encoder_depth values of 4 and 5 due to:

high_res_features = self.block1(features[-4])
  • With encoder_depth=5, there are no errors.
  • With encoder_depth=4, errors occur with MixVisionTransformer encoders (e.g., dump feature channel = 0).
  • With encoder_depth=3, it triggers a "list index out of range" error.

By fixing high_res_features to always use the 1/4 resolution input and limiting encoder_depth to >=3, these issues are resolved:

high_res_features = self.block1(features[2])

Test Code

import torch
import segmentation_models_pytorch as smp

def test_depth(depth, output_stride):
    x = torch.rand(1, 3, 256, 256)
    model = smp.DeepLabV3Plus(
        "mit_b0",
        encoder_depth=depth,
        encoder_weights=None,
        encoder_output_stride=output_stride,
    ).eval()
    y = model(x)
    assert x.shape[2:] == y.shape[2:]
    print(
        f"Encoder Depth: {depth} | Output Stride: {output_stride} : "
        f"Input/Output Shape {x.detach().numpy().shape[2:]}/{y.detach().numpy().shape[2:]}"      
    )
if __name__ == "__main__":
    for depth in [5, 4, 3]:
        for output_stride in [8, 16]:
            test_depth(depth, output_stride)
    print("all pass")

Output

Encoder Depth: 5 | Output Stride: 8 : Input/Output Shape (256, 256)/(256, 256)
Encoder Depth: 5 | Output Stride: 16 : Input/Output Shape (256, 256)/(256, 256)
Encoder Depth: 4 | Output Stride: 8 : Input/Output Shape (256, 256)/(256, 256)
Encoder Depth: 4 | Output Stride: 16 : Input/Output Shape (256, 256)/(256, 256)
Encoder Depth: 3 | Output Stride: 8 : Input/Output Shape (256, 256)/(256, 256)
Encoder Depth: 3 | Output Stride: 16 : Input/Output Shape (256, 256)/(256, 256)
all pass

@brianhou0208
Copy link
Contributor Author

This is a PR that I tried to fix before, but so far it seems that some problems of DeeplabV3+ have been solved and merged by #986. I only set this PR as a draft...

@munehiro-k
Copy link
Contributor

Hi, @brianhou0208.

For the DeeplabV3+ part, your fix may be better because you can set upsampling of DeepLabV3Plus to 4 regardless of the values of encoder_depth and encoder_output_stride to preserve in/out tensor sizes.

I made a trial merge commit because there is a conflict between my PR and yours and my PR has already been merged. Can you please check it?

The commits being compared are d490cdf and munehiro-k/segmentation_models.pytorch@2f57c5d.

Test Code Outputs

My one in #986 (note that the upsampling=4 setting always preserves in/out tensor sizes)

encoder_depth=3, encoder_output_stride= 8, upsampling=2
  output shape: (10, 1, 96, 64), preserve shape: False
encoder_depth=3, encoder_output_stride=16, upsampling=2
  output shape: (10, 1, 96, 64), preserve shape: False
encoder_depth=4, encoder_output_stride= 8, upsampling=2
  output shape: (10, 1, 96, 64), preserve shape: False
encoder_depth=4, encoder_output_stride=16, upsampling=2
  output shape: (10, 1, 96, 64), preserve shape: False
encoder_depth=5, encoder_output_stride= 8, upsampling=2
  output shape: (10, 1, 96, 64), preserve shape: False
encoder_depth=5, encoder_output_stride=16, upsampling=2
  output shape: (10, 1, 96, 64), preserve shape: False
encoder_depth=3, encoder_output_stride= 8, upsampling=4
  output shape: (10, 1, 192, 128), preserve shape: True
encoder_depth=3, encoder_output_stride=16, upsampling=4
  output shape: (10, 1, 192, 128), preserve shape: True
encoder_depth=4, encoder_output_stride= 8, upsampling=4
  output shape: (10, 1, 192, 128), preserve shape: True
encoder_depth=4, encoder_output_stride=16, upsampling=4
  output shape: (10, 1, 192, 128), preserve shape: True
encoder_depth=5, encoder_output_stride= 8, upsampling=4
  output shape: (10, 1, 192, 128), preserve shape: True
encoder_depth=5, encoder_output_stride=16, upsampling=4
  output shape: (10, 1, 192, 128), preserve shape: True

DeeplabV3

Encoder Depth: 5 | Output Stride: 8 : Input/Output Shape (256, 256)/(256, 256)
Encoder Depth: 5 | Output Stride: 16 : Input/Output Shape (256, 256)/(256, 256)
Encoder Depth: 4 | Output Stride: 8 : Input/Output Shape (256, 256)/(256, 256)
Encoder Depth: 4 | Output Stride: 16 : Input/Output Shape (256, 256)/(256, 256)
Encoder Depth: 3 | Output Stride: 8 : Input/Output Shape (256, 256)/(256, 256)
Encoder Depth: 3 | Output Stride: 16 : Input/Output Shape (256, 256)/(256, 256)
Encoder Depth: 2 | Output Stride: 8 : Input/Output Shape (256, 256)/(256, 256)
Encoder Depth: 2 | Output Stride: 16 : Input/Output Shape (256, 256)/(256, 256)
Encoder Depth: 1 | Output Stride: 8 : Input/Output Shape (256, 256)/(256, 256)
Encoder Depth: 1 | Output Stride: 16 : Input/Output Shape (256, 256)/(256, 256)
all pass

DeeplabV3+

Encoder Depth: 5 | Output Stride: 8 : Input/Output Shape (256, 256)/(256, 256)
Encoder Depth: 5 | Output Stride: 16 : Input/Output Shape (256, 256)/(256, 256)
Encoder Depth: 4 | Output Stride: 8 : Input/Output Shape (256, 256)/(256, 256)
Encoder Depth: 4 | Output Stride: 16 : Input/Output Shape (256, 256)/(256, 256)
Encoder Depth: 3 | Output Stride: 8 : Input/Output Shape (256, 256)/(256, 256)
Encoder Depth: 3 | Output Stride: 16 : Input/Output Shape (256, 256)/(256, 256)
all pass

@brianhou0208
Copy link
Contributor Author

Hi @munehiro-k ,
I’ve resolved the conflicts caused by the PR #986

It seems you identified this issue earlier than I did, and your PR #986 also solved some problems with DeepLabV3+.
I hadn’t noticed PR #561 before; I only came across this issue while attempting to improve the CI/test coverage for all models in test/test_models.py.

I observed that if high_res_features cannot consistently be fixed at 1/4 of the input resolution, using MixVisionTransformer or TIMM ConvNeXt/ConvNextV2 with encoder_depth set to 3 or 4 will result in errors.

@brianhou0208 brianhou0208 changed the title [Draft] Fix encoder depth & output stride on DeeplabV3 & DeeplabV3+ Fix encoder depth & output stride on DeeplabV3 & DeeplabV3+ Nov 30, 2024
@brianhou0208 brianhou0208 marked this pull request as ready for review November 30, 2024 12:16
@munehiro-k
Copy link
Contributor

munehiro-k commented Nov 30, 2024

The merged codes (munehiro-k:merge/fix_deeplab and brianhou0208:fix_deeplab) are quite similar, but brianhou0208:fix_deeplab deletes a part of the docstring. The deletion of aux_params: ... should be restored.

I put diffs between my merge and brianhou0208's merge.

legend

< munehiro-k:merge/fix_deeplab
> brianhou0208:fix_deeplab

model.py

82d81
<
149d147
<         aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
184a183,189
>
>         if encoder_output_stride not in [8, 16]:
>             raise ValueError(
>                 "DeeplabV3Plus support output stride 8 or 16, got {}.".format(
>                     encoder_output_stride
>                 )
>             )

decoder.py

81c81
<         if encoder_depth not in (3, 4, 5):
---
>         if encoder_depth < 3:
83,87c83,85
<                 "Encoder depth should be 3, 4 or 5, got {}.".format(encoder_depth)
<             )
<         if output_stride not in (8, 16):
<             raise ValueError(
<                 "Output stride should be 8 or 16, got {}.".format(output_stride)
---
>                 "Encoder depth for DeepLabV3Plus decoder cannot be less than 3, got {}.".format(
>                     encoder_depth
>                 )

Copy link
Collaborator

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update!

self.highres_input_index = -4

highres_in_channels = encoder_channels[self.highres_input_index]
highres_in_channels = encoder_channels[2]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only concern here is speed, cause the higher resolution feature we take, the more we need to process

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there might have been some misunderstanding.

Currently, regardless of the encoder_depth, high_res_features is always fixed at 1/4 of the input resolution, which improves speed.

Before #986 and #991, the resolution of high_res_features varied depending on the encoder_depth:

  • encoder_depth=5: high_res_features was 1/4 of the input resolution.
  • encoder_depth=4: high_res_features was 1/2 of the input resolution.
  • encoder_depth=3: high_res_features matched the input resolution.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For your reference, in PR #986, I updated the high_res_features logic as follows:

  • When encoder_depth=5: encoder_channels[-4], which corresponds to index 2 (1/4 of the input resolution).
  • When encoder_depth=4: encoder_channels[-3], which corresponds to index 2 (1/4 of the input resolution).
  • When encoder_depth=3:
    • If encoder_output_stride=8: encoder_channels[-2], which corresponds to index 3 (1/2 of the input resolution).
    • If encoder_output_stride=16: encoder_channels[-3], which corresponds to index 2 (1/4 of the input resolution).

The only difference in high_res_features is when encoder_depth=3 and encoder_output_stride=8.
I think PR #991 is preferable because it allows you to consistently set upsampling to 4 to preserve the input/output tensor sizes. In contrast, PR #986 requires setting upsampling to 2 to maintain sizes when encoder_depth=3 and encoder_output_stride=8.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explanations! It's clearer now 🙏

@qubvel qubvel merged commit 7c64aa5 into qubvel-org:main Dec 9, 2024
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants