|
13 | 13 |
|
14 | 14 | import warnings |
15 | 15 | from collections.abc import Sequence |
16 | | -from typing import cast |
17 | 16 |
|
18 | 17 | import torch |
19 | 18 | import torch.nn as nn |
20 | | -from torch.utils.checkpoint import checkpoint |
21 | 19 |
|
| 20 | +from monai.networks.blocks.activation_checkpointing import ActivationCheckpointWrapper |
22 | 21 | from monai.networks.blocks.convolutions import Convolution, ResidualUnit |
23 | 22 | from monai.networks.layers.factories import Act, Norm |
24 | 23 | from monai.networks.layers.simplelayers import SkipConnection |
25 | 24 |
|
26 | 25 | __all__ = ["UNet", "Unet"] |
27 | 26 |
|
28 | 27 |
|
29 | | -class _ActivationCheckpointWrapper(nn.Module): |
30 | | - """Apply activation checkpointing to the wrapped module during training.""" |
31 | | - |
32 | | - def __init__(self, module: nn.Module) -> None: |
33 | | - super().__init__() |
34 | | - self.module = module |
35 | | - |
36 | | - def forward(self, x: torch.Tensor) -> torch.Tensor: |
37 | | - return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) |
38 | | - |
39 | | - |
40 | 28 | class UNet(nn.Module): |
41 | 29 | """ |
42 | 30 | Enhanced version of UNet which has residual units implemented with the ResidualUnit class. |
@@ -313,9 +301,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: |
313 | 301 |
|
314 | 302 | class CheckpointUNet(UNet): |
315 | 303 | def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module: |
316 | | - subblock = _ActivationCheckpointWrapper(subblock) |
317 | | - down_path = _ActivationCheckpointWrapper(down_path) |
318 | | - up_path = _ActivationCheckpointWrapper(up_path) |
| 304 | + subblock = ActivationCheckpointWrapper(subblock) |
| 305 | + down_path = ActivationCheckpointWrapper(down_path) |
| 306 | + up_path = ActivationCheckpointWrapper(up_path) |
319 | 307 | return super()._get_connection_block(down_path, up_path, subblock) |
320 | 308 |
|
321 | 309 |
|
|
0 commit comments