@@ -29,6 +29,19 @@ def get_conv_layer(spatial_dim: int = 3, transpose: bool = False):
2929
3030
3131class MedNeXtBlock (nn .Module ):
32+ """
33+ MedNeXtBlock class for the MedNeXt model.
34+
35+ Args:
36+ in_channels (int): Number of input channels.
37+ out_channels (int): Number of output channels.
38+ expansion_ratio (int): Expansion ratio for the block. Defaults to 4.
39+ kernel_size (int): Kernel size for convolutions. Defaults to 7.
40+ use_residual_connection (int): Whether to use residual connection. Defaults to True.
41+ norm_type (str): Type of normalization to use. Defaults to "group".
42+ dim (str): Dimension of the input. Can be "2d" or "3d". Defaults to "3d".
43+ global_resp_norm (bool): Whether to use global response normalization. Defaults to False.
44+ """
3245
3346 def __init__ (
3447 self ,
@@ -39,7 +52,7 @@ def __init__(
3952 use_residual_connection : int = True ,
4053 norm_type : str = "group" ,
4154 dim = "3d" ,
42- grn = False ,
55+ global_resp_norm = False ,
4356 ):
4457
4558 super ().__init__ ()
@@ -48,7 +61,7 @@ def __init__(
4861
4962 self .dim = dim
5063 conv = get_conv_layer (spatial_dim = 2 if dim == "2d" else 3 )
51- grn_parameter_shape = (1 ,) * (2 if dim == "2d" else 3 )
64+ global_resp_norm_param_shape = (1 ,) * (2 if dim == "2d" else 3 )
5265 # First convolution layer with DepthWise Convolutions
5366 self .conv1 = conv (
5467 in_channels = in_channels ,
@@ -79,34 +92,55 @@ def __init__(
7992 in_channels = expansion_ratio * in_channels , out_channels = out_channels , kernel_size = 1 , stride = 1 , padding = 0
8093 )
8194
82- self .grn = grn
83- if self .grn :
84- grn_parameter_shape = (1 , expansion_ratio * in_channels ) + grn_parameter_shape
85- self .grn_beta = nn .Parameter (torch .zeros (grn_parameter_shape ), requires_grad = True )
86- self .grn_gamma = nn .Parameter (torch .zeros (grn_parameter_shape ), requires_grad = True )
95+ self .global_resp_norm = global_resp_norm
96+ if self .global_resp_norm :
97+ global_resp_norm_param_shape = (1 , expansion_ratio * in_channels ) + global_resp_norm_param_shape
98+ self .global_resp_beta = nn .Parameter (torch .zeros (global_resp_norm_param_shape ), requires_grad = True )
99+ self .global_resp_gamma = nn .Parameter (torch .zeros (global_resp_norm_param_shape ), requires_grad = True )
87100
88101 def forward (self , x ):
102+ """
103+ Forward pass of the MedNeXtBlock.
89104
105+ Args:
106+ x (torch.Tensor): Input tensor.
107+
108+ Returns:
109+ torch.Tensor: Output tensor.
110+ """
90111 x1 = x
91112 x1 = self .conv1 (x1 )
92113 x1 = self .act (self .conv2 (self .norm (x1 )))
93114
94- if self .grn :
115+ if self .global_resp_norm :
95116 # gamma, beta: learnable affine transform parameters
96117 # X: input of shape (N,C,H,W,D)
97118 if self .dim == "2d" :
98119 gx = torch .norm (x1 , p = 2 , dim = (- 2 , - 1 ), keepdim = True )
99120 else :
100121 gx = torch .norm (x1 , p = 2 , dim = (- 3 , - 2 , - 1 ), keepdim = True )
101122 nx = gx / (gx .mean (dim = 1 , keepdim = True ) + 1e-6 )
102- x1 = self .grn_gamma * (x1 * nx ) + self .grn_beta + x1
123+ x1 = self .global_resp_gamma * (x1 * nx ) + self .global_resp_beta + x1
103124 x1 = self .conv3 (x1 )
104125 if self .do_res :
105126 x1 = x + x1
106127 return x1
107128
108129
109130class MedNeXtDownBlock (MedNeXtBlock ):
131+ """
132+ MedNeXtDownBlock class for downsampling in the MedNeXt model.
133+
134+ Args:
135+ in_channels (int): Number of input channels.
136+ out_channels (int): Number of output channels.
137+ expansion_ratio (int): Expansion ratio for the block. Defaults to 4.
138+ kernel_size (int): Kernel size for convolutions. Defaults to 7.
139+ use_residual_connection (bool): Whether to use residual connection. Defaults to False.
140+ norm_type (str): Type of normalization to use. Defaults to "group".
141+ dim (str): Dimension of the input. Can be "2d" or "3d". Defaults to "3d".
142+ global_resp_norm (bool): Whether to use global response normalization. Defaults to False.
143+ """
110144
111145 def __init__ (
112146 self ,
@@ -117,7 +151,7 @@ def __init__(
117151 use_residual_connection : bool = False ,
118152 norm_type : str = "group" ,
119153 dim : str = "3d" ,
120- grn : bool = False ,
154+ global_resp_norm : bool = False ,
121155 ):
122156
123157 super ().__init__ (
@@ -128,7 +162,7 @@ def __init__(
128162 use_residual_connection = False ,
129163 norm_type = norm_type ,
130164 dim = dim ,
131- grn = grn ,
165+ global_resp_norm = global_resp_norm ,
132166 )
133167
134168 conv = get_conv_layer (spatial_dim = 2 if dim == "2d" else 3 )
@@ -146,7 +180,15 @@ def __init__(
146180 )
147181
148182 def forward (self , x ):
183+ """
184+ Forward pass of the MedNeXtDownBlock.
185+
186+ Args:
187+ x (torch.Tensor): Input tensor.
149188
189+ Returns:
190+ torch.Tensor: Output tensor.
191+ """
150192 x1 = super ().forward (x )
151193
152194 if self .resample_do_res :
@@ -157,6 +199,19 @@ def forward(self, x):
157199
158200
159201class MedNeXtUpBlock (MedNeXtBlock ):
202+ """
203+ MedNeXtUpBlock class for upsampling in the MedNeXt model.
204+
205+ Args:
206+ in_channels (int): Number of input channels.
207+ out_channels (int): Number of output channels.
208+ expansion_ratio (int): Expansion ratio for the block. Defaults to 4.
209+ kernel_size (int): Kernel size for convolutions. Defaults to 7.
210+ use_residual_connection (bool): Whether to use residual connection. Defaults to False.
211+ norm_type (str): Type of normalization to use. Defaults to "group".
212+ dim (str): Dimension of the input. Can be "2d" or "3d". Defaults to "3d".
213+ global_resp_norm (bool): Whether to use global response normalization. Defaults to False.
214+ """
160215
161216 def __init__ (
162217 self ,
@@ -167,7 +222,7 @@ def __init__(
167222 use_residual_connection : bool = False ,
168223 norm_type : str = "group" ,
169224 dim : str = "3d" ,
170- grn : bool = False ,
225+ global_resp_norm : bool = False ,
171226 ):
172227 super ().__init__ (
173228 in_channels ,
@@ -177,7 +232,7 @@ def __init__(
177232 use_residual_connection = False ,
178233 norm_type = norm_type ,
179234 dim = dim ,
180- grn = grn ,
235+ global_resp_norm = global_resp_norm ,
181236 )
182237
183238 self .resample_do_res = use_residual_connection
@@ -197,7 +252,15 @@ def __init__(
197252 )
198253
199254 def forward (self , x ):
255+ """
256+ Forward pass of the MedNeXtUpBlock.
257+
258+ Args:
259+ x (torch.Tensor): Input tensor.
200260
261+ Returns:
262+ torch.Tensor: Output tensor.
263+ """
201264 x1 = super ().forward (x )
202265 # Asymmetry but necessary to match shape
203266
@@ -218,6 +281,14 @@ def forward(self, x):
218281
219282
220283class MedNeXtOutBlock (nn .Module ):
284+ """
285+ MedNeXtOutBlock class for the output block in the MedNeXt model.
286+
287+ Args:
288+ in_channels (int): Number of input channels.
289+ n_classes (int): Number of output classes.
290+ dim (str): Dimension of the input. Can be "2d" or "3d".
291+ """
221292
222293 def __init__ (self , in_channels , n_classes , dim ):
223294 super ().__init__ ()
@@ -226,4 +297,13 @@ def __init__(self, in_channels, n_classes, dim):
226297 self .conv_out = conv (in_channels , n_classes , kernel_size = 1 )
227298
228299 def forward (self , x ):
300+ """
301+ Forward pass of the MedNeXtOutBlock.
302+
303+ Args:
304+ x (torch.Tensor): Input tensor.
305+
306+ Returns:
307+ torch.Tensor: Output tensor.
308+ """
229309 return self .conv_out (x )
0 commit comments