|
19 | 19 |
|
20 | 20 | import torch |
21 | 21 | import torch.nn as nn |
| 22 | +from torch.nn.functional import interpolate |
22 | 23 |
|
23 | 24 | from monai.networks.blocks.mednext_block import MedNeXtBlock, MedNeXtDownBlock, MedNeXtOutBlock, MedNeXtUpBlock |
24 | 25 |
|
@@ -57,7 +58,16 @@ class MedNeXt(nn.Module): |
57 | 58 | decoder_expansion_ratio: expansion ratio for decoder blocks. Defaults to 2. |
58 | 59 | bottleneck_expansion_ratio: expansion ratio for bottleneck blocks. Defaults to 2. |
59 | 60 | kernel_size: kernel size for convolutions. Defaults to 7. |
60 | | - deep_supervision: whether to use deep supervision. Defaults to False. |
| 61 | + deep_supervision: whether to use deep supervision. Defaults to ``False``. |
| 62 | + If ``True``, in training mode, the forward function will output not only the final feature map |
| 63 | + (from the `out_0` block), but also the feature maps that come from the intermediate up sample layers. |
| 64 | + In order to unify the return type, all intermediate feature maps are interpolated into the same size |
| 65 | + as the final feature map and stacked together (with a new dimension in the first axis) into one single tensor. |
| 66 | + For instance, if there are two intermediate feature maps with shapes: (1, 2, 16, 12) and |
| 67 | + (1, 2, 8, 6), and the final feature map has the shape (1, 2, 32, 24), then all intermediate feature maps |
| 68 | + will be interpolated into (1, 2, 32, 24), and the stacked tensor will has the shape (1, 3, 2, 32, 24). |
| 69 | + When calculating the loss, you can use torch.unbind to get all feature maps can compute the loss |
| 70 | + one by one with the ground truth, then do a weighted average for all losses to achieve the final loss. |
61 | 71 | use_residual_connection: whether to use residual connections in standard, down and up blocks. Defaults to False. |
62 | 72 | blocks_down: number of blocks in each encoder stage. Defaults to [2, 2, 2, 2]. |
63 | 73 | blocks_bottleneck: number of blocks in bottleneck stage. Defaults to 2. |
@@ -260,7 +270,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor | Sequence[torch.Tensor]: |
260 | 270 |
|
261 | 271 | # Return output(s) |
262 | 272 | if self.do_ds and self.training: |
263 | | - return (x, *ds_outputs[::-1]) |
| 273 | + out_all = [x] |
| 274 | + for feature_map in ds_outputs[::-1]: |
| 275 | + out_all.append(interpolate(feature_map, x.shape[2:])) |
| 276 | + return torch.stack(out_all, dim=1) |
264 | 277 | else: |
265 | 278 | return x |
266 | 279 |
|
|
0 commit comments