Skip to content

Commit efe93b2

Browse files
committed
Add docstrings
Signed-off-by: Suraj Pai <[email protected]>
1 parent 17d0579 commit efe93b2

File tree

2 files changed

+100
-20
lines changed

2 files changed

+100
-20
lines changed

monai/networks/blocks/mednext_block.py

Lines changed: 93 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,19 @@ def get_conv_layer(spatial_dim: int = 3, transpose: bool = False):
2929

3030

3131
class MedNeXtBlock(nn.Module):
32+
"""
33+
MedNeXtBlock class for the MedNeXt model.
34+
35+
Args:
36+
in_channels (int): Number of input channels.
37+
out_channels (int): Number of output channels.
38+
expansion_ratio (int): Expansion ratio for the block. Defaults to 4.
39+
kernel_size (int): Kernel size for convolutions. Defaults to 7.
40+
use_residual_connection (int): Whether to use residual connection. Defaults to True.
41+
norm_type (str): Type of normalization to use. Defaults to "group".
42+
dim (str): Dimension of the input. Can be "2d" or "3d". Defaults to "3d".
43+
global_resp_norm (bool): Whether to use global response normalization. Defaults to False.
44+
"""
3245

3346
def __init__(
3447
self,
@@ -39,7 +52,7 @@ def __init__(
3952
use_residual_connection: int = True,
4053
norm_type: str = "group",
4154
dim="3d",
42-
grn=False,
55+
global_resp_norm=False,
4356
):
4457

4558
super().__init__()
@@ -48,7 +61,7 @@ def __init__(
4861

4962
self.dim = dim
5063
conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3)
51-
grn_parameter_shape = (1,) * (2 if dim == "2d" else 3)
64+
global_resp_norm_param_shape = (1,) * (2 if dim == "2d" else 3)
5265
# First convolution layer with DepthWise Convolutions
5366
self.conv1 = conv(
5467
in_channels=in_channels,
@@ -79,34 +92,55 @@ def __init__(
7992
in_channels=expansion_ratio * in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0
8093
)
8194

82-
self.grn = grn
83-
if self.grn:
84-
grn_parameter_shape = (1, expansion_ratio * in_channels) + grn_parameter_shape
85-
self.grn_beta = nn.Parameter(torch.zeros(grn_parameter_shape), requires_grad=True)
86-
self.grn_gamma = nn.Parameter(torch.zeros(grn_parameter_shape), requires_grad=True)
95+
self.global_resp_norm = global_resp_norm
96+
if self.global_resp_norm:
97+
global_resp_norm_param_shape = (1, expansion_ratio * in_channels) + global_resp_norm_param_shape
98+
self.global_resp_beta = nn.Parameter(torch.zeros(global_resp_norm_param_shape), requires_grad=True)
99+
self.global_resp_gamma = nn.Parameter(torch.zeros(global_resp_norm_param_shape), requires_grad=True)
87100

88101
def forward(self, x):
102+
"""
103+
Forward pass of the MedNeXtBlock.
89104
105+
Args:
106+
x (torch.Tensor): Input tensor.
107+
108+
Returns:
109+
torch.Tensor: Output tensor.
110+
"""
90111
x1 = x
91112
x1 = self.conv1(x1)
92113
x1 = self.act(self.conv2(self.norm(x1)))
93114

94-
if self.grn:
115+
if self.global_resp_norm:
95116
# gamma, beta: learnable affine transform parameters
96117
# X: input of shape (N,C,H,W,D)
97118
if self.dim == "2d":
98119
gx = torch.norm(x1, p=2, dim=(-2, -1), keepdim=True)
99120
else:
100121
gx = torch.norm(x1, p=2, dim=(-3, -2, -1), keepdim=True)
101122
nx = gx / (gx.mean(dim=1, keepdim=True) + 1e-6)
102-
x1 = self.grn_gamma * (x1 * nx) + self.grn_beta + x1
123+
x1 = self.global_resp_gamma * (x1 * nx) + self.global_resp_beta + x1
103124
x1 = self.conv3(x1)
104125
if self.do_res:
105126
x1 = x + x1
106127
return x1
107128

108129

109130
class MedNeXtDownBlock(MedNeXtBlock):
131+
"""
132+
MedNeXtDownBlock class for downsampling in the MedNeXt model.
133+
134+
Args:
135+
in_channels (int): Number of input channels.
136+
out_channels (int): Number of output channels.
137+
expansion_ratio (int): Expansion ratio for the block. Defaults to 4.
138+
kernel_size (int): Kernel size for convolutions. Defaults to 7.
139+
use_residual_connection (bool): Whether to use residual connection. Defaults to False.
140+
norm_type (str): Type of normalization to use. Defaults to "group".
141+
dim (str): Dimension of the input. Can be "2d" or "3d". Defaults to "3d".
142+
global_resp_norm (bool): Whether to use global response normalization. Defaults to False.
143+
"""
110144

111145
def __init__(
112146
self,
@@ -117,7 +151,7 @@ def __init__(
117151
use_residual_connection: bool = False,
118152
norm_type: str = "group",
119153
dim: str = "3d",
120-
grn: bool = False,
154+
global_resp_norm: bool = False,
121155
):
122156

123157
super().__init__(
@@ -128,7 +162,7 @@ def __init__(
128162
use_residual_connection=False,
129163
norm_type=norm_type,
130164
dim=dim,
131-
grn=grn,
165+
global_resp_norm=global_resp_norm,
132166
)
133167

134168
conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3)
@@ -146,7 +180,15 @@ def __init__(
146180
)
147181

148182
def forward(self, x):
183+
"""
184+
Forward pass of the MedNeXtDownBlock.
185+
186+
Args:
187+
x (torch.Tensor): Input tensor.
149188
189+
Returns:
190+
torch.Tensor: Output tensor.
191+
"""
150192
x1 = super().forward(x)
151193

152194
if self.resample_do_res:
@@ -157,6 +199,19 @@ def forward(self, x):
157199

158200

159201
class MedNeXtUpBlock(MedNeXtBlock):
202+
"""
203+
MedNeXtUpBlock class for upsampling in the MedNeXt model.
204+
205+
Args:
206+
in_channels (int): Number of input channels.
207+
out_channels (int): Number of output channels.
208+
expansion_ratio (int): Expansion ratio for the block. Defaults to 4.
209+
kernel_size (int): Kernel size for convolutions. Defaults to 7.
210+
use_residual_connection (bool): Whether to use residual connection. Defaults to False.
211+
norm_type (str): Type of normalization to use. Defaults to "group".
212+
dim (str): Dimension of the input. Can be "2d" or "3d". Defaults to "3d".
213+
global_resp_norm (bool): Whether to use global response normalization. Defaults to False.
214+
"""
160215

161216
def __init__(
162217
self,
@@ -167,7 +222,7 @@ def __init__(
167222
use_residual_connection: bool = False,
168223
norm_type: str = "group",
169224
dim: str = "3d",
170-
grn: bool = False,
225+
global_resp_norm: bool = False,
171226
):
172227
super().__init__(
173228
in_channels,
@@ -177,7 +232,7 @@ def __init__(
177232
use_residual_connection=False,
178233
norm_type=norm_type,
179234
dim=dim,
180-
grn=grn,
235+
global_resp_norm=global_resp_norm,
181236
)
182237

183238
self.resample_do_res = use_residual_connection
@@ -197,7 +252,15 @@ def __init__(
197252
)
198253

199254
def forward(self, x):
255+
"""
256+
Forward pass of the MedNeXtUpBlock.
257+
258+
Args:
259+
x (torch.Tensor): Input tensor.
200260
261+
Returns:
262+
torch.Tensor: Output tensor.
263+
"""
201264
x1 = super().forward(x)
202265
# Asymmetry but necessary to match shape
203266

@@ -218,6 +281,14 @@ def forward(self, x):
218281

219282

220283
class MedNeXtOutBlock(nn.Module):
284+
"""
285+
MedNeXtOutBlock class for the output block in the MedNeXt model.
286+
287+
Args:
288+
in_channels (int): Number of input channels.
289+
n_classes (int): Number of output classes.
290+
dim (str): Dimension of the input. Can be "2d" or "3d".
291+
"""
221292

222293
def __init__(self, in_channels, n_classes, dim):
223294
super().__init__()
@@ -226,4 +297,13 @@ def __init__(self, in_channels, n_classes, dim):
226297
self.conv_out = conv(in_channels, n_classes, kernel_size=1)
227298

228299
def forward(self, x):
300+
"""
301+
Forward pass of the MedNeXtOutBlock.
302+
303+
Args:
304+
x (torch.Tensor): Input tensor.
305+
306+
Returns:
307+
torch.Tensor: Output tensor.
308+
"""
229309
return self.conv_out(x)

monai/networks/nets/mednext.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class MedNeXt(nn.Module):
6363
blocks_bottleneck: number of blocks in bottleneck stage. Defaults to 2.
6464
blocks_up: number of blocks in each decoder stage. Defaults to [2, 2, 2, 2].
6565
norm_type: type of normalization layer. Defaults to 'group'.
66-
grn: whether to use Global Response Normalization (GRN). Defaults to False.
66+
global_resp_norm: whether to use Global Response Normalization. Defaults to False. Refer: https://arxiv.org/abs/2301.00808
6767
"""
6868

6969
def __init__(
@@ -82,7 +82,7 @@ def __init__(
8282
blocks_bottleneck: int = 2,
8383
blocks_up: Sequence[int] = (2, 2, 2, 2),
8484
norm_type: str = "group",
85-
grn: bool = False,
85+
global_resp_norm: bool = False,
8686
):
8787
"""
8888
Initialize the MedNeXt model.
@@ -126,7 +126,7 @@ def __init__(
126126
use_residual_connection=use_residual_connection,
127127
norm_type=norm_type,
128128
dim=spatial_dims_str,
129-
grn=grn,
129+
global_resp_norm=global_resp_norm,
130130
)
131131
for _ in range(num_blocks)
132132
]
@@ -158,7 +158,7 @@ def __init__(
158158
use_residual_connection=use_residual_connection,
159159
norm_type=norm_type,
160160
dim=spatial_dims_str,
161-
grn=grn,
161+
global_resp_norm=global_resp_norm,
162162
)
163163
for _ in range(blocks_bottleneck)
164164
]
@@ -176,7 +176,7 @@ def __init__(
176176
use_residual_connection=use_residual_connection,
177177
norm_type=norm_type,
178178
dim=spatial_dims_str,
179-
grn=grn,
179+
global_resp_norm=global_resp_norm,
180180
)
181181
)
182182

@@ -191,7 +191,7 @@ def __init__(
191191
use_residual_connection=use_residual_connection,
192192
norm_type=norm_type,
193193
dim=spatial_dims_str,
194-
grn=grn,
194+
global_resp_norm=global_resp_norm,
195195
)
196196
for _ in range(num_blocks)
197197
]
@@ -299,7 +299,7 @@ def create_mednext(
299299
"deep_supervision": deep_supervision,
300300
"use_residual_connection": True,
301301
"norm_type": "group",
302-
"grn": False,
302+
"global_resp_norm": False,
303303
"init_filters": 32,
304304
}
305305

0 commit comments

Comments
 (0)