Skip to content

Commit

Permalink
add support for sampling from complex images in 2D (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
alisterburt authored Jul 5, 2024
1 parent 67bb126 commit 35c02ab
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 3 deletions.
21 changes: 18 additions & 3 deletions src/torch_image_lerp/linear_interpolation_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,24 @@ def sample_image_2d(
samples: torch.Tensor
`(..., )` array of samples from `image`.
"""
complex_input = torch.is_complex(image)

# setup for sampling with torch.nn.functional.grid_sample
coordinates, ps = einops.pack([coordinates], pattern='* yx')
n_samples = coordinates.shape[0]
h, w = image.shape[-2:]

image = einops.repeat(image, 'h w -> b 1 h w', b=n_samples) # b c h w
coordinates = einops.rearrange(coordinates, 'b yx -> b 1 1 yx') # b h w zyx
if complex_input is True:
# cannot sample complex tensors directly with grid_sample
# c.f. https://github.com/pytorch/pytorch/issues/67634
# workaround: treat real and imaginary parts as separate channels
image = torch.view_as_real(image)
image = einops.rearrange(image, 'h w complex -> complex h w')
image = einops.repeat(image, 'complex h w -> b complex h w', b=n_samples)
coordinates = einops.rearrange(coordinates, 'b zyx -> b 1 1 zyx') # b h w zyx
else:
image = einops.repeat(image, 'h w -> b 1 h w', b=n_samples) # b c h w
coordinates = einops.rearrange(coordinates, 'b zyx -> b 1 1 zyx') # b h w zyx

# take the samples
samples = F.grid_sample(
Expand All @@ -43,7 +54,11 @@ def sample_image_2d(
padding_mode='border', # this increases sampling fidelity at edges
align_corners=True,
)
samples = einops.rearrange(samples, 'b 1 1 1 -> b')
if complex_input is True:
samples = einops.rearrange(samples, 'b complex 1 1 -> b complex')
samples = torch.view_as_complex(samples.contiguous()) # (b, )
else:
samples = einops.rearrange(samples, 'b 1 1 1 -> b')

# set samples from outside of image to zero
coordinates = einops.rearrange(coordinates, 'b 1 1 yx -> b yx')
Expand Down
12 changes: 12 additions & 0 deletions tests/test_linear_interpolation_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,18 @@ def test_sample_image_2d():
assert samples.shape == (6, 7, 8)


def test_sample_image_2d_complex_input():
# basic sanity check only
image = torch.complex(real=torch.rand((28, 28)), imag=torch.rand(28, 28))

# make an arbitrary stack (..., 2) of 2d coords
coords = torch.tensor(np.random.randint(low=0, high=27, size=(6, 7, 8, 2)))

# sample
samples = sample_image_2d(image=image, coordinates=coords)
assert samples.shape == (6, 7, 8)


def test_insert_into_image_2d():
image = torch.zeros((28, 28)).float()

Expand Down

0 comments on commit 35c02ab

Please sign in to comment.