From 93e782f4accd06a72e15e8e96f574f5f27bef658 Mon Sep 17 00:00:00 2001 From: Robin CREMESE Date: Wed, 21 Aug 2024 16:50:57 +0200 Subject: [PATCH] Code formating for Blake and Flake8 checks to pass + integration of MedNext variants (S, B, M, L) + integration of remarks from @johnzilke (https://github.com/Project-MONAI/MONAI/pull/8004#pullrequestreview-2233276224) for renaming class arguments - removal of self defined LayerNorm - linked residual connection for encoder and decoder Signed-off-by: Robin CREMESE --- monai/networks/blocks/__init__.py | 2 +- monai/networks/blocks/mednext_block.py | 123 ++++++++------- monai/networks/nets/__init__.py | 20 ++- monai/networks/nets/mednext.py | 207 +++++++++++++++++++++---- tests/test_mednext.py | 24 ++- 5 files changed, 273 insertions(+), 103 deletions(-) diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py index a535c0ab262..499caf2e0f7 100644 --- a/monai/networks/blocks/__init__.py +++ b/monai/networks/blocks/__init__.py @@ -26,7 +26,7 @@ from .fcn import FCN, GCN, MCFCN, Refine from .feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool, LastLevelP6P7 from .localnet_block import LocalNetDownSampleBlock, LocalNetFeatureExtractorBlock, LocalNetUpSampleBlock -from .mednext_block import MedNeXtBlock, MedNeXtDownBlock, MedNeXtUpBlock, OutBlock +from .mednext_block import MedNeXtBlock, MedNeXtDownBlock, MedNeXtOutBlock, MedNeXtUpBlock from .mlp import MLPBlock from .patchembedding import PatchEmbed, PatchEmbeddingBlock from .regunet_block import RegistrationDownSampleBlock, RegistrationExtractionBlock, RegistrationResidualConvBlock diff --git a/monai/networks/blocks/mednext_block.py b/monai/networks/blocks/mednext_block.py index 42617d2e564..8d0f0433bd6 100644 --- a/monai/networks/blocks/mednext_block.py +++ b/monai/networks/blocks/mednext_block.py @@ -17,7 +17,8 @@ import torch import torch.nn as nn -import torch.nn.functional as F + +all = ["MedNeXtBlock", "MedNeXtDownBlock", "MedNeXtUpBlock", "MedNeXtOutBlock"] class MedNeXtBlock(nn.Module): @@ -26,26 +27,30 @@ def __init__( self, in_channels: int, out_channels: int, - exp_r: int = 4, + expansion_ratio: int = 4, kernel_size: int = 7, - do_res: int = True, + use_residual_connection: int = True, norm_type: str = "group", - n_groups: int or None = None, dim="3d", grn=False, ): super().__init__() - self.do_res = do_res + self.do_res = use_residual_connection assert dim in ["2d", "3d"] self.dim = dim if self.dim == "2d": conv = nn.Conv2d - else: + normalized_shape = [in_channels, kernel_size, kernel_size] + grn_parameter_shape = (1, 1) + elif self.dim == "3d": conv = nn.Conv3d - + normalized_shape = [in_channels, kernel_size, kernel_size, kernel_size] + grn_parameter_shape = (1, 1, 1) + else: + raise ValueError("dim must be either '2d' or '3d'") # First convolution layer with DepthWise Convolutions self.conv1 = conv( in_channels=in_channels, @@ -53,36 +58,34 @@ def __init__( kernel_size=kernel_size, stride=1, padding=kernel_size // 2, - groups=in_channels if n_groups is None else n_groups, + groups=in_channels, ) # Normalization Layer. GroupNorm is used by default. if norm_type == "group": self.norm = nn.GroupNorm(num_groups=in_channels, num_channels=in_channels) elif norm_type == "layer": - self.norm = LayerNorm(normalized_shape=in_channels, data_format="channels_first") - + self.norm = nn.LayerNorm(normalized_shape=normalized_shape) # Second convolution (Expansion) layer with Conv3D 1x1x1 - self.conv2 = conv(in_channels=in_channels, out_channels=exp_r * in_channels, kernel_size=1, stride=1, padding=0) + self.conv2 = conv( + in_channels=in_channels, out_channels=expansion_ratio * in_channels, kernel_size=1, stride=1, padding=0 + ) # GeLU activations self.act = nn.GELU() # Third convolution (Compression) layer with Conv3D 1x1x1 self.conv3 = conv( - in_channels=exp_r * in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0 + in_channels=expansion_ratio * in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0 ) self.grn = grn if self.grn: - if dim == "2d": - self.grn_beta = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1), requires_grad=True) - self.grn_gamma = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1), requires_grad=True) - else: - self.grn_beta = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1, 1), requires_grad=True) - self.grn_gamma = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1, 1), requires_grad=True) + grn_parameter_shape = (1, expansion_ratio * in_channels) + grn_parameter_shape + self.grn_beta = nn.Parameter(torch.zeros(grn_parameter_shape), requires_grad=True) + self.grn_gamma = nn.Parameter(torch.zeros(grn_parameter_shape), requires_grad=True) - def forward(self, x, dummy_tensor=None): + def forward(self, x): x1 = x x1 = self.conv1(x1) @@ -106,19 +109,34 @@ def forward(self, x, dummy_tensor=None): class MedNeXtDownBlock(MedNeXtBlock): def __init__( - self, in_channels, out_channels, exp_r=4, kernel_size=7, do_res=False, norm_type="group", dim="3d", grn=False + self, + in_channels: int, + out_channels: int, + expansion_ratio: int = 4, + kernel_size: int = 7, + use_residual_connection: bool = False, + norm_type: str = "group", + dim: str = "3d", + grn: bool = False, ): super().__init__( - in_channels, out_channels, exp_r, kernel_size, do_res=False, norm_type=norm_type, dim=dim, grn=grn + in_channels, + out_channels, + expansion_ratio, + kernel_size, + use_residual_connection=False, + norm_type=norm_type, + dim=dim, + grn=grn, ) if dim == "2d": conv = nn.Conv2d else: conv = nn.Conv3d - self.resample_do_res = do_res - if do_res: + self.resample_do_res = use_residual_connection + if use_residual_connection: self.res_conv = conv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2) self.conv1 = conv( @@ -130,7 +148,7 @@ def __init__( groups=in_channels, ) - def forward(self, x, dummy_tensor=None): + def forward(self, x): x1 = super().forward(x) @@ -144,20 +162,35 @@ def forward(self, x, dummy_tensor=None): class MedNeXtUpBlock(MedNeXtBlock): def __init__( - self, in_channels, out_channels, exp_r=4, kernel_size=7, do_res=False, norm_type="group", dim="3d", grn=False + self, + in_channels: int, + out_channels: int, + expansion_ratio: int = 4, + kernel_size: int = 7, + use_residual_connection: bool = False, + norm_type: str = "group", + dim: str = "3d", + grn: bool = False, ): super().__init__( - in_channels, out_channels, exp_r, kernel_size, do_res=False, norm_type=norm_type, dim=dim, grn=grn + in_channels, + out_channels, + expansion_ratio, + kernel_size, + use_residual_connection=False, + norm_type=norm_type, + dim=dim, + grn=grn, ) - self.resample_do_res = do_res + self.resample_do_res = use_residual_connection self.dim = dim if dim == "2d": conv = nn.ConvTranspose2d else: conv = nn.ConvTranspose3d - if do_res: + if use_residual_connection: self.res_conv = conv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2) self.conv1 = conv( @@ -169,7 +202,7 @@ def __init__( groups=in_channels, ) - def forward(self, x, dummy_tensor=None): + def forward(self, x): x1 = super().forward(x) # Asymmetry but necessary to match shape @@ -190,7 +223,7 @@ def forward(self, x, dummy_tensor=None): return x1 -class OutBlock(nn.Module): +class MedNeXtOutBlock(nn.Module): def __init__(self, in_channels, n_classes, dim): super().__init__() @@ -201,33 +234,5 @@ def __init__(self, in_channels, n_classes, dim): conv = nn.ConvTranspose3d self.conv_out = conv(in_channels, n_classes, kernel_size=1) - def forward(self, x, dummy_tensor=None): + def forward(self, x): return self.conv_out(x) - - -class LayerNorm(nn.Module): - """LayerNorm that supports two data formats: channels_last (default) or channels_first. - The ordering of the dimensions in the inputs. channels_last corresponds to inputs with - shape (batch_size, height, width, channels) while channels_first corresponds to inputs - with shape (batch_size, channels, height, width). - """ - - def __init__(self, normalized_shape, eps=1e-5, data_format="channels_last"): - super().__init__() - self.weight = nn.Parameter(torch.ones(normalized_shape)) # beta - self.bias = nn.Parameter(torch.zeros(normalized_shape)) # gamma - self.eps = eps - self.data_format = data_format - if self.data_format not in ["channels_last", "channels_first"]: - raise NotImplementedError - self.normalized_shape = (normalized_shape,) - - def forward(self, x, dummy_tensor=False): - if self.data_format == "channels_last": - return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) - elif self.data_format == "channels_first": - u = x.mean(1, keepdim=True) - s = (x - u).pow(2).mean(1, keepdim=True) - x = (x - u) / torch.sqrt(s + self.eps) - x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None] - return x diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index f62fe432fa2..6dde0b4cc68 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -53,7 +53,25 @@ from .generator import Generator from .highresnet import HighResBlock, HighResNet from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet -from .mednext import MedNeXt +from .mednext import ( + MedNeXt, + MedNext, + MedNextB, + MedNeXtB, + MedNextBase, + MedNextL, + MedNeXtL, + MedNeXtLarge, + MedNextLarge, + MedNextM, + MedNeXtM, + MedNeXtMedium, + MedNextMedium, + MedNextS, + MedNeXtS, + MedNeXtSmall, + MedNextSmall, +) from .milmodel import MILModel from .netadapter import NetAdapter from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator diff --git a/monai/networks/nets/mednext.py b/monai/networks/nets/mednext.py index 293db7d443a..e4e68bea20b 100644 --- a/monai/networks/nets/mednext.py +++ b/monai/networks/nets/mednext.py @@ -15,10 +15,33 @@ from __future__ import annotations +from collections.abc import Sequence + import torch import torch.nn as nn -from monai.networks.blocks import MedNeXtBlock, MedNeXtDownBlock, MedNeXtUpBlock, OutBlock +from monai.networks.blocks.mednext_block import MedNeXtBlock, MedNeXtDownBlock, MedNeXtOutBlock, MedNeXtUpBlock + +__all__ = [ + "MedNeXt", + "MedNeXtSmall", + "MedNeXtBase", + "MedNeXtMedium", + "MedNeXtLarge", + "MedNext", + "MedNextS", + "MedNeXtS", + "MedNextSmall", + "MedNextB", + "MedNeXtB", + "MedNextBase", + "MedNextM", + "MedNeXtM", + "MedNextMedium", + "MedNextL", + "MedNeXtL", + "MedNextLarge", +] class MedNeXt(nn.Module): @@ -30,13 +53,12 @@ class MedNeXt(nn.Module): init_filters: number of output channels for initial convolution layer. Defaults to 32. in_channels: number of input channels for the network. Defaults to 1. out_channels: number of output channels for the network. Defaults to 2. - enc_exp_r: expansion ratio for encoder blocks. Defaults to 2. - dec_exp_r: expansion ratio for decoder blocks. Defaults to 2. - bottlenec_exp_r: expansion ratio for bottleneck blocks. Defaults to 2. + encoder_expansion_ratio: expansion ratio for encoder blocks. Defaults to 2. + decoder_expansion_ratio: expansion ratio for decoder blocks. Defaults to 2. + bottleneck_expansion_ratio: expansion ratio for bottleneck blocks. Defaults to 2. kernel_size: kernel size for convolutions. Defaults to 7. deep_supervision: whether to use deep supervision. Defaults to False. - do_res: whether to use residual connections. Defaults to False. - do_res_up_down: whether to use residual connections in up and down blocks. Defaults to False. + use_residual_connection: whether to use residual connections in standard, down and up blocks. Defaults to False. blocks_down: number of blocks in each encoder stage. Defaults to [2, 2, 2, 2]. blocks_bottleneck: number of blocks in bottleneck stage. Defaults to 2. blocks_up: number of blocks in each decoder stage. Defaults to [2, 2, 2, 2]. @@ -50,16 +72,15 @@ def __init__( init_filters: int = 32, in_channels: int = 1, out_channels: int = 2, - enc_exp_r: int = 2, - dec_exp_r: int = 2, - bottlenec_exp_r: int = 2, + encoder_expansion_ratio: int = 2, + decoder_expansion_ratio: int = 2, + bottleneck_expansion_ratio: int = 2, kernel_size: int = 7, deep_supervision: bool = False, - do_res: bool = False, - do_res_up_down: bool = False, - blocks_down: list = [2, 2, 2, 2], + use_residual_connection: bool = False, + blocks_down: Sequence[int] = (2, 2, 2, 2), blocks_bottleneck: int = 2, - blocks_up: list = [2, 2, 2, 2], + blocks_up: Sequence[int] = (2, 2, 2, 2), norm_type: str = "group", grn: bool = False, ): @@ -80,11 +101,11 @@ def __init__( spatial_dims_str = f"{spatial_dims}d" enc_kernel_size = dec_kernel_size = kernel_size - if isinstance(enc_exp_r, int): - enc_exp_r = [enc_exp_r] * len(blocks_down) + if isinstance(encoder_expansion_ratio, int): + encoder_expansion_ratio = [encoder_expansion_ratio] * len(blocks_down) - if isinstance(dec_exp_r, int): - dec_exp_r = [dec_exp_r] * len(blocks_up) + if isinstance(decoder_expansion_ratio, int): + decoder_expansion_ratio = [decoder_expansion_ratio] * len(blocks_up) conv = nn.Conv2d if spatial_dims_str == "2d" else nn.Conv3d @@ -100,9 +121,9 @@ def __init__( MedNeXtBlock( in_channels=init_filters * (2**i), out_channels=init_filters * (2**i), - exp_r=enc_exp_r[i], + expansion_ratio=encoder_expansion_ratio[i], kernel_size=enc_kernel_size, - do_res=do_res, + use_residual_connection=use_residual_connection, norm_type=norm_type, dim=spatial_dims_str, grn=grn, @@ -116,9 +137,9 @@ def __init__( MedNeXtDownBlock( in_channels=init_filters * (2**i), out_channels=init_filters * (2 ** (i + 1)), - exp_r=enc_exp_r[i], + expansion_ratio=encoder_expansion_ratio[i], kernel_size=enc_kernel_size, - do_res=do_res_up_down, + use_residual_connection=use_residual_connection, norm_type=norm_type, dim=spatial_dims_str, ) @@ -132,9 +153,9 @@ def __init__( MedNeXtBlock( in_channels=init_filters * (2 ** len(blocks_down)), out_channels=init_filters * (2 ** len(blocks_down)), - exp_r=bottlenec_exp_r, + expansion_ratio=bottleneck_expansion_ratio, kernel_size=dec_kernel_size, - do_res=do_res, + use_residual_connection=use_residual_connection, norm_type=norm_type, dim=spatial_dims_str, grn=grn, @@ -150,9 +171,9 @@ def __init__( MedNeXtUpBlock( in_channels=init_filters * (2 ** (len(blocks_up) - i)), out_channels=init_filters * (2 ** (len(blocks_up) - i - 1)), - exp_r=dec_exp_r[i], + expansion_ratio=decoder_expansion_ratio[i], kernel_size=dec_kernel_size, - do_res=do_res_up_down, + use_residual_connection=use_residual_connection, norm_type=norm_type, dim=spatial_dims_str, grn=grn, @@ -165,9 +186,9 @@ def __init__( MedNeXtBlock( in_channels=init_filters * (2 ** (len(blocks_up) - i - 1)), out_channels=init_filters * (2 ** (len(blocks_up) - i - 1)), - exp_r=dec_exp_r[i], + expansion_ratio=decoder_expansion_ratio[i], kernel_size=dec_kernel_size, - do_res=do_res, + use_residual_connection=use_residual_connection, norm_type=norm_type, dim=spatial_dims_str, grn=grn, @@ -180,11 +201,11 @@ def __init__( self.up_blocks = nn.ModuleList(up_blocks) self.dec_stages = nn.ModuleList(dec_stages) - self.out_0 = OutBlock(in_channels=init_filters, n_classes=out_channels, dim=spatial_dims_str) + self.out_0 = MedNeXtOutBlock(in_channels=init_filters, n_classes=out_channels, dim=spatial_dims_str) if deep_supervision: out_blocks = [ - OutBlock(in_channels=init_filters * (2**i), n_classes=out_channels, dim=spatial_dims_str) + MedNeXtOutBlock(in_channels=init_filters * (2**i), n_classes=out_channels, dim=spatial_dims_str) for i in range(1, len(blocks_up) + 1) ] @@ -242,3 +263,131 @@ def forward(self, x: torch.Tensor) -> torch.Tensor | list[torch.Tensor]: return (x, *ds_outputs[::-1]) else: return x + + +# Define the MedNeXt variants as reported in 10.48550/arXiv.2303.09975 +class MedNeXtSmall(MedNeXt): + """MedNeXt Small (S) configuration""" + + def __init__( + self, + spatial_dims: int = 3, + in_channels: int = 1, + out_channels: int = 2, + kernel_size: int = 3, + deep_supervision: bool = False, + ): + super().__init__( + spatial_dims=spatial_dims, + init_filters=32, + in_channels=in_channels, + out_channels=out_channels, + encoder_expansion_ratio=2, + decoder_expansion_ratio=2, + bottleneck_expansion_ratio=2, + kernel_size=kernel_size, + deep_supervision=deep_supervision, + use_residual_connection=True, + blocks_down=(2, 2, 2, 2), + blocks_bottleneck=2, + blocks_up=(2, 2, 2, 2), + norm_type="group", + grn=False, + ) + + +class MedNeXtBase(MedNeXt): + """MedNeXt Base (B) configuration""" + + def __init__( + self, + spatial_dims: int = 3, + in_channels: int = 1, + out_channels: int = 2, + kernel_size: int = 3, + deep_supervision: bool = False, + ): + super().__init__( + spatial_dims=spatial_dims, + init_filters=32, + in_channels=in_channels, + out_channels=out_channels, + encoder_expansion_ratio=(2, 3, 4, 4), + decoder_expansion_ratio=(4, 4, 3, 2), + bottleneck_expansion_ratio=4, + kernel_size=kernel_size, + deep_supervision=deep_supervision, + use_residual_connection=True, + blocks_down=(2, 2, 2, 2), + blocks_bottleneck=2, + blocks_up=(2, 2, 2, 2), + norm_type="group", + grn=False, + ) + + +class MedNeXtMedium(MedNeXt): + """MedNeXt Medium (M)""" + + def __init__( + self, + spatial_dims: int = 3, + in_channels: int = 1, + out_channels: int = 2, + kernel_size: int = 3, + deep_supervision: bool = False, + ): + super().__init__( + spatial_dims=spatial_dims, + init_filters=32, + in_channels=in_channels, + out_channels=out_channels, + encoder_expansion_ratio=(2, 3, 4, 4), + decoder_expansion_ratio=(4, 4, 3, 2), + bottleneck_expansion_ratio=4, + kernel_size=kernel_size, + deep_supervision=deep_supervision, + use_residual_connection=True, + blocks_down=(3, 4, 4, 4), + blocks_bottleneck=4, + blocks_up=(4, 4, 4, 3), + norm_type="group", + grn=False, + ) + + +class MedNeXtLarge(MedNeXt): + """MedNeXt Large (L)""" + + def __init__( + self, + spatial_dims: int = 3, + in_channels: int = 1, + out_channels: int = 2, + kernel_size: int = 3, + deep_supervision: bool = False, + ): + super().__init__( + spatial_dims=spatial_dims, + init_filters=32, + in_channels=in_channels, + out_channels=out_channels, + encoder_expansion_ratio=(3, 4, 8, 8), + decoder_expansion_ratio=(8, 8, 4, 3), + bottleneck_expansion_ratio=8, + kernel_size=kernel_size, + deep_supervision=deep_supervision, + use_residual_connection=True, + blocks_down=(3, 4, 8, 8), + blocks_bottleneck=8, + blocks_up=(8, 8, 4, 3), + norm_type="group", + grn=False, + ) + + +MedNext = MedNeXt +MedNextS = MedNeXtS = MedNextSmall = MedNeXtSmall +MedNextB = MedNeXtB = MedNextBase = MedNeXtBase +MedNextM = MedNeXtM = MedNextMedium = MedNeXtMedium +MedNextL = MedNeXtL = MedNextLarge = MedNeXtLarge diff --git a/tests/test_mednext.py b/tests/test_mednext.py index e5f74118c32..e39e88f1088 100644 --- a/tests/test_mednext.py +++ b/tests/test_mednext.py @@ -27,19 +27,17 @@ for init_filters in [8, 16]: for deep_supervision in [False, True]: for do_res in [False, True]: - for do_res_up_down in [False, True]: - test_case = [ - { - "spatial_dims": spatial_dims, - "init_filters": init_filters, - "deep_supervision": deep_supervision, - "do_res": do_res, - "do_res_up_down": do_res_up_down, - }, - (2, 1, *([16] * spatial_dims)), - (2, 2, *([16] * spatial_dims)), - ] - TEST_CASE_MEDNEXT.append(test_case) + test_case = [ + { + "spatial_dims": spatial_dims, + "init_filters": init_filters, + "deep_supervision": deep_supervision, + "use_residual_connection": do_res, + }, + (2, 1, *([16] * spatial_dims)), + (2, 2, *([16] * spatial_dims)), + ] + TEST_CASE_MEDNEXT.append(test_case) TEST_CASE_MEDNEXT_2 = [] for spatial_dims in range(2, 4):