Skip to content

Commit bfff9a2

Browse files
author
Suraj Pai
authored
Merge pull request #1 from rcremese/mednext
Mednext
2 parents 53cdfad + 93e782f commit bfff9a2

File tree

5 files changed

+273
-103
lines changed

5 files changed

+273
-103
lines changed

monai/networks/blocks/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from .fcn import FCN, GCN, MCFCN, Refine
2727
from .feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool, LastLevelP6P7
2828
from .localnet_block import LocalNetDownSampleBlock, LocalNetFeatureExtractorBlock, LocalNetUpSampleBlock
29-
from .mednext_block import MedNeXtBlock, MedNeXtDownBlock, MedNeXtUpBlock, OutBlock
29+
from .mednext_block import MedNeXtBlock, MedNeXtDownBlock, MedNeXtOutBlock, MedNeXtUpBlock
3030
from .mlp import MLPBlock
3131
from .patchembedding import PatchEmbed, PatchEmbeddingBlock
3232
from .regunet_block import RegistrationDownSampleBlock, RegistrationExtractionBlock, RegistrationResidualConvBlock

monai/networks/blocks/mednext_block.py

Lines changed: 64 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
import torch
1919
import torch.nn as nn
20-
import torch.nn.functional as F
20+
21+
all = ["MedNeXtBlock", "MedNeXtDownBlock", "MedNeXtUpBlock", "MedNeXtOutBlock"]
2122

2223

2324
class MedNeXtBlock(nn.Module):
@@ -26,63 +27,65 @@ def __init__(
2627
self,
2728
in_channels: int,
2829
out_channels: int,
29-
exp_r: int = 4,
30+
expansion_ratio: int = 4,
3031
kernel_size: int = 7,
31-
do_res: int = True,
32+
use_residual_connection: int = True,
3233
norm_type: str = "group",
33-
n_groups: int or None = None,
3434
dim="3d",
3535
grn=False,
3636
):
3737

3838
super().__init__()
3939

40-
self.do_res = do_res
40+
self.do_res = use_residual_connection
4141

4242
assert dim in ["2d", "3d"]
4343
self.dim = dim
4444
if self.dim == "2d":
4545
conv = nn.Conv2d
46-
else:
46+
normalized_shape = [in_channels, kernel_size, kernel_size]
47+
grn_parameter_shape = (1, 1)
48+
elif self.dim == "3d":
4749
conv = nn.Conv3d
48-
50+
normalized_shape = [in_channels, kernel_size, kernel_size, kernel_size]
51+
grn_parameter_shape = (1, 1, 1)
52+
else:
53+
raise ValueError("dim must be either '2d' or '3d'")
4954
# First convolution layer with DepthWise Convolutions
5055
self.conv1 = conv(
5156
in_channels=in_channels,
5257
out_channels=in_channels,
5358
kernel_size=kernel_size,
5459
stride=1,
5560
padding=kernel_size // 2,
56-
groups=in_channels if n_groups is None else n_groups,
61+
groups=in_channels,
5762
)
5863

5964
# Normalization Layer. GroupNorm is used by default.
6065
if norm_type == "group":
6166
self.norm = nn.GroupNorm(num_groups=in_channels, num_channels=in_channels)
6267
elif norm_type == "layer":
63-
self.norm = LayerNorm(normalized_shape=in_channels, data_format="channels_first")
64-
68+
self.norm = nn.LayerNorm(normalized_shape=normalized_shape)
6569
# Second convolution (Expansion) layer with Conv3D 1x1x1
66-
self.conv2 = conv(in_channels=in_channels, out_channels=exp_r * in_channels, kernel_size=1, stride=1, padding=0)
70+
self.conv2 = conv(
71+
in_channels=in_channels, out_channels=expansion_ratio * in_channels, kernel_size=1, stride=1, padding=0
72+
)
6773

6874
# GeLU activations
6975
self.act = nn.GELU()
7076

7177
# Third convolution (Compression) layer with Conv3D 1x1x1
7278
self.conv3 = conv(
73-
in_channels=exp_r * in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0
79+
in_channels=expansion_ratio * in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0
7480
)
7581

7682
self.grn = grn
7783
if self.grn:
78-
if dim == "2d":
79-
self.grn_beta = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1), requires_grad=True)
80-
self.grn_gamma = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1), requires_grad=True)
81-
else:
82-
self.grn_beta = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1, 1), requires_grad=True)
83-
self.grn_gamma = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1, 1), requires_grad=True)
84+
grn_parameter_shape = (1, expansion_ratio * in_channels) + grn_parameter_shape
85+
self.grn_beta = nn.Parameter(torch.zeros(grn_parameter_shape), requires_grad=True)
86+
self.grn_gamma = nn.Parameter(torch.zeros(grn_parameter_shape), requires_grad=True)
8487

85-
def forward(self, x, dummy_tensor=None):
88+
def forward(self, x):
8689

8790
x1 = x
8891
x1 = self.conv1(x1)
@@ -106,19 +109,34 @@ def forward(self, x, dummy_tensor=None):
106109
class MedNeXtDownBlock(MedNeXtBlock):
107110

108111
def __init__(
109-
self, in_channels, out_channels, exp_r=4, kernel_size=7, do_res=False, norm_type="group", dim="3d", grn=False
112+
self,
113+
in_channels: int,
114+
out_channels: int,
115+
expansion_ratio: int = 4,
116+
kernel_size: int = 7,
117+
use_residual_connection: bool = False,
118+
norm_type: str = "group",
119+
dim: str = "3d",
120+
grn: bool = False,
110121
):
111122

112123
super().__init__(
113-
in_channels, out_channels, exp_r, kernel_size, do_res=False, norm_type=norm_type, dim=dim, grn=grn
124+
in_channels,
125+
out_channels,
126+
expansion_ratio,
127+
kernel_size,
128+
use_residual_connection=False,
129+
norm_type=norm_type,
130+
dim=dim,
131+
grn=grn,
114132
)
115133

116134
if dim == "2d":
117135
conv = nn.Conv2d
118136
else:
119137
conv = nn.Conv3d
120-
self.resample_do_res = do_res
121-
if do_res:
138+
self.resample_do_res = use_residual_connection
139+
if use_residual_connection:
122140
self.res_conv = conv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2)
123141

124142
self.conv1 = conv(
@@ -130,7 +148,7 @@ def __init__(
130148
groups=in_channels,
131149
)
132150

133-
def forward(self, x, dummy_tensor=None):
151+
def forward(self, x):
134152

135153
x1 = super().forward(x)
136154

@@ -144,20 +162,35 @@ def forward(self, x, dummy_tensor=None):
144162
class MedNeXtUpBlock(MedNeXtBlock):
145163

146164
def __init__(
147-
self, in_channels, out_channels, exp_r=4, kernel_size=7, do_res=False, norm_type="group", dim="3d", grn=False
165+
self,
166+
in_channels: int,
167+
out_channels: int,
168+
expansion_ratio: int = 4,
169+
kernel_size: int = 7,
170+
use_residual_connection: bool = False,
171+
norm_type: str = "group",
172+
dim: str = "3d",
173+
grn: bool = False,
148174
):
149175
super().__init__(
150-
in_channels, out_channels, exp_r, kernel_size, do_res=False, norm_type=norm_type, dim=dim, grn=grn
176+
in_channels,
177+
out_channels,
178+
expansion_ratio,
179+
kernel_size,
180+
use_residual_connection=False,
181+
norm_type=norm_type,
182+
dim=dim,
183+
grn=grn,
151184
)
152185

153-
self.resample_do_res = do_res
186+
self.resample_do_res = use_residual_connection
154187

155188
self.dim = dim
156189
if dim == "2d":
157190
conv = nn.ConvTranspose2d
158191
else:
159192
conv = nn.ConvTranspose3d
160-
if do_res:
193+
if use_residual_connection:
161194
self.res_conv = conv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2)
162195

163196
self.conv1 = conv(
@@ -169,7 +202,7 @@ def __init__(
169202
groups=in_channels,
170203
)
171204

172-
def forward(self, x, dummy_tensor=None):
205+
def forward(self, x):
173206

174207
x1 = super().forward(x)
175208
# Asymmetry but necessary to match shape
@@ -190,7 +223,7 @@ def forward(self, x, dummy_tensor=None):
190223
return x1
191224

192225

193-
class OutBlock(nn.Module):
226+
class MedNeXtOutBlock(nn.Module):
194227

195228
def __init__(self, in_channels, n_classes, dim):
196229
super().__init__()
@@ -201,33 +234,5 @@ def __init__(self, in_channels, n_classes, dim):
201234
conv = nn.ConvTranspose3d
202235
self.conv_out = conv(in_channels, n_classes, kernel_size=1)
203236

204-
def forward(self, x, dummy_tensor=None):
237+
def forward(self, x):
205238
return self.conv_out(x)
206-
207-
208-
class LayerNorm(nn.Module):
209-
"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
210-
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
211-
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
212-
with shape (batch_size, channels, height, width).
213-
"""
214-
215-
def __init__(self, normalized_shape, eps=1e-5, data_format="channels_last"):
216-
super().__init__()
217-
self.weight = nn.Parameter(torch.ones(normalized_shape)) # beta
218-
self.bias = nn.Parameter(torch.zeros(normalized_shape)) # gamma
219-
self.eps = eps
220-
self.data_format = data_format
221-
if self.data_format not in ["channels_last", "channels_first"]:
222-
raise NotImplementedError
223-
self.normalized_shape = (normalized_shape,)
224-
225-
def forward(self, x, dummy_tensor=False):
226-
if self.data_format == "channels_last":
227-
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
228-
elif self.data_format == "channels_first":
229-
u = x.mean(1, keepdim=True)
230-
s = (x - u).pow(2).mean(1, keepdim=True)
231-
x = (x - u) / torch.sqrt(s + self.eps)
232-
x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None]
233-
return x

monai/networks/nets/__init__.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,25 @@
5353
from .generator import Generator
5454
from .highresnet import HighResBlock, HighResNet
5555
from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet
56-
from .mednext import MedNeXt
56+
from .mednext import (
57+
MedNeXt,
58+
MedNext,
59+
MedNextB,
60+
MedNeXtB,
61+
MedNextBase,
62+
MedNextL,
63+
MedNeXtL,
64+
MedNeXtLarge,
65+
MedNextLarge,
66+
MedNextM,
67+
MedNeXtM,
68+
MedNeXtMedium,
69+
MedNextMedium,
70+
MedNextS,
71+
MedNeXtS,
72+
MedNeXtSmall,
73+
MedNextSmall,
74+
)
5775
from .milmodel import MILModel
5876
from .netadapter import NetAdapter
5977
from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator

0 commit comments

Comments
 (0)