2121all = ["MedNeXtBlock" , "MedNeXtDownBlock" , "MedNeXtUpBlock" , "MedNeXtOutBlock" ]
2222
2323
24+ def get_conv_layer (spatial_dim : int = 3 , transpose : bool = False ):
25+ if spatial_dim == 2 :
26+ return nn .ConvTranspose2d if transpose else nn .Conv2d
27+ else : # spatial_dim == 3
28+ return nn .ConvTranspose3d if transpose else nn .Conv3d
29+
30+
2431class MedNeXtBlock (nn .Module ):
2532
2633 def __init__ (
@@ -39,18 +46,9 @@ def __init__(
3946
4047 self .do_res = use_residual_connection
4148
42- assert dim in ["2d" , "3d" ]
4349 self .dim = dim
44- if self .dim == "2d" :
45- conv = nn .Conv2d
46- normalized_shape = [in_channels , kernel_size , kernel_size ]
47- grn_parameter_shape = (1 , 1 )
48- elif self .dim == "3d" :
49- conv = nn .Conv3d
50- normalized_shape = [in_channels , kernel_size , kernel_size , kernel_size ]
51- grn_parameter_shape = (1 , 1 , 1 )
52- else :
53- raise ValueError ("dim must be either '2d' or '3d'" )
50+ conv = get_conv_layer (spatial_dim = 2 if dim == "2d" else 3 )
51+ grn_parameter_shape = (1 ,) * (2 if dim == "2d" else 3 )
5452 # First convolution layer with DepthWise Convolutions
5553 self .conv1 = conv (
5654 in_channels = in_channels ,
@@ -63,9 +61,11 @@ def __init__(
6361
6462 # Normalization Layer. GroupNorm is used by default.
6563 if norm_type == "group" :
66- self .norm = nn .GroupNorm (num_groups = in_channels , num_channels = in_channels )
64+ self .norm = nn .GroupNorm (num_groups = in_channels , num_channels = in_channels ) # type: ignore
6765 elif norm_type == "layer" :
68- self .norm = nn .LayerNorm (normalized_shape = normalized_shape )
66+ self .norm = nn .LayerNorm (
67+ normalized_shape = [in_channels ] + [kernel_size ] * (2 if dim == "2d" else 3 ) # type: ignore
68+ )
6969 # Second convolution (Expansion) layer with Conv3D 1x1x1
7070 self .conv2 = conv (
7171 in_channels = in_channels , out_channels = expansion_ratio * in_channels , kernel_size = 1 , stride = 1 , padding = 0
@@ -131,10 +131,7 @@ def __init__(
131131 grn = grn ,
132132 )
133133
134- if dim == "2d" :
135- conv = nn .Conv2d
136- else :
137- conv = nn .Conv3d
134+ conv = get_conv_layer (spatial_dim = 2 if dim == "2d" else 3 )
138135 self .resample_do_res = use_residual_connection
139136 if use_residual_connection :
140137 self .res_conv = conv (in_channels = in_channels , out_channels = out_channels , kernel_size = 1 , stride = 2 )
@@ -186,10 +183,7 @@ def __init__(
186183 self .resample_do_res = use_residual_connection
187184
188185 self .dim = dim
189- if dim == "2d" :
190- conv = nn .ConvTranspose2d
191- else :
192- conv = nn .ConvTranspose3d
186+ conv = get_conv_layer (spatial_dim = 2 if dim == "2d" else 3 , transpose = True )
193187 if use_residual_connection :
194188 self .res_conv = conv (in_channels = in_channels , out_channels = out_channels , kernel_size = 1 , stride = 2 )
195189
@@ -228,10 +222,7 @@ class MedNeXtOutBlock(nn.Module):
228222 def __init__ (self , in_channels , n_classes , dim ):
229223 super ().__init__ ()
230224
231- if dim == "2d" :
232- conv = nn .ConvTranspose2d
233- else :
234- conv = nn .ConvTranspose3d
225+ conv = get_conv_layer (spatial_dim = 2 if dim == "2d" else 3 , transpose = True )
235226 self .conv_out = conv (in_channels , n_classes , kernel_size = 1 )
236227
237228 def forward (self , x ):
0 commit comments