From 67bb12660c1ff8e6541f3b9cdedf7a5cd20f49d2 Mon Sep 17 00:00:00 2001 From: alisterburt Date: Thu, 4 Jul 2024 17:41:57 -0700 Subject: [PATCH] fix error in 3D sampling leading to test failure (#7) --- src/torch_image_lerp/linear_interpolation_3d.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/torch_image_lerp/linear_interpolation_3d.py b/src/torch_image_lerp/linear_interpolation_3d.py index 74bfb48..ab2381b 100644 --- a/src/torch_image_lerp/linear_interpolation_3d.py +++ b/src/torch_image_lerp/linear_interpolation_3d.py @@ -28,6 +28,7 @@ def sample_image_3d( `(..., )` array of samples from `image`. """ complex_input = torch.is_complex(image) + # setup for sampling with torch.nn.functional.grid_sample coordinates, ps = einops.pack([coordinates], pattern='* zyx') n_samples = coordinates.shape[0] @@ -52,11 +53,12 @@ def sample_image_3d( padding_mode='border', # this increases sampling fidelity at edges align_corners=True, ) + if complex_input is True: samples = einops.rearrange(samples, 'b complex 1 1 1 -> b complex') samples = torch.view_as_complex(samples.contiguous()) # (b, ) else: - samples = einops.rearrange(samples, 'b c 1 1 1 -> b c') + samples = einops.rearrange(samples, 'b 1 1 1 1 -> b') # set samples from outside of volume to zero coordinates = einops.rearrange(coordinates, 'b 1 1 1 zyx -> b zyx')