Skip to content

Commit da5a3a4

Browse files
author
Fabio Ferreira
committed
refactor: move activation checkpointing wrapper to blocks
1 parent 515c659 commit da5a3a4

File tree

2 files changed

+45
-16
lines changed

2 files changed

+45
-16
lines changed
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
from typing import cast
15+
16+
import torch
17+
import torch.nn as nn
18+
from torch.utils.checkpoint import checkpoint
19+
20+
21+
class ActivationCheckpointWrapper(nn.Module):
22+
"""Wrapper applying activation checkpointing to a module during training.
23+
24+
Args:
25+
module: The module to wrap with activation checkpointing.
26+
"""
27+
28+
def __init__(self, module: nn.Module) -> None:
29+
super().__init__()
30+
self.module = module
31+
32+
def forward(self, x: torch.Tensor) -> torch.Tensor:
33+
"""Forward pass with optional activation checkpointing.
34+
35+
Args:
36+
x: Input tensor.
37+
38+
Returns:
39+
Output tensor from the wrapped module.
40+
"""
41+
return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))

monai/networks/nets/unet.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,30 +13,18 @@
1313

1414
import warnings
1515
from collections.abc import Sequence
16-
from typing import cast
1716

1817
import torch
1918
import torch.nn as nn
20-
from torch.utils.checkpoint import checkpoint
2119

20+
from monai.networks.blocks.activation_checkpointing import ActivationCheckpointWrapper
2221
from monai.networks.blocks.convolutions import Convolution, ResidualUnit
2322
from monai.networks.layers.factories import Act, Norm
2423
from monai.networks.layers.simplelayers import SkipConnection
2524

2625
__all__ = ["UNet", "Unet"]
2726

2827

29-
class _ActivationCheckpointWrapper(nn.Module):
30-
"""Apply activation checkpointing to the wrapped module during training."""
31-
32-
def __init__(self, module: nn.Module) -> None:
33-
super().__init__()
34-
self.module = module
35-
36-
def forward(self, x: torch.Tensor) -> torch.Tensor:
37-
return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
38-
39-
4028
class UNet(nn.Module):
4129
"""
4230
Enhanced version of UNet which has residual units implemented with the ResidualUnit class.
@@ -313,9 +301,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
313301

314302
class CheckpointUNet(UNet):
315303
def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module:
316-
subblock = _ActivationCheckpointWrapper(subblock)
317-
down_path = _ActivationCheckpointWrapper(down_path)
318-
up_path = _ActivationCheckpointWrapper(up_path)
304+
subblock = ActivationCheckpointWrapper(subblock)
305+
down_path = ActivationCheckpointWrapper(down_path)
306+
up_path = ActivationCheckpointWrapper(up_path)
319307
return super()._get_connection_block(down_path, up_path, subblock)
320308

321309

0 commit comments

Comments
 (0)