diff --git a/src/torch_image_lerp/linear_interpolation_2d.py b/src/torch_image_lerp/linear_interpolation_2d.py index 2319723..0532df4 100644 --- a/src/torch_image_lerp/linear_interpolation_2d.py +++ b/src/torch_image_lerp/linear_interpolation_2d.py @@ -27,13 +27,24 @@ def sample_image_2d( samples: torch.Tensor `(..., )` array of samples from `image`. """ + complex_input = torch.is_complex(image) + # setup for sampling with torch.nn.functional.grid_sample coordinates, ps = einops.pack([coordinates], pattern='* yx') n_samples = coordinates.shape[0] h, w = image.shape[-2:] - image = einops.repeat(image, 'h w -> b 1 h w', b=n_samples) # b c h w - coordinates = einops.rearrange(coordinates, 'b yx -> b 1 1 yx') # b h w zyx + if complex_input is True: + # cannot sample complex tensors directly with grid_sample + # c.f. https://github.com/pytorch/pytorch/issues/67634 + # workaround: treat real and imaginary parts as separate channels + image = torch.view_as_real(image) + image = einops.rearrange(image, 'h w complex -> complex h w') + image = einops.repeat(image, 'complex h w -> b complex h w', b=n_samples) + coordinates = einops.rearrange(coordinates, 'b zyx -> b 1 1 zyx') # b h w zyx + else: + image = einops.repeat(image, 'h w -> b 1 h w', b=n_samples) # b c h w + coordinates = einops.rearrange(coordinates, 'b zyx -> b 1 1 zyx') # b h w zyx # take the samples samples = F.grid_sample( @@ -43,7 +54,11 @@ def sample_image_2d( padding_mode='border', # this increases sampling fidelity at edges align_corners=True, ) - samples = einops.rearrange(samples, 'b 1 1 1 -> b') + if complex_input is True: + samples = einops.rearrange(samples, 'b complex 1 1 -> b complex') + samples = torch.view_as_complex(samples.contiguous()) # (b, ) + else: + samples = einops.rearrange(samples, 'b 1 1 1 -> b') # set samples from outside of image to zero coordinates = einops.rearrange(coordinates, 'b 1 1 yx -> b yx') diff --git a/tests/test_linear_interpolation_2d.py b/tests/test_linear_interpolation_2d.py index 3123d05..c2ce3aa 100644 --- a/tests/test_linear_interpolation_2d.py +++ b/tests/test_linear_interpolation_2d.py @@ -17,6 +17,18 @@ def test_sample_image_2d(): assert samples.shape == (6, 7, 8) +def test_sample_image_2d_complex_input(): + # basic sanity check only + image = torch.complex(real=torch.rand((28, 28)), imag=torch.rand(28, 28)) + + # make an arbitrary stack (..., 2) of 2d coords + coords = torch.tensor(np.random.randint(low=0, high=27, size=(6, 7, 8, 2))) + + # sample + samples = sample_image_2d(image=image, coordinates=coords) + assert samples.shape == (6, 7, 8) + + def test_insert_into_image_2d(): image = torch.zeros((28, 28)).float()