Skip to content

Commit e66e357

Browse files
Update monai/networks/nets/unet.py
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Fábio S. Ferreira <[email protected]>
1 parent 66edcb5 commit e66e357

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

monai/networks/nets/unet.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,12 @@ def __init__(self, module: nn.Module) -> None:
3232
self.module = module
3333

3434
def forward(self, x: torch.Tensor) -> torch.Tensor:
35-
if self.training:
36-
return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
35+
if self.training and torch.is_grad_enabled() and x.requires_grad:
36+
try:
37+
return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
38+
except TypeError:
39+
# Fallback for older PyTorch without `use_reentrant`
40+
return cast(torch.Tensor, checkpoint(self.module, x))
3741
return cast(torch.Tensor, self.module(x))
3842

3943

0 commit comments

Comments
 (0)