Skip to content

Commit e7647fb

Browse files
committed
make dynunet and mednext api compatible in deep supervsion mode
Signed-off-by: elitap <[email protected]>
1 parent 8ac8e0d commit e7647fb

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

monai/networks/nets/mednext.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import torch
2121
import torch.nn as nn
22+
from torch.nn.functional import interpolate
2223

2324
from monai.networks.blocks.mednext_block import MedNeXtBlock, MedNeXtDownBlock, MedNeXtOutBlock, MedNeXtUpBlock
2425

@@ -57,7 +58,16 @@ class MedNeXt(nn.Module):
5758
decoder_expansion_ratio: expansion ratio for decoder blocks. Defaults to 2.
5859
bottleneck_expansion_ratio: expansion ratio for bottleneck blocks. Defaults to 2.
5960
kernel_size: kernel size for convolutions. Defaults to 7.
60-
deep_supervision: whether to use deep supervision. Defaults to False.
61+
deep_supervision: whether to use deep supervision. Defaults to ``False``.
62+
If ``True``, in training mode, the forward function will output not only the final feature map
63+
(from the `out_0` block), but also the feature maps that come from the intermediate up sample layers.
64+
In order to unify the return type, all intermediate feature maps are interpolated into the same size
65+
as the final feature map and stacked together (with a new dimension in the first axis) into one single tensor.
66+
For instance, if there are two intermediate feature maps with shapes: (1, 2, 16, 12) and
67+
(1, 2, 8, 6), and the final feature map has the shape (1, 2, 32, 24), then all intermediate feature maps
68+
will be interpolated into (1, 2, 32, 24), and the stacked tensor will has the shape (1, 3, 2, 32, 24).
69+
When calculating the loss, you can use torch.unbind to get all feature maps can compute the loss
70+
one by one with the ground truth, then do a weighted average for all losses to achieve the final loss.
6171
use_residual_connection: whether to use residual connections in standard, down and up blocks. Defaults to False.
6272
blocks_down: number of blocks in each encoder stage. Defaults to [2, 2, 2, 2].
6373
blocks_bottleneck: number of blocks in bottleneck stage. Defaults to 2.
@@ -260,7 +270,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor | Sequence[torch.Tensor]:
260270

261271
# Return output(s)
262272
if self.do_ds and self.training:
263-
return (x, *ds_outputs[::-1])
273+
out_all = [x]
274+
for feature_map in ds_outputs[::-1]:
275+
out_all.append(interpolate(feature_map, x.shape[2:]))
276+
return torch.stack(out_all, dim=1)
264277
else:
265278
return x
266279

tests/test_mednext.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,10 @@ def test_shape(self, input_param, input_shape, expected_shape):
7575
with eval_mode(net):
7676
result = net(torch.randn(input_shape).to(device))
7777
if input_param["deep_supervision"] and net.training:
78-
assert isinstance(result, tuple)
79-
self.assertEqual(result[0].shape, expected_shape, msg=str(input_param))
78+
assert isinstance(result, torch.Tensor)
79+
result = torch.unbind(result, dim=1)
80+
for r in result:
81+
self.assertEqual(r.shape, expected_shape, msg=str(input_param))
8082
else:
8183
self.assertEqual(result.shape, expected_shape, msg=str(input_param))
8284

@@ -87,8 +89,10 @@ def test_shape2(self, input_param, input_shape, expected_shape):
8789
net.train()
8890
result = net(torch.randn(input_shape).to(device))
8991
if input_param["deep_supervision"]:
90-
assert isinstance(result, tuple)
91-
self.assertEqual(result[0].shape, expected_shape, msg=str(input_param))
92+
assert isinstance(result, torch.Tensor)
93+
result = torch.unbind(result, dim=1)
94+
for r in result:
95+
self.assertEqual(r.shape, expected_shape, msg=str(input_param))
9296
else:
9397
assert isinstance(result, torch.Tensor)
9498
self.assertEqual(result.shape, expected_shape, msg=str(input_param))

0 commit comments

Comments
 (0)