@@ -266,128 +266,89 @@ def forward(self, x: torch.Tensor) -> torch.Tensor | list[torch.Tensor]:
266266
267267
268268# Define the MedNeXt variants as reported in 10.48550/arXiv.2303.09975
269- class MedNeXtSmall (MedNeXt ):
270- """MedNeXt Small (S) configuration"""
269+ def create_mednext (
270+ variant : str ,
271+ spatial_dims : int = 3 ,
272+ in_channels : int = 1 ,
273+ out_channels : int = 2 ,
274+ kernel_size : int = 3 ,
275+ deep_supervision : bool = False ,
276+ ) -> MedNeXt :
277+ """
278+ Factory method to create MedNeXt variants.
271279
272- def __init__ (
273- self ,
274- spatial_dims : int = 3 ,
275- in_channels : int = 1 ,
276- out_channels : int = 2 ,
277- kernel_size : int = 3 ,
278- deep_supervision : bool = False ,
279- ):
280- super ().__init__ (
281- spatial_dims = spatial_dims ,
282- init_filters = 32 ,
283- in_channels = in_channels ,
284- out_channels = out_channels ,
280+ Args:
281+ variant (str): The MedNeXt variant to create ('S', 'B', 'M', or 'L').
282+ spatial_dims (int): Number of spatial dimensions. Defaults to 3.
283+ in_channels (int): Number of input channels. Defaults to 1.
284+ out_channels (int): Number of output channels. Defaults to 2.
285+ kernel_size (int): Kernel size for convolutions. Defaults to 3.
286+ deep_supervision (bool): Whether to use deep supervision. Defaults to False.
287+
288+ Returns:
289+ MedNeXt: The specified MedNeXt variant.
290+
291+ Raises:
292+ ValueError: If an invalid variant is specified.
293+ """
294+ common_args = {
295+ "spatial_dims" : spatial_dims ,
296+ "in_channels" : in_channels ,
297+ "out_channels" : out_channels ,
298+ "kernel_size" : kernel_size ,
299+ "deep_supervision" : deep_supervision ,
300+ "use_residual_connection" : True ,
301+ "norm_type" : "group" ,
302+ "grn" : False ,
303+ "init_filters" : 32 ,
304+ }
305+
306+ if variant .upper () == "S" :
307+ return MedNeXt (
285308 encoder_expansion_ratio = 2 ,
286309 decoder_expansion_ratio = 2 ,
287310 bottleneck_expansion_ratio = 2 ,
288- kernel_size = kernel_size ,
289- deep_supervision = deep_supervision ,
290- use_residual_connection = True ,
291311 blocks_down = (2 , 2 , 2 , 2 ),
292312 blocks_bottleneck = 2 ,
293313 blocks_up = (2 , 2 , 2 , 2 ),
294- norm_type = "group" ,
295- grn = False ,
314+ ** common_args ,
296315 )
297-
298-
299- class MedNeXtBase (MedNeXt ):
300- """MedNeXt Base (B) configuration"""
301-
302- def __init__ (
303- self ,
304- spatial_dims : int = 3 ,
305- in_channels : int = 1 ,
306- out_channels : int = 2 ,
307- kernel_size : int = 3 ,
308- deep_supervision : bool = False ,
309- ):
310- super ().__init__ (
311- spatial_dims = spatial_dims ,
312- init_filters = 32 ,
313- in_channels = in_channels ,
314- out_channels = out_channels ,
316+ elif variant .upper () == "B" :
317+ return MedNeXt (
315318 encoder_expansion_ratio = (2 , 3 , 4 , 4 ),
316319 decoder_expansion_ratio = (4 , 4 , 3 , 2 ),
317320 bottleneck_expansion_ratio = 4 ,
318- kernel_size = kernel_size ,
319- deep_supervision = deep_supervision ,
320- use_residual_connection = True ,
321321 blocks_down = (2 , 2 , 2 , 2 ),
322322 blocks_bottleneck = 2 ,
323323 blocks_up = (2 , 2 , 2 , 2 ),
324- norm_type = "group" ,
325- grn = False ,
324+ ** common_args ,
326325 )
327-
328-
329- class MedNeXtMedium (MedNeXt ):
330- """MedNeXt Medium (M)"""
331-
332- def __init__ (
333- self ,
334- spatial_dims : int = 3 ,
335- in_channels : int = 1 ,
336- out_channels : int = 2 ,
337- kernel_size : int = 3 ,
338- deep_supervision : bool = False ,
339- ):
340- super ().__init__ (
341- spatial_dims = spatial_dims ,
342- init_filters = 32 ,
343- in_channels = in_channels ,
344- out_channels = out_channels ,
326+ elif variant .upper () == "M" :
327+ return MedNeXt (
345328 encoder_expansion_ratio = (2 , 3 , 4 , 4 ),
346329 decoder_expansion_ratio = (4 , 4 , 3 , 2 ),
347330 bottleneck_expansion_ratio = 4 ,
348- kernel_size = kernel_size ,
349- deep_supervision = deep_supervision ,
350- use_residual_connection = True ,
351331 blocks_down = (3 , 4 , 4 , 4 ),
352332 blocks_bottleneck = 4 ,
353333 blocks_up = (4 , 4 , 4 , 3 ),
354- norm_type = "group" ,
355- grn = False ,
334+ ** common_args ,
356335 )
357-
358-
359- class MedNeXtLarge (MedNeXt ):
360- """MedNeXt Large (L)"""
361-
362- def __init__ (
363- self ,
364- spatial_dims : int = 3 ,
365- in_channels : int = 1 ,
366- out_channels : int = 2 ,
367- kernel_size : int = 3 ,
368- deep_supervision : bool = False ,
369- ):
370- super ().__init__ (
371- spatial_dims = spatial_dims ,
372- init_filters = 32 ,
373- in_channels = in_channels ,
374- out_channels = out_channels ,
336+ elif variant .upper () == "L" :
337+ return MedNeXt (
375338 encoder_expansion_ratio = (3 , 4 , 8 , 8 ),
376339 decoder_expansion_ratio = (8 , 8 , 4 , 3 ),
377340 bottleneck_expansion_ratio = 8 ,
378- kernel_size = kernel_size ,
379- deep_supervision = deep_supervision ,
380- use_residual_connection = True ,
381341 blocks_down = (3 , 4 , 8 , 8 ),
382342 blocks_bottleneck = 8 ,
383343 blocks_up = (8 , 8 , 4 , 3 ),
384- norm_type = "group" ,
385- grn = False ,
344+ ** common_args ,
386345 )
346+ else :
347+ raise ValueError (f"Invalid MedNeXt variant: { variant } " )
387348
388349
389350MedNext = MedNeXt
390- MedNextS = MedNeXtS = MedNextSmall = MedNeXtSmall
391- MedNextB = MedNeXtB = MedNextBase = MedNeXtBase
392- MedNextM = MedNeXtM = MedNextMedium = MedNeXtMedium
393- MedNextL = MedNeXtL = MedNextLarge = MedNeXtLarge
351+ MedNextS = MedNeXtS = MedNextSmall = MedNeXtSmall = lambda ** kwargs : create_mednext ( "S" , ** kwargs )
352+ MedNextB = MedNeXtB = MedNextBase = MedNeXtBase = lambda ** kwargs : create_mednext ( "B" , ** kwargs )
353+ MedNextM = MedNeXtM = MedNextMedium = MedNeXtMedium = lambda ** kwargs : create_mednext ( "M" , ** kwargs )
354+ MedNextL = MedNeXtL = MedNextLarge = MedNeXtLarge = lambda ** kwargs : create_mednext ( "L" , ** kwargs )
0 commit comments