File tree Expand file tree Collapse file tree 1 file changed +0
-9
lines changed
Expand file tree Collapse file tree 1 file changed +0
-9
lines changed Original file line number Diff line number Diff line change @@ -30,19 +30,10 @@ class _ActivationCheckpointWrapper(nn.Module):
3030 """Apply activation checkpointing to the wrapped module during training."""
3131 def __init__ (self , module : nn .Module ) -> None :
3232 super ().__init__ ()
33- # Pre-detect BatchNorm presence for fast path
34- self ._has_bn = any (isinstance (m , nn .modules .batchnorm ._BatchNorm ) for m in module .modules ())
3533 self .module = module
3634
3735 def forward (self , x : torch .Tensor ) -> torch .Tensor :
3836 if self .training and torch .is_grad_enabled () and x .requires_grad :
39- if self ._has_bn :
40- warnings .warn (
41- "Activation checkpointing skipped for a subblock containing BatchNorm to avoid double-updating "
42- "running statistics during recomputation." ,
43- RuntimeWarning ,
44- )
45- return cast (torch .Tensor , self .module (x ))
4637 try :
4738 return cast (torch .Tensor , checkpoint (self .module , x , use_reentrant = False ))
4839 except TypeError :
You can’t perform that action at this time.
0 commit comments