Skip to content

Commit a9af877

Browse files
committed
Fix mypy errors
1 parent fc10369 commit a9af877

File tree

3 files changed

+25
-34
lines changed

3 files changed

+25
-34
lines changed

monai/networks/blocks/mednext_block.py

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,13 @@
2121
all = ["MedNeXtBlock", "MedNeXtDownBlock", "MedNeXtUpBlock", "MedNeXtOutBlock"]
2222

2323

24+
def get_conv_layer(spatial_dim: int = 3, transpose: bool = False):
25+
if spatial_dim == 2:
26+
return nn.ConvTranspose2d if transpose else nn.Conv2d
27+
else: # spatial_dim == 3
28+
return nn.ConvTranspose3d if transpose else nn.Conv3d
29+
30+
2431
class MedNeXtBlock(nn.Module):
2532

2633
def __init__(
@@ -39,18 +46,9 @@ def __init__(
3946

4047
self.do_res = use_residual_connection
4148

42-
assert dim in ["2d", "3d"]
4349
self.dim = dim
44-
if self.dim == "2d":
45-
conv = nn.Conv2d
46-
normalized_shape = [in_channels, kernel_size, kernel_size]
47-
grn_parameter_shape = (1, 1)
48-
elif self.dim == "3d":
49-
conv = nn.Conv3d
50-
normalized_shape = [in_channels, kernel_size, kernel_size, kernel_size]
51-
grn_parameter_shape = (1, 1, 1)
52-
else:
53-
raise ValueError("dim must be either '2d' or '3d'")
50+
conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3)
51+
grn_parameter_shape = (1,) * (2 if dim == "2d" else 3)
5452
# First convolution layer with DepthWise Convolutions
5553
self.conv1 = conv(
5654
in_channels=in_channels,
@@ -63,9 +61,11 @@ def __init__(
6361

6462
# Normalization Layer. GroupNorm is used by default.
6563
if norm_type == "group":
66-
self.norm = nn.GroupNorm(num_groups=in_channels, num_channels=in_channels)
64+
self.norm = nn.GroupNorm(num_groups=in_channels, num_channels=in_channels) # type: ignore
6765
elif norm_type == "layer":
68-
self.norm = nn.LayerNorm(normalized_shape=normalized_shape)
66+
self.norm = nn.LayerNorm(
67+
normalized_shape=[in_channels] + [kernel_size] * (2 if dim == "2d" else 3) # type: ignore
68+
)
6969
# Second convolution (Expansion) layer with Conv3D 1x1x1
7070
self.conv2 = conv(
7171
in_channels=in_channels, out_channels=expansion_ratio * in_channels, kernel_size=1, stride=1, padding=0
@@ -131,10 +131,7 @@ def __init__(
131131
grn=grn,
132132
)
133133

134-
if dim == "2d":
135-
conv = nn.Conv2d
136-
else:
137-
conv = nn.Conv3d
134+
conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3)
138135
self.resample_do_res = use_residual_connection
139136
if use_residual_connection:
140137
self.res_conv = conv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2)
@@ -186,10 +183,7 @@ def __init__(
186183
self.resample_do_res = use_residual_connection
187184

188185
self.dim = dim
189-
if dim == "2d":
190-
conv = nn.ConvTranspose2d
191-
else:
192-
conv = nn.ConvTranspose3d
186+
conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3, transpose=True)
193187
if use_residual_connection:
194188
self.res_conv = conv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2)
195189

@@ -228,10 +222,7 @@ class MedNeXtOutBlock(nn.Module):
228222
def __init__(self, in_channels, n_classes, dim):
229223
super().__init__()
230224

231-
if dim == "2d":
232-
conv = nn.ConvTranspose2d
233-
else:
234-
conv = nn.ConvTranspose3d
225+
conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3, transpose=True)
235226
self.conv_out = conv(in_channels, n_classes, kernel_size=1)
236227

237228
def forward(self, x):

monai/networks/nets/mednext.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ def __init__(
7272
init_filters: int = 32,
7373
in_channels: int = 1,
7474
out_channels: int = 2,
75-
encoder_expansion_ratio: int = 2,
76-
decoder_expansion_ratio: int = 2,
75+
encoder_expansion_ratio: Sequence[int] | int = 2,
76+
decoder_expansion_ratio: Sequence[int] | int = 2,
7777
bottleneck_expansion_ratio: int = 2,
7878
kernel_size: int = 7,
7979
deep_supervision: bool = False,
@@ -212,7 +212,7 @@ def __init__(
212212
out_blocks.reverse()
213213
self.out_blocks = nn.ModuleList(out_blocks)
214214

215-
def forward(self, x: torch.Tensor) -> torch.Tensor | list[torch.Tensor]:
215+
def forward(self, x: torch.Tensor) -> torch.Tensor | Sequence[torch.Tensor]:
216216
"""
217217
Forward pass of the MedNeXt model.
218218
@@ -227,7 +227,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor | list[torch.Tensor]:
227227
x (torch.Tensor): Input tensor.
228228
229229
Returns:
230-
torch.Tensor or list[torch.Tensor]: Output tensor(s).
230+
torch.Tensor or Sequence[torch.Tensor]: Output tensor(s).
231231
"""
232232
# Apply stem convolution
233233
x = self.stem(x)
@@ -311,7 +311,7 @@ def create_mednext(
311311
blocks_down=(2, 2, 2, 2),
312312
blocks_bottleneck=2,
313313
blocks_up=(2, 2, 2, 2),
314-
**common_args,
314+
**common_args, # type: ignore
315315
)
316316
elif variant.upper() == "B":
317317
return MedNeXt(
@@ -321,7 +321,7 @@ def create_mednext(
321321
blocks_down=(2, 2, 2, 2),
322322
blocks_bottleneck=2,
323323
blocks_up=(2, 2, 2, 2),
324-
**common_args,
324+
**common_args, # type: ignore
325325
)
326326
elif variant.upper() == "M":
327327
return MedNeXt(
@@ -331,7 +331,7 @@ def create_mednext(
331331
blocks_down=(3, 4, 4, 4),
332332
blocks_bottleneck=4,
333333
blocks_up=(4, 4, 4, 3),
334-
**common_args,
334+
**common_args, # type: ignore
335335
)
336336
elif variant.upper() == "L":
337337
return MedNeXt(
@@ -341,7 +341,7 @@ def create_mednext(
341341
blocks_down=(3, 4, 8, 8),
342342
blocks_bottleneck=8,
343343
blocks_up=(8, 8, 4, 3),
344-
**common_args,
344+
**common_args, # type: ignore
345345
)
346346
else:
347347
raise ValueError(f"Invalid MedNeXt variant: {variant}")

tests/test_mednext.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
for spatial_dims in range(2, 4):
6060
for out_channels in [1, 2]:
6161
test_case = [
62-
model,
62+
model, # type: ignore
6363
{"spatial_dims": spatial_dims, "in_channels": 1, "out_channels": out_channels},
6464
(2, 1, *([16] * spatial_dims)),
6565
(2, out_channels, *([16] * spatial_dims)),

0 commit comments

Comments
 (0)