Skip to content

Commit 43dec88

Browse files
author
Fabio Ferreira
committed
chore: add docstrings to checkpointed unet
1 parent da5a3a4 commit 43dec88

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

monai/networks/nets/unet.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from monai.networks.layers.factories import Act, Norm
2323
from monai.networks.layers.simplelayers import SkipConnection
2424

25-
__all__ = ["UNet", "Unet"]
25+
__all__ = ["UNet", "Unet", "CheckpointUNet"]
2626

2727

2828
class UNet(nn.Module):
@@ -300,6 +300,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
300300

301301

302302
class CheckpointUNet(UNet):
303+
"""UNet variant that wraps internal connection blocks with activation checkpointing.
304+
305+
See `UNet` for constructor arguments. During training with gradients enabled,
306+
intermediate activations inside encoder–decoder connections are recomputed in
307+
the backward pass to reduce peak memory usage at the cost of extra compute.
308+
"""
309+
303310
def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module:
304311
subblock = ActivationCheckpointWrapper(subblock)
305312
down_path = ActivationCheckpointWrapper(down_path)

0 commit comments

Comments
 (0)