Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for sampling from complex images in 2D #8

Merged
merged 1 commit into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions src/torch_image_lerp/linear_interpolation_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,24 @@ def sample_image_2d(
samples: torch.Tensor
`(..., )` array of samples from `image`.
"""
complex_input = torch.is_complex(image)

# setup for sampling with torch.nn.functional.grid_sample
coordinates, ps = einops.pack([coordinates], pattern='* yx')
n_samples = coordinates.shape[0]
h, w = image.shape[-2:]

image = einops.repeat(image, 'h w -> b 1 h w', b=n_samples) # b c h w
coordinates = einops.rearrange(coordinates, 'b yx -> b 1 1 yx') # b h w zyx
if complex_input is True:
# cannot sample complex tensors directly with grid_sample
# c.f. https://github.com/pytorch/pytorch/issues/67634
# workaround: treat real and imaginary parts as separate channels
image = torch.view_as_real(image)
image = einops.rearrange(image, 'h w complex -> complex h w')
image = einops.repeat(image, 'complex h w -> b complex h w', b=n_samples)
coordinates = einops.rearrange(coordinates, 'b zyx -> b 1 1 zyx') # b h w zyx
else:
image = einops.repeat(image, 'h w -> b 1 h w', b=n_samples) # b c h w
coordinates = einops.rearrange(coordinates, 'b zyx -> b 1 1 zyx') # b h w zyx

# take the samples
samples = F.grid_sample(
Expand All @@ -43,7 +54,11 @@ def sample_image_2d(
padding_mode='border', # this increases sampling fidelity at edges
align_corners=True,
)
samples = einops.rearrange(samples, 'b 1 1 1 -> b')
if complex_input is True:
samples = einops.rearrange(samples, 'b complex 1 1 -> b complex')
samples = torch.view_as_complex(samples.contiguous()) # (b, )
else:
samples = einops.rearrange(samples, 'b 1 1 1 -> b')

# set samples from outside of image to zero
coordinates = einops.rearrange(coordinates, 'b 1 1 yx -> b yx')
Expand Down
12 changes: 12 additions & 0 deletions tests/test_linear_interpolation_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,18 @@ def test_sample_image_2d():
assert samples.shape == (6, 7, 8)


def test_sample_image_2d_complex_input():
# basic sanity check only
image = torch.complex(real=torch.rand((28, 28)), imag=torch.rand(28, 28))

# make an arbitrary stack (..., 2) of 2d coords
coords = torch.tensor(np.random.randint(low=0, high=27, size=(6, 7, 8, 2)))

# sample
samples = sample_image_2d(image=image, coordinates=coords)
assert samples.shape == (6, 7, 8)


def test_insert_into_image_2d():
image = torch.zeros((28, 28)).float()

Expand Down
Loading