1717
1818import torch
1919import torch .nn as nn
20- import torch .nn .functional as F
20+
21+ all = ["MedNeXtBlock" , "MedNeXtDownBlock" , "MedNeXtUpBlock" , "MedNeXtOutBlock" ]
2122
2223
2324class MedNeXtBlock (nn .Module ):
@@ -26,63 +27,65 @@ def __init__(
2627 self ,
2728 in_channels : int ,
2829 out_channels : int ,
29- exp_r : int = 4 ,
30+ expansion_ratio : int = 4 ,
3031 kernel_size : int = 7 ,
31- do_res : int = True ,
32+ use_residual_connection : int = True ,
3233 norm_type : str = "group" ,
33- n_groups : int or None = None ,
3434 dim = "3d" ,
3535 grn = False ,
3636 ):
3737
3838 super ().__init__ ()
3939
40- self .do_res = do_res
40+ self .do_res = use_residual_connection
4141
4242 assert dim in ["2d" , "3d" ]
4343 self .dim = dim
4444 if self .dim == "2d" :
4545 conv = nn .Conv2d
46- else :
46+ normalized_shape = [in_channels , kernel_size , kernel_size ]
47+ grn_parameter_shape = (1 , 1 )
48+ elif self .dim == "3d" :
4749 conv = nn .Conv3d
48-
50+ normalized_shape = [in_channels , kernel_size , kernel_size , kernel_size ]
51+ grn_parameter_shape = (1 , 1 , 1 )
52+ else :
53+ raise ValueError ("dim must be either '2d' or '3d'" )
4954 # First convolution layer with DepthWise Convolutions
5055 self .conv1 = conv (
5156 in_channels = in_channels ,
5257 out_channels = in_channels ,
5358 kernel_size = kernel_size ,
5459 stride = 1 ,
5560 padding = kernel_size // 2 ,
56- groups = in_channels if n_groups is None else n_groups ,
61+ groups = in_channels ,
5762 )
5863
5964 # Normalization Layer. GroupNorm is used by default.
6065 if norm_type == "group" :
6166 self .norm = nn .GroupNorm (num_groups = in_channels , num_channels = in_channels )
6267 elif norm_type == "layer" :
63- self .norm = LayerNorm (normalized_shape = in_channels , data_format = "channels_first" )
64-
68+ self .norm = nn .LayerNorm (normalized_shape = normalized_shape )
6569 # Second convolution (Expansion) layer with Conv3D 1x1x1
66- self .conv2 = conv (in_channels = in_channels , out_channels = exp_r * in_channels , kernel_size = 1 , stride = 1 , padding = 0 )
70+ self .conv2 = conv (
71+ in_channels = in_channels , out_channels = expansion_ratio * in_channels , kernel_size = 1 , stride = 1 , padding = 0
72+ )
6773
6874 # GeLU activations
6975 self .act = nn .GELU ()
7076
7177 # Third convolution (Compression) layer with Conv3D 1x1x1
7278 self .conv3 = conv (
73- in_channels = exp_r * in_channels , out_channels = out_channels , kernel_size = 1 , stride = 1 , padding = 0
79+ in_channels = expansion_ratio * in_channels , out_channels = out_channels , kernel_size = 1 , stride = 1 , padding = 0
7480 )
7581
7682 self .grn = grn
7783 if self .grn :
78- if dim == "2d" :
79- self .grn_beta = nn .Parameter (torch .zeros (1 , exp_r * in_channels , 1 , 1 ), requires_grad = True )
80- self .grn_gamma = nn .Parameter (torch .zeros (1 , exp_r * in_channels , 1 , 1 ), requires_grad = True )
81- else :
82- self .grn_beta = nn .Parameter (torch .zeros (1 , exp_r * in_channels , 1 , 1 , 1 ), requires_grad = True )
83- self .grn_gamma = nn .Parameter (torch .zeros (1 , exp_r * in_channels , 1 , 1 , 1 ), requires_grad = True )
84+ grn_parameter_shape = (1 , expansion_ratio * in_channels ) + grn_parameter_shape
85+ self .grn_beta = nn .Parameter (torch .zeros (grn_parameter_shape ), requires_grad = True )
86+ self .grn_gamma = nn .Parameter (torch .zeros (grn_parameter_shape ), requires_grad = True )
8487
85- def forward (self , x , dummy_tensor = None ):
88+ def forward (self , x ):
8689
8790 x1 = x
8891 x1 = self .conv1 (x1 )
@@ -106,19 +109,34 @@ def forward(self, x, dummy_tensor=None):
106109class MedNeXtDownBlock (MedNeXtBlock ):
107110
108111 def __init__ (
109- self , in_channels , out_channels , exp_r = 4 , kernel_size = 7 , do_res = False , norm_type = "group" , dim = "3d" , grn = False
112+ self ,
113+ in_channels : int ,
114+ out_channels : int ,
115+ expansion_ratio : int = 4 ,
116+ kernel_size : int = 7 ,
117+ use_residual_connection : bool = False ,
118+ norm_type : str = "group" ,
119+ dim : str = "3d" ,
120+ grn : bool = False ,
110121 ):
111122
112123 super ().__init__ (
113- in_channels , out_channels , exp_r , kernel_size , do_res = False , norm_type = norm_type , dim = dim , grn = grn
124+ in_channels ,
125+ out_channels ,
126+ expansion_ratio ,
127+ kernel_size ,
128+ use_residual_connection = False ,
129+ norm_type = norm_type ,
130+ dim = dim ,
131+ grn = grn ,
114132 )
115133
116134 if dim == "2d" :
117135 conv = nn .Conv2d
118136 else :
119137 conv = nn .Conv3d
120- self .resample_do_res = do_res
121- if do_res :
138+ self .resample_do_res = use_residual_connection
139+ if use_residual_connection :
122140 self .res_conv = conv (in_channels = in_channels , out_channels = out_channels , kernel_size = 1 , stride = 2 )
123141
124142 self .conv1 = conv (
@@ -130,7 +148,7 @@ def __init__(
130148 groups = in_channels ,
131149 )
132150
133- def forward (self , x , dummy_tensor = None ):
151+ def forward (self , x ):
134152
135153 x1 = super ().forward (x )
136154
@@ -144,20 +162,35 @@ def forward(self, x, dummy_tensor=None):
144162class MedNeXtUpBlock (MedNeXtBlock ):
145163
146164 def __init__ (
147- self , in_channels , out_channels , exp_r = 4 , kernel_size = 7 , do_res = False , norm_type = "group" , dim = "3d" , grn = False
165+ self ,
166+ in_channels : int ,
167+ out_channels : int ,
168+ expansion_ratio : int = 4 ,
169+ kernel_size : int = 7 ,
170+ use_residual_connection : bool = False ,
171+ norm_type : str = "group" ,
172+ dim : str = "3d" ,
173+ grn : bool = False ,
148174 ):
149175 super ().__init__ (
150- in_channels , out_channels , exp_r , kernel_size , do_res = False , norm_type = norm_type , dim = dim , grn = grn
176+ in_channels ,
177+ out_channels ,
178+ expansion_ratio ,
179+ kernel_size ,
180+ use_residual_connection = False ,
181+ norm_type = norm_type ,
182+ dim = dim ,
183+ grn = grn ,
151184 )
152185
153- self .resample_do_res = do_res
186+ self .resample_do_res = use_residual_connection
154187
155188 self .dim = dim
156189 if dim == "2d" :
157190 conv = nn .ConvTranspose2d
158191 else :
159192 conv = nn .ConvTranspose3d
160- if do_res :
193+ if use_residual_connection :
161194 self .res_conv = conv (in_channels = in_channels , out_channels = out_channels , kernel_size = 1 , stride = 2 )
162195
163196 self .conv1 = conv (
@@ -169,7 +202,7 @@ def __init__(
169202 groups = in_channels ,
170203 )
171204
172- def forward (self , x , dummy_tensor = None ):
205+ def forward (self , x ):
173206
174207 x1 = super ().forward (x )
175208 # Asymmetry but necessary to match shape
@@ -190,7 +223,7 @@ def forward(self, x, dummy_tensor=None):
190223 return x1
191224
192225
193- class OutBlock (nn .Module ):
226+ class MedNeXtOutBlock (nn .Module ):
194227
195228 def __init__ (self , in_channels , n_classes , dim ):
196229 super ().__init__ ()
@@ -201,33 +234,5 @@ def __init__(self, in_channels, n_classes, dim):
201234 conv = nn .ConvTranspose3d
202235 self .conv_out = conv (in_channels , n_classes , kernel_size = 1 )
203236
204- def forward (self , x , dummy_tensor = None ):
237+ def forward (self , x ):
205238 return self .conv_out (x )
206-
207-
208- class LayerNorm (nn .Module ):
209- """LayerNorm that supports two data formats: channels_last (default) or channels_first.
210- The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
211- shape (batch_size, height, width, channels) while channels_first corresponds to inputs
212- with shape (batch_size, channels, height, width).
213- """
214-
215- def __init__ (self , normalized_shape , eps = 1e-5 , data_format = "channels_last" ):
216- super ().__init__ ()
217- self .weight = nn .Parameter (torch .ones (normalized_shape )) # beta
218- self .bias = nn .Parameter (torch .zeros (normalized_shape )) # gamma
219- self .eps = eps
220- self .data_format = data_format
221- if self .data_format not in ["channels_last" , "channels_first" ]:
222- raise NotImplementedError
223- self .normalized_shape = (normalized_shape ,)
224-
225- def forward (self , x , dummy_tensor = False ):
226- if self .data_format == "channels_last" :
227- return F .layer_norm (x , self .normalized_shape , self .weight , self .bias , self .eps )
228- elif self .data_format == "channels_first" :
229- u = x .mean (1 , keepdim = True )
230- s = (x - u ).pow (2 ).mean (1 , keepdim = True )
231- x = (x - u ) / torch .sqrt (s + self .eps )
232- x = self .weight [:, None , None , None ] * x + self .bias [:, None , None , None ]
233- return x
0 commit comments