Skip to content

Commit

Permalink
fix minor bugs and add tests (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
alisterburt authored Jul 5, 2024
1 parent bc68930 commit 69de1a7
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 13 deletions.
4 changes: 2 additions & 2 deletions src/torch_image_lerp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .lerp_2d import sample_image_2d, insert_into_image_2d
from .lerp_3d import sample_image_3d, insert_into_image_3d
from .linear_interpolation_2d import sample_image_2d, insert_into_image_2d
from .linear_interpolation_3d import sample_image_3d, insert_into_image_3d
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,15 @@ def sample_image_2d(
# setup for sampling with torch.nn.functional.grid_sample
coordinates, ps = einops.pack([coordinates], pattern='* yx')
n_samples = coordinates.shape[0]
h, w = image.shape[-2:]

image = einops.repeat(image, 'h w -> b 1 h w', b=n_samples) # b c h w
coordinates = einops.rearrange(coordinates, 'b yx -> b 1 1 yx') # b h w zyx

# take the samples
samples = F.grid_sample(
input=image,
grid=array_to_grid_sample(coordinates, array_shape=image.shape[-2:]),
grid=array_to_grid_sample(coordinates, array_shape=(h, w)),
mode='bilinear',
padding_mode='border', # this increases sampling fidelity at edges
align_corners=True,
Expand All @@ -46,7 +47,7 @@ def sample_image_2d(

# set samples from outside of image to zero
coordinates = einops.rearrange(coordinates, 'b 1 1 yx -> b yx')
image_shape = torch.as_tensor(image.shape)
image_shape = torch.as_tensor((h, w), device=image.device)
inside = torch.logical_and(coordinates >= 0, coordinates <= image_shape - 1)
inside = torch.all(inside, dim=-1) # (b, d, h, w)
samples[~inside] *= 0
Expand All @@ -60,7 +61,7 @@ def insert_into_image_2d(
data: torch.Tensor,
coordinates: torch.Tensor,
image: torch.Tensor,
weights: torch.Tensor,
weights: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Insert values into a 2D image with bilinear interpolation.
Expand All @@ -74,7 +75,7 @@ def insert_into_image_2d(
- Coordinates span the range `[0, N-1]` for a dimension of length N.
image: torch.Tensor
`(h, w)` array containing the image into which data will be inserted.
weights: torch.Tensor
weights: torch.Tensor | None
`(h, w)` array containing weights associated with each pixel in `image`.
This is useful for tracking weights across multiple calls to this function.
Expand All @@ -85,6 +86,10 @@ def insert_into_image_2d(
"""
if data.shape != coordinates.shape[:-1]:
raise ValueError('One coordinate pair is required for each value in data.')
if coordinates.shape[-1] != 2:
raise ValueError('Coordinates must be of shape (..., 2).')
if weights is None:
weights = torch.zeros_like(image)

# linearise data and coordinates
data, _ = einops.pack([data], pattern='*')
Expand All @@ -97,12 +102,12 @@ def insert_into_image_2d(
data, coordinates = data[in_image_idx], coordinates[in_image_idx]

# calculate and cache floor and ceil of coordinates for each value to be inserted
corner_coordinates = torch.empty(size=(data.shape[0], 2, 2), dtype=torch.long)
corner_coordinates = torch.empty(size=(data.shape[0], 2, 2), dtype=torch.long, device=image.device)
corner_coordinates[:, 0] = torch.floor(coordinates)
corner_coordinates[:, 1] = torch.ceil(coordinates)

# calculate linear interpolation weights for each data point being inserted
_weights = torch.empty(size=(data.shape[0], 2, 2)) # (b, 2, yx)
_weights = torch.empty(size=(data.shape[0], 2, 2), device=image.device) # (b, 2, yx)
_weights[:, 1] = coordinates - corner_coordinates[:, 0] # upper corner weights
_weights[:, 0] = 1 - _weights[:, 1] # lower corner weights

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def insert_into_image_3d(
data: torch.Tensor,
coordinates: torch.Tensor,
image: torch.Tensor,
weights: torch.Tensor,
weights: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Insert values into a 3D image with trilinear interpolation.
Expand All @@ -88,7 +88,7 @@ def insert_into_image_3d(
- Coordinates span the range `[0, N-1]` for a dimension of length N.
image: torch.Tensor
`(d, h, w)` array into which data will be inserted.
weights: torch.Tensor
weights: torch.Tensor | None
`(d, h, w)` array containing weights associated with each pixel in `image`.
This is useful for tracking weights across multiple calls to this function.
Expand All @@ -99,6 +99,10 @@ def insert_into_image_3d(
"""
if data.shape != coordinates.shape[:-1]:
raise ValueError('One coordinate triplet is required for each value in data.')
if coordinates.shape[-1] != 3:
raise ValueError('Coordinates must be of shape (..., 3).')
if weights is None:
weights = torch.zeros_like(image)

# linearise data and coordinates
data, _ = einops.pack([data], pattern='*')
Expand All @@ -116,7 +120,7 @@ def insert_into_image_3d(
_c[:, 1] = torch.ceil(coordinates) # for upper corners

# calculate linear interpolation weights for each data point being inserted
_w = torch.empty(size=(data.shape[0], 2, 3), dtype=torch.float64, device=image.device) # (b, 2, zyx)
_w = torch.empty(size=(data.shape[0], 2, 3), dtype=image.dtype, device=image.device) # (b, 2, zyx)
_w[:, 1] = coordinates - _c[:, 0] # upper corner weights
_w[:, 0] = 1 - _w[:, 1] # lower corner weights

Expand Down
2 changes: 0 additions & 2 deletions tests/test_lerp_2d.py

This file was deleted.

Empty file removed tests/test_lerp_3d.py
Empty file.
51 changes: 51 additions & 0 deletions tests/test_linear_interpolation_2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import einops
import torch
import numpy as np

from torch_image_lerp import sample_image_2d, insert_into_image_2d


def test_sample_image_2d():
# basic sanity check only
image = torch.rand((28, 28))

# make an arbitrary stack (..., 2) of 2d coords
coords = torch.tensor(np.random.randint(low=0, high=27, size=(6, 7, 8, 2)))

# sample
samples = sample_image_2d(image=image, coordinates=coords)
assert samples.shape == (6, 7, 8)


def test_insert_into_image_2d():
image = torch.zeros((28, 28)).float()

# single value
value = torch.tensor([5]).float()
coordinate = torch.tensor([10.5, 14.5]).view((1, 2))

# sample
image, weights = insert_into_image_2d(value, coordinates=coordinate, image=image)

# check value (5) is evenly split over 4 nearest pixels
expected = einops.repeat(torch.tensor([5 / 4]), '1 -> 2 2')
assert torch.allclose(image[10:12, 14:16], expected)

# check for zeros elsewhere
assert torch.allclose(image[:10, :14], torch.zeros_like(image[:10, :14]))


def test_insert_into_image_2d_multiple():
image = torch.zeros((28, 28)).float()

# multiple values
values = torch.ones(size=(6, 7, 8)).float()
coordinates = torch.tensor(np.random.randint(low=0, high=27, size=(6, 7, 8, 2)))

# sample
image, weights = insert_into_image_2d(values, coordinates=coordinates, image=image)

# check for nonzero value at one point
sample_point = coordinates[0, 0, 0]
y, x = sample_point
assert image[y, x] > 0
51 changes: 51 additions & 0 deletions tests/test_linear_interpolation_3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import numpy as np
import torch
import einops

from torch_image_lerp import sample_image_3d, insert_into_image_3d


def test_sample_image_3d():
# basic sanity check only
image = torch.rand((28, 28, 28))

# make an arbitrary stack (..., 3) of 3d coords
coords = torch.tensor(np.random.randint(low=0, high=27, size=(6, 7, 8, 3)))

# sample
samples = sample_image_3d(image=image, coordinates=coords)
assert samples.shape == (6, 7, 8)


def test_insert_into_image_3d():
image = torch.zeros((28, 28, 28)).float()

# single value
value = torch.tensor([5]).float()
coordinate = torch.tensor([10.5, 14.5, 18.5]).view((1, 3))

# sample
image, weights = insert_into_image_3d(value, coordinates=coordinate, image=image)

# check value (5) is evenly split over 4 nearest pixels
expected = einops.repeat(torch.tensor([5 / 8]), '1 -> 2 2 2')
assert torch.allclose(image[10:12, 14:16, 18:20], expected)

# check for zeros elsewhere
assert torch.allclose(image[:10, :14, :18], torch.zeros_like(image[:10, :14, :18]))


def test_insert_into_image_3d_multiple():
image = torch.zeros((28, 28, 28)).float()

# multiple values
values = torch.ones(size=(6, 7, 8)).float()
coordinates = torch.tensor(np.random.randint(low=0, high=27, size=(6, 7, 8, 3)))

# sample
image, weights = insert_into_image_3d(values, coordinates=coordinates, image=image)

# check for nonzero value at one point
sample_point = coordinates[0, 0, 0]
z, y, x = sample_point
assert image[z, y, x] > 0

0 comments on commit 69de1a7

Please sign in to comment.