Skip to content

Commit b0a968b

Browse files
committed
Code formating for Blake and Flake8 checks to pass + integration of MedNext variants (S, B, M, L) + integration of remarks from @johnzilke (Project-MONAI#8004 (review)) for renaming class arguments - removal of self defined LayerNorm - linked residual connection for encoder and decoder
Signed-off-by: Robin CREMESE <[email protected]>
1 parent cea80a6 commit b0a968b

File tree

5 files changed

+746
-0
lines changed

5 files changed

+746
-0
lines changed

monai/networks/blocks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from .fcn import FCN, GCN, MCFCN, Refine
2727
from .feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool, LastLevelP6P7
2828
from .localnet_block import LocalNetDownSampleBlock, LocalNetFeatureExtractorBlock, LocalNetUpSampleBlock
29+
from .mednext_block import MedNeXtBlock, MedNeXtDownBlock, MedNeXtOutBlock, MedNeXtUpBlock
2930
from .mlp import MLPBlock
3031
from .patchembedding import PatchEmbed, PatchEmbeddingBlock
3132
from .regunet_block import RegistrationDownSampleBlock, RegistrationExtractionBlock, RegistrationResidualConvBlock
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
# Portions of this code are derived from the original repository at:
13+
# https://github.com/MIC-DKFZ/MedNeXt
14+
# and are used under the terms of the Apache License, Version 2.0.
15+
16+
from __future__ import annotations
17+
18+
import torch
19+
import torch.nn as nn
20+
21+
all = ["MedNeXtBlock", "MedNeXtDownBlock", "MedNeXtUpBlock", "MedNeXtOutBlock"]
22+
23+
24+
class MedNeXtBlock(nn.Module):
25+
26+
def __init__(
27+
self,
28+
in_channels: int,
29+
out_channels: int,
30+
expansion_ratio: int = 4,
31+
kernel_size: int = 7,
32+
use_residual_connection: int = True,
33+
norm_type: str = "group",
34+
dim="3d",
35+
grn=False,
36+
):
37+
38+
super().__init__()
39+
40+
self.do_res = use_residual_connection
41+
42+
assert dim in ["2d", "3d"]
43+
self.dim = dim
44+
if self.dim == "2d":
45+
conv = nn.Conv2d
46+
normalized_shape = [in_channels, kernel_size, kernel_size]
47+
grn_parameter_shape = (1, 1)
48+
elif self.dim == "3d":
49+
conv = nn.Conv3d
50+
normalized_shape = [in_channels, kernel_size, kernel_size, kernel_size]
51+
grn_parameter_shape = (1, 1, 1)
52+
else:
53+
raise ValueError("dim must be either '2d' or '3d'")
54+
# First convolution layer with DepthWise Convolutions
55+
self.conv1 = conv(
56+
in_channels=in_channels,
57+
out_channels=in_channels,
58+
kernel_size=kernel_size,
59+
stride=1,
60+
padding=kernel_size // 2,
61+
groups=in_channels,
62+
)
63+
64+
# Normalization Layer. GroupNorm is used by default.
65+
if norm_type == "group":
66+
self.norm = nn.GroupNorm(num_groups=in_channels, num_channels=in_channels)
67+
elif norm_type == "layer":
68+
self.norm = nn.LayerNorm(normalized_shape=normalized_shape)
69+
# Second convolution (Expansion) layer with Conv3D 1x1x1
70+
self.conv2 = conv(
71+
in_channels=in_channels, out_channels=expansion_ratio * in_channels, kernel_size=1, stride=1, padding=0
72+
)
73+
74+
# GeLU activations
75+
self.act = nn.GELU()
76+
77+
# Third convolution (Compression) layer with Conv3D 1x1x1
78+
self.conv3 = conv(
79+
in_channels=expansion_ratio * in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0
80+
)
81+
82+
self.grn = grn
83+
if self.grn:
84+
grn_parameter_shape = (1, expansion_ratio * in_channels) + grn_parameter_shape
85+
self.grn_beta = nn.Parameter(torch.zeros(grn_parameter_shape), requires_grad=True)
86+
self.grn_gamma = nn.Parameter(torch.zeros(grn_parameter_shape), requires_grad=True)
87+
88+
def forward(self, x):
89+
90+
x1 = x
91+
x1 = self.conv1(x1)
92+
x1 = self.act(self.conv2(self.norm(x1)))
93+
94+
if self.grn:
95+
# gamma, beta: learnable affine transform parameters
96+
# X: input of shape (N,C,H,W,D)
97+
if self.dim == "2d":
98+
gx = torch.norm(x1, p=2, dim=(-2, -1), keepdim=True)
99+
else:
100+
gx = torch.norm(x1, p=2, dim=(-3, -2, -1), keepdim=True)
101+
nx = gx / (gx.mean(dim=1, keepdim=True) + 1e-6)
102+
x1 = self.grn_gamma * (x1 * nx) + self.grn_beta + x1
103+
x1 = self.conv3(x1)
104+
if self.do_res:
105+
x1 = x + x1
106+
return x1
107+
108+
109+
class MedNeXtDownBlock(MedNeXtBlock):
110+
111+
def __init__(
112+
self,
113+
in_channels: int,
114+
out_channels: int,
115+
expansion_ratio: int = 4,
116+
kernel_size: int = 7,
117+
use_residual_connection: bool = False,
118+
norm_type: str = "group",
119+
dim: str = "3d",
120+
grn: bool = False,
121+
):
122+
123+
super().__init__(
124+
in_channels,
125+
out_channels,
126+
expansion_ratio,
127+
kernel_size,
128+
use_residual_connection=False,
129+
norm_type=norm_type,
130+
dim=dim,
131+
grn=grn,
132+
)
133+
134+
if dim == "2d":
135+
conv = nn.Conv2d
136+
else:
137+
conv = nn.Conv3d
138+
self.resample_do_res = use_residual_connection
139+
if use_residual_connection:
140+
self.res_conv = conv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2)
141+
142+
self.conv1 = conv(
143+
in_channels=in_channels,
144+
out_channels=in_channels,
145+
kernel_size=kernel_size,
146+
stride=2,
147+
padding=kernel_size // 2,
148+
groups=in_channels,
149+
)
150+
151+
def forward(self, x):
152+
153+
x1 = super().forward(x)
154+
155+
if self.resample_do_res:
156+
res = self.res_conv(x)
157+
x1 = x1 + res
158+
159+
return x1
160+
161+
162+
class MedNeXtUpBlock(MedNeXtBlock):
163+
164+
def __init__(
165+
self,
166+
in_channels: int,
167+
out_channels: int,
168+
expansion_ratio: int = 4,
169+
kernel_size: int = 7,
170+
use_residual_connection: bool = False,
171+
norm_type: str = "group",
172+
dim: str = "3d",
173+
grn: bool = False,
174+
):
175+
super().__init__(
176+
in_channels,
177+
out_channels,
178+
expansion_ratio,
179+
kernel_size,
180+
use_residual_connection=False,
181+
norm_type=norm_type,
182+
dim=dim,
183+
grn=grn,
184+
)
185+
186+
self.resample_do_res = use_residual_connection
187+
188+
self.dim = dim
189+
if dim == "2d":
190+
conv = nn.ConvTranspose2d
191+
else:
192+
conv = nn.ConvTranspose3d
193+
if use_residual_connection:
194+
self.res_conv = conv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2)
195+
196+
self.conv1 = conv(
197+
in_channels=in_channels,
198+
out_channels=in_channels,
199+
kernel_size=kernel_size,
200+
stride=2,
201+
padding=kernel_size // 2,
202+
groups=in_channels,
203+
)
204+
205+
def forward(self, x):
206+
207+
x1 = super().forward(x)
208+
# Asymmetry but necessary to match shape
209+
210+
if self.dim == "2d":
211+
x1 = torch.nn.functional.pad(x1, (1, 0, 1, 0))
212+
else:
213+
x1 = torch.nn.functional.pad(x1, (1, 0, 1, 0, 1, 0))
214+
215+
if self.resample_do_res:
216+
res = self.res_conv(x)
217+
if self.dim == "2d":
218+
res = torch.nn.functional.pad(res, (1, 0, 1, 0))
219+
else:
220+
res = torch.nn.functional.pad(res, (1, 0, 1, 0, 1, 0))
221+
x1 = x1 + res
222+
223+
return x1
224+
225+
226+
class MedNeXtOutBlock(nn.Module):
227+
228+
def __init__(self, in_channels, n_classes, dim):
229+
super().__init__()
230+
231+
if dim == "2d":
232+
conv = nn.ConvTranspose2d
233+
else:
234+
conv = nn.ConvTranspose3d
235+
self.conv_out = conv(in_channels, n_classes, kernel_size=1)
236+
237+
def forward(self, x):
238+
return self.conv_out(x)

monai/networks/nets/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,25 @@
5353
from .generator import Generator
5454
from .highresnet import HighResBlock, HighResNet
5555
from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet
56+
from .mednext import (
57+
MedNeXt,
58+
MedNext,
59+
MedNextB,
60+
MedNeXtB,
61+
MedNextBase,
62+
MedNextL,
63+
MedNeXtL,
64+
MedNeXtLarge,
65+
MedNextLarge,
66+
MedNextM,
67+
MedNeXtM,
68+
MedNeXtMedium,
69+
MedNextMedium,
70+
MedNextS,
71+
MedNeXtS,
72+
MedNeXtSmall,
73+
MedNextSmall,
74+
)
5675
from .milmodel import MILModel
5776
from .netadapter import NetAdapter
5877
from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator

0 commit comments

Comments
 (0)