Skip to content

Commit f990ccc

Browse files
committed
Try a fix for Windows issue
Signed-off-by: Eric Kerfoot <[email protected]>
1 parent c978eae commit f990ccc

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

monai/networks/layers/spatial_transforms.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from __future__ import annotations
1313

1414
from collections.abc import Sequence
15+
import sys
1516

1617
import torch
1718
import torch.nn as nn
@@ -526,6 +527,11 @@ def forward(
526527
ValueError: When affine and image batch dimension differ.
527528
528529
"""
530+
531+
# In some cases it's necessary to convert inputs to grid_sample from float64 to float32 to work around known
532+
# issues with PyTorch, see https://github.com/Project-MONAI/MONAI/pull/8429
533+
convert_f32 = sys.platform != "win32" and src.dtype == torch.float64 and src.device == torch.device("cpu")
534+
529535
# validate `theta`
530536
if not isinstance(theta, torch.Tensor):
531537
raise TypeError(f"theta must be torch.Tensor but is {type(theta).__name__}.")
@@ -582,11 +588,21 @@ def forward(
582588
)
583589

584590
grid = nn.functional.affine_grid(theta=theta[:, :sr], size=list(dst_size), align_corners=self.align_corners)
591+
592+
_input = src.contiguous()
593+
if convert_f32:
594+
_input = _input.to(torch.float32)
595+
grid = grid.to(torch.float32)
596+
585597
dst = nn.functional.grid_sample(
586-
input=src.contiguous(),
598+
input=_input,
587599
grid=grid,
588600
mode=self.mode,
589601
padding_mode=self.padding_mode,
590602
align_corners=self.align_corners,
591603
)
604+
605+
if convert_f32:
606+
dst = dst.to(torch.float64)
607+
592608
return dst

0 commit comments

Comments
 (0)