Skip to content

Commit e146aaf

Browse files
committed
Update mednext implementations
Signed-off-by: Suraj Pai <[email protected]>
1 parent 93e782f commit e146aaf

File tree

2 files changed

+82
-94
lines changed

2 files changed

+82
-94
lines changed

monai/networks/nets/mednext.py

Lines changed: 54 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -266,128 +266,89 @@ def forward(self, x: torch.Tensor) -> torch.Tensor | list[torch.Tensor]:
266266

267267

268268
# Define the MedNeXt variants as reported in 10.48550/arXiv.2303.09975
269-
class MedNeXtSmall(MedNeXt):
270-
"""MedNeXt Small (S) configuration"""
269+
def create_mednext(
270+
variant: str,
271+
spatial_dims: int = 3,
272+
in_channels: int = 1,
273+
out_channels: int = 2,
274+
kernel_size: int = 3,
275+
deep_supervision: bool = False,
276+
) -> MedNeXt:
277+
"""
278+
Factory method to create MedNeXt variants.
271279
272-
def __init__(
273-
self,
274-
spatial_dims: int = 3,
275-
in_channels: int = 1,
276-
out_channels: int = 2,
277-
kernel_size: int = 3,
278-
deep_supervision: bool = False,
279-
):
280-
super().__init__(
281-
spatial_dims=spatial_dims,
282-
init_filters=32,
283-
in_channels=in_channels,
284-
out_channels=out_channels,
280+
Args:
281+
variant (str): The MedNeXt variant to create ('S', 'B', 'M', or 'L').
282+
spatial_dims (int): Number of spatial dimensions. Defaults to 3.
283+
in_channels (int): Number of input channels. Defaults to 1.
284+
out_channels (int): Number of output channels. Defaults to 2.
285+
kernel_size (int): Kernel size for convolutions. Defaults to 3.
286+
deep_supervision (bool): Whether to use deep supervision. Defaults to False.
287+
288+
Returns:
289+
MedNeXt: The specified MedNeXt variant.
290+
291+
Raises:
292+
ValueError: If an invalid variant is specified.
293+
"""
294+
common_args = {
295+
"spatial_dims": spatial_dims,
296+
"in_channels": in_channels,
297+
"out_channels": out_channels,
298+
"kernel_size": kernel_size,
299+
"deep_supervision": deep_supervision,
300+
"use_residual_connection": True,
301+
"norm_type": "group",
302+
"grn": False,
303+
"init_filters": 32,
304+
}
305+
306+
if variant.upper() == "S":
307+
return MedNeXt(
285308
encoder_expansion_ratio=2,
286309
decoder_expansion_ratio=2,
287310
bottleneck_expansion_ratio=2,
288-
kernel_size=kernel_size,
289-
deep_supervision=deep_supervision,
290-
use_residual_connection=True,
291311
blocks_down=(2, 2, 2, 2),
292312
blocks_bottleneck=2,
293313
blocks_up=(2, 2, 2, 2),
294-
norm_type="group",
295-
grn=False,
314+
**common_args,
296315
)
297-
298-
299-
class MedNeXtBase(MedNeXt):
300-
"""MedNeXt Base (B) configuration"""
301-
302-
def __init__(
303-
self,
304-
spatial_dims: int = 3,
305-
in_channels: int = 1,
306-
out_channels: int = 2,
307-
kernel_size: int = 3,
308-
deep_supervision: bool = False,
309-
):
310-
super().__init__(
311-
spatial_dims=spatial_dims,
312-
init_filters=32,
313-
in_channels=in_channels,
314-
out_channels=out_channels,
316+
elif variant.upper() == "B":
317+
return MedNeXt(
315318
encoder_expansion_ratio=(2, 3, 4, 4),
316319
decoder_expansion_ratio=(4, 4, 3, 2),
317320
bottleneck_expansion_ratio=4,
318-
kernel_size=kernel_size,
319-
deep_supervision=deep_supervision,
320-
use_residual_connection=True,
321321
blocks_down=(2, 2, 2, 2),
322322
blocks_bottleneck=2,
323323
blocks_up=(2, 2, 2, 2),
324-
norm_type="group",
325-
grn=False,
324+
**common_args,
326325
)
327-
328-
329-
class MedNeXtMedium(MedNeXt):
330-
"""MedNeXt Medium (M)"""
331-
332-
def __init__(
333-
self,
334-
spatial_dims: int = 3,
335-
in_channels: int = 1,
336-
out_channels: int = 2,
337-
kernel_size: int = 3,
338-
deep_supervision: bool = False,
339-
):
340-
super().__init__(
341-
spatial_dims=spatial_dims,
342-
init_filters=32,
343-
in_channels=in_channels,
344-
out_channels=out_channels,
326+
elif variant.upper() == "M":
327+
return MedNeXt(
345328
encoder_expansion_ratio=(2, 3, 4, 4),
346329
decoder_expansion_ratio=(4, 4, 3, 2),
347330
bottleneck_expansion_ratio=4,
348-
kernel_size=kernel_size,
349-
deep_supervision=deep_supervision,
350-
use_residual_connection=True,
351331
blocks_down=(3, 4, 4, 4),
352332
blocks_bottleneck=4,
353333
blocks_up=(4, 4, 4, 3),
354-
norm_type="group",
355-
grn=False,
334+
**common_args,
356335
)
357-
358-
359-
class MedNeXtLarge(MedNeXt):
360-
"""MedNeXt Large (L)"""
361-
362-
def __init__(
363-
self,
364-
spatial_dims: int = 3,
365-
in_channels: int = 1,
366-
out_channels: int = 2,
367-
kernel_size: int = 3,
368-
deep_supervision: bool = False,
369-
):
370-
super().__init__(
371-
spatial_dims=spatial_dims,
372-
init_filters=32,
373-
in_channels=in_channels,
374-
out_channels=out_channels,
336+
elif variant.upper() == "L":
337+
return MedNeXt(
375338
encoder_expansion_ratio=(3, 4, 8, 8),
376339
decoder_expansion_ratio=(8, 8, 4, 3),
377340
bottleneck_expansion_ratio=8,
378-
kernel_size=kernel_size,
379-
deep_supervision=deep_supervision,
380-
use_residual_connection=True,
381341
blocks_down=(3, 4, 8, 8),
382342
blocks_bottleneck=8,
383343
blocks_up=(8, 8, 4, 3),
384-
norm_type="group",
385-
grn=False,
344+
**common_args,
386345
)
346+
else:
347+
raise ValueError(f"Invalid MedNeXt variant: {variant}")
387348

388349

389350
MedNext = MedNeXt
390-
MedNextS = MedNeXtS = MedNextSmall = MedNeXtSmall
391-
MedNextB = MedNeXtB = MedNextBase = MedNeXtBase
392-
MedNextM = MedNeXtM = MedNextMedium = MedNeXtMedium
393-
MedNextL = MedNeXtL = MedNextLarge = MedNeXtLarge
351+
MedNextS = MedNeXtS = MedNextSmall = MedNeXtSmall = lambda **kwargs: create_mednext("S", **kwargs)
352+
MedNextB = MedNeXtB = MedNextBase = MedNeXtBase = lambda **kwargs: create_mednext("B", **kwargs)
353+
MedNextM = MedNeXtM = MedNextMedium = MedNeXtMedium = lambda **kwargs: create_mednext("M", **kwargs)
354+
MedNextL = MedNeXtL = MedNextLarge = MedNeXtLarge = lambda **kwargs: create_mednext("L", **kwargs)

tests/test_mednext.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from parameterized import parameterized
1818

1919
from monai.networks import eval_mode
20-
from monai.networks.nets import MedNeXt
20+
from monai.networks.nets import MedNeXt, MedNeXtL, MedNeXtM, MedNeXtS
2121
from tests.utils import SkipIfBeforePyTorchVersion, test_script_save
2222

2323
device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -55,6 +55,18 @@
5555
]
5656
TEST_CASE_MEDNEXT_2.append(test_case)
5757

58+
TEST_CASE_MEDNEXT_VARIANTS = []
59+
for model in [MedNeXtS, MedNeXtM, MedNeXtL]:
60+
for spatial_dims in range(2, 4):
61+
for out_channels in [1, 2]:
62+
test_case = [
63+
model,
64+
{"spatial_dims": spatial_dims, "in_channels": 1, "out_channels": out_channels},
65+
(2, 1, *([16] * spatial_dims)),
66+
(2, out_channels, *([16] * spatial_dims)),
67+
]
68+
TEST_CASE_MEDNEXT_VARIANTS.append(test_case)
69+
5870

5971
class TestMedNeXt(unittest.TestCase):
6072

@@ -91,6 +103,21 @@ def test_ill_arg(self):
91103
with self.assertRaises(AssertionError):
92104
MedNeXt(spatial_dims=4)
93105

106+
@parameterized.expand(TEST_CASE_MEDNEXT_VARIANTS)
107+
def test_mednext_variants(self, model, input_param, input_shape, expected_shape):
108+
net = model(**input_param).to(device)
109+
110+
net.train()
111+
result = net(torch.randn(input_shape).to(device))
112+
assert isinstance(result, torch.Tensor)
113+
self.assertEqual(result.shape, expected_shape, msg=str(input_param))
114+
115+
net.eval()
116+
with torch.no_grad():
117+
result = net(torch.randn(input_shape).to(device))
118+
assert isinstance(result, torch.Tensor)
119+
self.assertEqual(result.shape, expected_shape, msg=str(input_param))
120+
94121

95122
if __name__ == "__main__":
96123
unittest.main()

0 commit comments

Comments
 (0)