Skip to content

Commit 7342b84

Browse files
committed
Address mypy suggestions for type annotations in cablock.py, downsample.py, restormer.py and test_downsample_block.py.
Signed-off-by: Cano-Muniz, Santiago <[email protected]>
1 parent aeebc89 commit 7342b84

File tree

4 files changed

+23
-18
lines changed

4 files changed

+23
-18
lines changed

monai/networks/blocks/cablock.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
# limitations under the License.
1111
from __future__ import annotations
1212

13+
from typing import cast
14+
1315
import torch
1416
import torch.nn as nn
1517
import torch.nn.functional as F
@@ -70,7 +72,7 @@ def __init__(self, spatial_dims: int, dim: int, ffn_expansion_factor: float, bia
7072
def forward(self, x: torch.Tensor) -> torch.Tensor:
7173
x = self.project_in(x)
7274
x1, x2 = self.dwconv(x).chunk(2, dim=1)
73-
return self.project_out(F.gelu(x1) * x2)
75+
return cast(torch.Tensor, self.project_out(F.gelu(x1) * x2))
7476

7577

7678
class CABlock(nn.Module):
@@ -141,7 +143,7 @@ def _normal_attention(self, q, k, v):
141143
attn = attn.softmax(dim=-1)
142144
return attn @ v
143145

144-
def forward(self, x) -> torch.Tensor:
146+
def forward(self, x: torch.Tensor) -> torch.Tensor:
145147
"""Forward pass for MDTA attention.
146148
1. Apply depth-wise convolutions to Q, K, V
147149
2. Reshape Q, K, V for multi-head attention
@@ -177,4 +179,4 @@ def forward(self, x) -> torch.Tensor:
177179
**dict(zip(["h", "w"] if self.spatial_dims == 2 else ["d", "h", "w"], spatial_dims)),
178180
)
179181

180-
return self.project_out(out)
182+
return cast(torch.Tensor, self.project_out(out))

monai/networks/blocks/downsample.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def __init__(
9292
out_channels: int | None = None,
9393
scale_factor: Sequence[float] | float = 2,
9494
kernel_size: Sequence[float] | float | None = None,
95-
mode: str = "conv", # conv, convgroup, nontrainable, pixelunshuffle
95+
mode: DownsampleMode | str = DownsampleMode.CONV,
9696
pre_conv: nn.Module | str | None = "default",
9797
post_conv: nn.Module | None = None,
9898
bias: bool = True,
@@ -101,11 +101,11 @@ def __init__(
101101
Downsamples data by `scale_factor`.
102102
Supported modes are:
103103
104-
- "conv": uses a strided convolution for learnable downsampling.
105-
- "convgroup": uses a grouped strided convolution for efficient feature reduction.
106-
- "maxpool": uses maxpooling for non-learnable downsampling.
107-
- "avgpool": uses average pooling for non-learnable downsampling.
108-
- "pixelunshuffle": uses :py:class:`monai.networks.blocks.SubpixelDownsample`.
104+
- DownsampleMode.CONV: uses a strided convolution for learnable downsampling.
105+
- DownsampleMode.CONVGROUP: uses a grouped strided convolution for efficient feature reduction.
106+
- DownsampleMode.MAXPOOL: uses maxpooling for non-learnable downsampling.
107+
- DownsampleMode.AVGPOOL: uses average pooling for non-learnable downsampling.
108+
- DownsampleMode.PIXELUNSHUFFLE: uses :py:class:`monai.networks.blocks.SubpixelDownsample`.
109109
110110
This operation will cause non-deterministic behavior when ``mode`` is ``DownsampleMode.NONTRAINABLE``.
111111
Please check the link below for more details:
@@ -120,7 +120,8 @@ def __init__(
120120
out_channels: number of channels of the output image. Defaults to `in_channels`.
121121
scale_factor: multiplier for spatial size reduction. Has to match input size if it is a tuple. Defaults to 2.
122122
kernel_size: kernel size used during convolutions. Defaults to `scale_factor`.
123-
mode: {``"conv"``, ``"convgroup"``, ``"maxpool"``, ``"avgpool"``, ``"pixelunshuffle"``}. Defaults to ``"conv"``.
123+
mode: {``DownsampleMode.CONV``, ``DownsampleMode.CONVGROUP``, ``DownsampleMode.MAXPOOL``, ``DownsampleMode.AVGPOOL``,
124+
``DownsampleMode.PIXELUNSHUFFLE``}. Defaults to ``DownsampleMode.CONV``.
124125
pre_conv: a conv block applied before downsampling. Defaults to "default".
125126
When ``conv_block`` is ``"default"``, one reserved conv layer will be utilized.
126127
Only used in the "maxpool", "avgpool" or "pixelunshuffle" modes.
@@ -134,7 +135,7 @@ def __init__(
134135

135136
if not kernel_size:
136137
kernel_size_ = scale_factor_
137-
padding = 0
138+
padding = ensure_tuple_rep(0, spatial_dims)
138139
else:
139140
kernel_size_ = ensure_tuple_rep(kernel_size, spatial_dims)
140141
padding = tuple((k - 1) // 2 for k in kernel_size_)

monai/networks/nets/restormer.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515

1616
from monai.networks.blocks.cablock import CABlock, FeedForward
1717
from monai.networks.blocks.convolutions import Convolution
18-
from monai.networks.blocks.downsample import DownSample, DownsampleMode
19-
from monai.networks.blocks.upsample import UpSample, UpsampleMode
18+
from monai.networks.blocks.downsample import DownSample
19+
from monai.networks.blocks.upsample import UpSample
2020
from monai.networks.layers.factories import Norm
21+
from monai.utils.enums import DownsampleMode, UpsampleMode
2122

2223

2324
class MDTATransformerBlock(nn.Module):
@@ -81,9 +82,9 @@ def __init__(self, spatial_dims: int, in_channels: int = 3, embed_dim: int = 48,
8182
conv_only=True,
8283
)
8384

84-
85-
def forward(self, x: torch.Tensor) -> torch.Tensor:
86-
return super().forward(x)
85+
def forward(self, x: torch.Tensor) -> torch.Tensor:
86+
x = super().forward(x)
87+
return x
8788

8889

8990
class Restormer(nn.Module):
@@ -290,7 +291,7 @@ def __init__(
290291
conv_only=True,
291292
)
292293

293-
def forward(self, x) -> torch.Tensor:
294+
def forward(self, x: torch.Tensor) -> torch.Tensor:
294295
"""Forward pass of Restormer.
295296
Processes input through encoder-decoder architecture with skip connections.
296297
Args:

tests/networks/blocks/test_downsample_block.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@ def __init__(self, n_feat: int):
146146

147147
def forward(self, x: torch.Tensor) -> torch.Tensor:
148148
x = self.conv(x)
149-
return self.pixelunshuffle(x)
149+
x = self.pixelunshuffle(x)
150+
return x
150151

151152
n_feat = 2
152153
x = torch.randn(1, n_feat, 64, 64)

0 commit comments

Comments
 (0)