Skip to content

Commit f673ca1

Browse files
author
Fabio Ferreira
committed
fix: avoid BatchNorm subblocks
1 parent e112457 commit f673ca1

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

monai/networks/nets/unet.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,19 @@ class _ActivationCheckpointWrapper(nn.Module):
3030
"""Apply activation checkpointing to the wrapped module during training."""
3131
def __init__(self, module: nn.Module) -> None:
3232
super().__init__()
33+
# Pre-detect BatchNorm presence for fast path
34+
self._has_bn = any(isinstance(m, nn.modules.batchnorm._BatchNorm) for m in module.modules())
3335
self.module = module
3436

3537
def forward(self, x: torch.Tensor) -> torch.Tensor:
3638
if self.training and torch.is_grad_enabled() and x.requires_grad:
39+
if self._has_bn:
40+
warnings.warn(
41+
"Activation checkpointing skipped for a subblock containing BatchNorm to avoid double-updating "
42+
"running statistics during recomputation.",
43+
RuntimeWarning,
44+
)
45+
return cast(torch.Tensor, self.module(x))
3746
try:
3847
return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
3948
except TypeError:

0 commit comments

Comments
 (0)