@@ -33,13 +33,7 @@ def __init__(self, module: nn.Module) -> None:
3333 self .module = module
3434
3535 def forward (self , x : torch .Tensor ) -> torch .Tensor :
36- if self .training and torch .is_grad_enabled () and x .requires_grad :
37- try :
38- return cast (torch .Tensor , checkpoint (self .module , x , use_reentrant = False ))
39- except TypeError :
40- # Fallback for older PyTorch without `use_reentrant`
41- return cast (torch .Tensor , checkpoint (self .module , x ))
42- return cast (torch .Tensor , self .module (x ))
36+ return cast (torch .Tensor , checkpoint (self .module , x , use_reentrant = False ))
4337
4438
4539class UNet (nn .Module ):
@@ -138,7 +132,6 @@ def __init__(
138132 dropout : float = 0.0 ,
139133 bias : bool = True ,
140134 adn_ordering : str = "NDA" ,
141- use_checkpointing : bool = False ,
142135 ) -> None :
143136 super ().__init__ ()
144137
@@ -167,7 +160,6 @@ def __init__(
167160 self .dropout = dropout
168161 self .bias = bias
169162 self .adn_ordering = adn_ordering
170- self .use_checkpointing = use_checkpointing
171163
172164 def _create_block (
173165 inc : int , outc : int , channels : Sequence [int ], strides : Sequence [int ], is_top : bool
@@ -214,8 +206,6 @@ def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblo
214206 subblock: block defining the next layer in the network.
215207 Returns: block for this layer: `nn.Sequential(down_path, SkipConnection(subblock), up_path)`
216208 """
217- if self .use_checkpointing :
218- subblock = _ActivationCheckpointWrapper (subblock )
219209 return nn .Sequential (down_path , SkipConnection (subblock ), up_path )
220210
221211 def _get_down_layer (self , in_channels : int , out_channels : int , strides : int , is_top : bool ) -> nn .Module :
@@ -321,5 +311,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
321311 x = self .model (x )
322312 return x
323313
314+ class CheckpointUNet (UNet ):
315+ def _get_connection_block (self , down_path : nn .Module , up_path : nn .Module , subblock : nn .Module ) -> nn .Module :
316+ subblock = _ActivationCheckpointWrapper (subblock )
317+ return super ()._get_connection_block (down_path , up_path , subblock )
324318
325319Unet = UNet
0 commit comments