We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 66edcb5 commit e66e357Copy full SHA for e66e357
monai/networks/nets/unet.py
@@ -32,8 +32,12 @@ def __init__(self, module: nn.Module) -> None:
32
self.module = module
33
34
def forward(self, x: torch.Tensor) -> torch.Tensor:
35
- if self.training:
36
- return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
+ if self.training and torch.is_grad_enabled() and x.requires_grad:
+ try:
37
+ return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
38
+ except TypeError:
39
+ # Fallback for older PyTorch without `use_reentrant`
40
+ return cast(torch.Tensor, checkpoint(self.module, x))
41
return cast(torch.Tensor, self.module(x))
42
43
0 commit comments