diff --git a/torchseg/decoders/unetplusplus/model.py b/torchseg/decoders/unetplusplus/model.py index 72aa5db9..ede695d7 100644 --- a/torchseg/decoders/unetplusplus/model.py +++ b/torchseg/decoders/unetplusplus/model.py @@ -46,6 +46,9 @@ class UnetPlusPlus(SegmentationModel): - dropout (float): Dropout factor in [0, 1) - activation (str): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits) + head_upsampling: Factor to upsample input to segmentation head. Defaults to 1. + This allows for use of U-Net decoder with models that need additional + upsampling to be at the original input image resolution. Reference: https://arxiv.org/abs/1807.10165 @@ -67,6 +70,7 @@ def __init__( activation: Callable = nn.Identity(), encoder_params: dict = {}, aux_params: Optional[dict] = None, + head_upsampling: int = 1, ): super().__init__() @@ -97,6 +101,7 @@ def __init__( out_channels=classes, activation=activation, kernel_size=3, + upsampling=head_upsampling, ) if aux_params is not None: