Skip to content

Commit 5b51f50

Browse files
committed
Merge remote-tracking branch 'boris/maisi-trt' into maisi-trt
2 parents 03cb981 + dd12d09 commit 5b51f50

File tree

5 files changed

+805
-0
lines changed

5 files changed

+805
-0
lines changed

monai/networks/blocks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from .fcn import FCN, GCN, MCFCN, Refine
2727
from .feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool, LastLevelP6P7
2828
from .localnet_block import LocalNetDownSampleBlock, LocalNetFeatureExtractorBlock, LocalNetUpSampleBlock
29+
from .mednext_block import MedNeXtBlock, MedNeXtDownBlock, MedNeXtOutBlock, MedNeXtUpBlock
2930
from .mlp import MLPBlock
3031
from .patchembedding import PatchEmbed, PatchEmbeddingBlock
3132
from .regunet_block import RegistrationDownSampleBlock, RegistrationExtractionBlock, RegistrationResidualConvBlock
Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
# Portions of this code are derived from the original repository at:
13+
# https://github.com/MIC-DKFZ/MedNeXt
14+
# and are used under the terms of the Apache License, Version 2.0.
15+
16+
from __future__ import annotations
17+
18+
import torch
19+
import torch.nn as nn
20+
21+
all = ["MedNeXtBlock", "MedNeXtDownBlock", "MedNeXtUpBlock", "MedNeXtOutBlock"]
22+
23+
24+
def get_conv_layer(spatial_dim: int = 3, transpose: bool = False):
25+
if spatial_dim == 2:
26+
return nn.ConvTranspose2d if transpose else nn.Conv2d
27+
else: # spatial_dim == 3
28+
return nn.ConvTranspose3d if transpose else nn.Conv3d
29+
30+
31+
class MedNeXtBlock(nn.Module):
32+
"""
33+
MedNeXtBlock class for the MedNeXt model.
34+
35+
Args:
36+
in_channels (int): Number of input channels.
37+
out_channels (int): Number of output channels.
38+
expansion_ratio (int): Expansion ratio for the block. Defaults to 4.
39+
kernel_size (int): Kernel size for convolutions. Defaults to 7.
40+
use_residual_connection (int): Whether to use residual connection. Defaults to True.
41+
norm_type (str): Type of normalization to use. Defaults to "group".
42+
dim (str): Dimension of the input. Can be "2d" or "3d". Defaults to "3d".
43+
global_resp_norm (bool): Whether to use global response normalization. Defaults to False.
44+
"""
45+
46+
def __init__(
47+
self,
48+
in_channels: int,
49+
out_channels: int,
50+
expansion_ratio: int = 4,
51+
kernel_size: int = 7,
52+
use_residual_connection: int = True,
53+
norm_type: str = "group",
54+
dim="3d",
55+
global_resp_norm=False,
56+
):
57+
58+
super().__init__()
59+
60+
self.do_res = use_residual_connection
61+
62+
self.dim = dim
63+
conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3)
64+
global_resp_norm_param_shape = (1,) * (2 if dim == "2d" else 3)
65+
# First convolution layer with DepthWise Convolutions
66+
self.conv1 = conv(
67+
in_channels=in_channels,
68+
out_channels=in_channels,
69+
kernel_size=kernel_size,
70+
stride=1,
71+
padding=kernel_size // 2,
72+
groups=in_channels,
73+
)
74+
75+
# Normalization Layer. GroupNorm is used by default.
76+
if norm_type == "group":
77+
self.norm = nn.GroupNorm(num_groups=in_channels, num_channels=in_channels) # type: ignore
78+
elif norm_type == "layer":
79+
self.norm = nn.LayerNorm(
80+
normalized_shape=[in_channels] + [kernel_size] * (2 if dim == "2d" else 3) # type: ignore
81+
)
82+
# Second convolution (Expansion) layer with Conv3D 1x1x1
83+
self.conv2 = conv(
84+
in_channels=in_channels, out_channels=expansion_ratio * in_channels, kernel_size=1, stride=1, padding=0
85+
)
86+
87+
# GeLU activations
88+
self.act = nn.GELU()
89+
90+
# Third convolution (Compression) layer with Conv3D 1x1x1
91+
self.conv3 = conv(
92+
in_channels=expansion_ratio * in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0
93+
)
94+
95+
self.global_resp_norm = global_resp_norm
96+
if self.global_resp_norm:
97+
global_resp_norm_param_shape = (1, expansion_ratio * in_channels) + global_resp_norm_param_shape
98+
self.global_resp_beta = nn.Parameter(torch.zeros(global_resp_norm_param_shape), requires_grad=True)
99+
self.global_resp_gamma = nn.Parameter(torch.zeros(global_resp_norm_param_shape), requires_grad=True)
100+
101+
def forward(self, x):
102+
"""
103+
Forward pass of the MedNeXtBlock.
104+
105+
Args:
106+
x (torch.Tensor): Input tensor.
107+
108+
Returns:
109+
torch.Tensor: Output tensor.
110+
"""
111+
x1 = x
112+
x1 = self.conv1(x1)
113+
x1 = self.act(self.conv2(self.norm(x1)))
114+
115+
if self.global_resp_norm:
116+
# gamma, beta: learnable affine transform parameters
117+
# X: input of shape (N,C,H,W,D)
118+
if self.dim == "2d":
119+
gx = torch.norm(x1, p=2, dim=(-2, -1), keepdim=True)
120+
else:
121+
gx = torch.norm(x1, p=2, dim=(-3, -2, -1), keepdim=True)
122+
nx = gx / (gx.mean(dim=1, keepdim=True) + 1e-6)
123+
x1 = self.global_resp_gamma * (x1 * nx) + self.global_resp_beta + x1
124+
x1 = self.conv3(x1)
125+
if self.do_res:
126+
x1 = x + x1
127+
return x1
128+
129+
130+
class MedNeXtDownBlock(MedNeXtBlock):
131+
"""
132+
MedNeXtDownBlock class for downsampling in the MedNeXt model.
133+
134+
Args:
135+
in_channels (int): Number of input channels.
136+
out_channels (int): Number of output channels.
137+
expansion_ratio (int): Expansion ratio for the block. Defaults to 4.
138+
kernel_size (int): Kernel size for convolutions. Defaults to 7.
139+
use_residual_connection (bool): Whether to use residual connection. Defaults to False.
140+
norm_type (str): Type of normalization to use. Defaults to "group".
141+
dim (str): Dimension of the input. Can be "2d" or "3d". Defaults to "3d".
142+
global_resp_norm (bool): Whether to use global response normalization. Defaults to False.
143+
"""
144+
145+
def __init__(
146+
self,
147+
in_channels: int,
148+
out_channels: int,
149+
expansion_ratio: int = 4,
150+
kernel_size: int = 7,
151+
use_residual_connection: bool = False,
152+
norm_type: str = "group",
153+
dim: str = "3d",
154+
global_resp_norm: bool = False,
155+
):
156+
157+
super().__init__(
158+
in_channels,
159+
out_channels,
160+
expansion_ratio,
161+
kernel_size,
162+
use_residual_connection=False,
163+
norm_type=norm_type,
164+
dim=dim,
165+
global_resp_norm=global_resp_norm,
166+
)
167+
168+
conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3)
169+
self.resample_do_res = use_residual_connection
170+
if use_residual_connection:
171+
self.res_conv = conv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2)
172+
173+
self.conv1 = conv(
174+
in_channels=in_channels,
175+
out_channels=in_channels,
176+
kernel_size=kernel_size,
177+
stride=2,
178+
padding=kernel_size // 2,
179+
groups=in_channels,
180+
)
181+
182+
def forward(self, x):
183+
"""
184+
Forward pass of the MedNeXtDownBlock.
185+
186+
Args:
187+
x (torch.Tensor): Input tensor.
188+
189+
Returns:
190+
torch.Tensor: Output tensor.
191+
"""
192+
x1 = super().forward(x)
193+
194+
if self.resample_do_res:
195+
res = self.res_conv(x)
196+
x1 = x1 + res
197+
198+
return x1
199+
200+
201+
class MedNeXtUpBlock(MedNeXtBlock):
202+
"""
203+
MedNeXtUpBlock class for upsampling in the MedNeXt model.
204+
205+
Args:
206+
in_channels (int): Number of input channels.
207+
out_channels (int): Number of output channels.
208+
expansion_ratio (int): Expansion ratio for the block. Defaults to 4.
209+
kernel_size (int): Kernel size for convolutions. Defaults to 7.
210+
use_residual_connection (bool): Whether to use residual connection. Defaults to False.
211+
norm_type (str): Type of normalization to use. Defaults to "group".
212+
dim (str): Dimension of the input. Can be "2d" or "3d". Defaults to "3d".
213+
global_resp_norm (bool): Whether to use global response normalization. Defaults to False.
214+
"""
215+
216+
def __init__(
217+
self,
218+
in_channels: int,
219+
out_channels: int,
220+
expansion_ratio: int = 4,
221+
kernel_size: int = 7,
222+
use_residual_connection: bool = False,
223+
norm_type: str = "group",
224+
dim: str = "3d",
225+
global_resp_norm: bool = False,
226+
):
227+
super().__init__(
228+
in_channels,
229+
out_channels,
230+
expansion_ratio,
231+
kernel_size,
232+
use_residual_connection=False,
233+
norm_type=norm_type,
234+
dim=dim,
235+
global_resp_norm=global_resp_norm,
236+
)
237+
238+
self.resample_do_res = use_residual_connection
239+
240+
self.dim = dim
241+
conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3, transpose=True)
242+
if use_residual_connection:
243+
self.res_conv = conv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2)
244+
245+
self.conv1 = conv(
246+
in_channels=in_channels,
247+
out_channels=in_channels,
248+
kernel_size=kernel_size,
249+
stride=2,
250+
padding=kernel_size // 2,
251+
groups=in_channels,
252+
)
253+
254+
def forward(self, x):
255+
"""
256+
Forward pass of the MedNeXtUpBlock.
257+
258+
Args:
259+
x (torch.Tensor): Input tensor.
260+
261+
Returns:
262+
torch.Tensor: Output tensor.
263+
"""
264+
x1 = super().forward(x)
265+
# Asymmetry but necessary to match shape
266+
267+
if self.dim == "2d":
268+
x1 = torch.nn.functional.pad(x1, (1, 0, 1, 0))
269+
else:
270+
x1 = torch.nn.functional.pad(x1, (1, 0, 1, 0, 1, 0))
271+
272+
if self.resample_do_res:
273+
res = self.res_conv(x)
274+
if self.dim == "2d":
275+
res = torch.nn.functional.pad(res, (1, 0, 1, 0))
276+
else:
277+
res = torch.nn.functional.pad(res, (1, 0, 1, 0, 1, 0))
278+
x1 = x1 + res
279+
280+
return x1
281+
282+
283+
class MedNeXtOutBlock(nn.Module):
284+
"""
285+
MedNeXtOutBlock class for the output block in the MedNeXt model.
286+
287+
Args:
288+
in_channels (int): Number of input channels.
289+
n_classes (int): Number of output classes.
290+
dim (str): Dimension of the input. Can be "2d" or "3d".
291+
"""
292+
293+
def __init__(self, in_channels, n_classes, dim):
294+
super().__init__()
295+
296+
conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3, transpose=True)
297+
self.conv_out = conv(in_channels, n_classes, kernel_size=1)
298+
299+
def forward(self, x):
300+
"""
301+
Forward pass of the MedNeXtOutBlock.
302+
303+
Args:
304+
x (torch.Tensor): Input tensor.
305+
306+
Returns:
307+
torch.Tensor: Output tensor.
308+
"""
309+
return self.conv_out(x)

monai/networks/nets/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,25 @@
5353
from .generator import Generator
5454
from .highresnet import HighResBlock, HighResNet
5555
from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet
56+
from .mednext import (
57+
MedNeXt,
58+
MedNext,
59+
MedNextB,
60+
MedNeXtB,
61+
MedNextBase,
62+
MedNextL,
63+
MedNeXtL,
64+
MedNeXtLarge,
65+
MedNextLarge,
66+
MedNextM,
67+
MedNeXtM,
68+
MedNeXtMedium,
69+
MedNextMedium,
70+
MedNextS,
71+
MedNeXtS,
72+
MedNeXtSmall,
73+
MedNextSmall,
74+
)
5675
from .milmodel import MILModel
5776
from .netadapter import NetAdapter
5877
from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator

0 commit comments

Comments
 (0)