From 69de1a74e0882c106a5fe51d607a50dc3d6ea2c2 Mon Sep 17 00:00:00 2001 From: alisterburt Date: Thu, 4 Jul 2024 17:14:31 -0700 Subject: [PATCH] fix minor bugs and add tests (#4) --- src/torch_image_lerp/__init__.py | 4 +- ...{lerp_2d.py => linear_interpolation_2d.py} | 17 ++++--- ...{lerp_3d.py => linear_interpolation_3d.py} | 10 ++-- tests/test_lerp_2d.py | 2 - tests/test_lerp_3d.py | 0 tests/test_linear_interpolation_2d.py | 51 +++++++++++++++++++ tests/test_linear_interpolation_3d.py | 51 +++++++++++++++++++ 7 files changed, 122 insertions(+), 13 deletions(-) rename src/torch_image_lerp/{lerp_2d.py => linear_interpolation_2d.py} (89%) rename src/torch_image_lerp/{lerp_3d.py => linear_interpolation_3d.py} (94%) delete mode 100644 tests/test_lerp_2d.py delete mode 100644 tests/test_lerp_3d.py create mode 100644 tests/test_linear_interpolation_2d.py create mode 100644 tests/test_linear_interpolation_3d.py diff --git a/src/torch_image_lerp/__init__.py b/src/torch_image_lerp/__init__.py index 8253000..f5891eb 100644 --- a/src/torch_image_lerp/__init__.py +++ b/src/torch_image_lerp/__init__.py @@ -1,2 +1,2 @@ -from .lerp_2d import sample_image_2d, insert_into_image_2d -from .lerp_3d import sample_image_3d, insert_into_image_3d +from .linear_interpolation_2d import sample_image_2d, insert_into_image_2d +from .linear_interpolation_3d import sample_image_3d, insert_into_image_3d diff --git a/src/torch_image_lerp/lerp_2d.py b/src/torch_image_lerp/linear_interpolation_2d.py similarity index 89% rename from src/torch_image_lerp/lerp_2d.py rename to src/torch_image_lerp/linear_interpolation_2d.py index 578800a..2319723 100644 --- a/src/torch_image_lerp/lerp_2d.py +++ b/src/torch_image_lerp/linear_interpolation_2d.py @@ -30,6 +30,7 @@ def sample_image_2d( # setup for sampling with torch.nn.functional.grid_sample coordinates, ps = einops.pack([coordinates], pattern='* yx') n_samples = coordinates.shape[0] + h, w = image.shape[-2:] image = einops.repeat(image, 'h w -> b 1 h w', b=n_samples) # b c h w coordinates = einops.rearrange(coordinates, 'b yx -> b 1 1 yx') # b h w zyx @@ -37,7 +38,7 @@ def sample_image_2d( # take the samples samples = F.grid_sample( input=image, - grid=array_to_grid_sample(coordinates, array_shape=image.shape[-2:]), + grid=array_to_grid_sample(coordinates, array_shape=(h, w)), mode='bilinear', padding_mode='border', # this increases sampling fidelity at edges align_corners=True, @@ -46,7 +47,7 @@ def sample_image_2d( # set samples from outside of image to zero coordinates = einops.rearrange(coordinates, 'b 1 1 yx -> b yx') - image_shape = torch.as_tensor(image.shape) + image_shape = torch.as_tensor((h, w), device=image.device) inside = torch.logical_and(coordinates >= 0, coordinates <= image_shape - 1) inside = torch.all(inside, dim=-1) # (b, d, h, w) samples[~inside] *= 0 @@ -60,7 +61,7 @@ def insert_into_image_2d( data: torch.Tensor, coordinates: torch.Tensor, image: torch.Tensor, - weights: torch.Tensor, + weights: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Insert values into a 2D image with bilinear interpolation. @@ -74,7 +75,7 @@ def insert_into_image_2d( - Coordinates span the range `[0, N-1]` for a dimension of length N. image: torch.Tensor `(h, w)` array containing the image into which data will be inserted. - weights: torch.Tensor + weights: torch.Tensor | None `(h, w)` array containing weights associated with each pixel in `image`. This is useful for tracking weights across multiple calls to this function. @@ -85,6 +86,10 @@ def insert_into_image_2d( """ if data.shape != coordinates.shape[:-1]: raise ValueError('One coordinate pair is required for each value in data.') + if coordinates.shape[-1] != 2: + raise ValueError('Coordinates must be of shape (..., 2).') + if weights is None: + weights = torch.zeros_like(image) # linearise data and coordinates data, _ = einops.pack([data], pattern='*') @@ -97,12 +102,12 @@ def insert_into_image_2d( data, coordinates = data[in_image_idx], coordinates[in_image_idx] # calculate and cache floor and ceil of coordinates for each value to be inserted - corner_coordinates = torch.empty(size=(data.shape[0], 2, 2), dtype=torch.long) + corner_coordinates = torch.empty(size=(data.shape[0], 2, 2), dtype=torch.long, device=image.device) corner_coordinates[:, 0] = torch.floor(coordinates) corner_coordinates[:, 1] = torch.ceil(coordinates) # calculate linear interpolation weights for each data point being inserted - _weights = torch.empty(size=(data.shape[0], 2, 2)) # (b, 2, yx) + _weights = torch.empty(size=(data.shape[0], 2, 2), device=image.device) # (b, 2, yx) _weights[:, 1] = coordinates - corner_coordinates[:, 0] # upper corner weights _weights[:, 0] = 1 - _weights[:, 1] # lower corner weights diff --git a/src/torch_image_lerp/lerp_3d.py b/src/torch_image_lerp/linear_interpolation_3d.py similarity index 94% rename from src/torch_image_lerp/lerp_3d.py rename to src/torch_image_lerp/linear_interpolation_3d.py index 880b5b1..74bfb48 100644 --- a/src/torch_image_lerp/lerp_3d.py +++ b/src/torch_image_lerp/linear_interpolation_3d.py @@ -74,7 +74,7 @@ def insert_into_image_3d( data: torch.Tensor, coordinates: torch.Tensor, image: torch.Tensor, - weights: torch.Tensor, + weights: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Insert values into a 3D image with trilinear interpolation. @@ -88,7 +88,7 @@ def insert_into_image_3d( - Coordinates span the range `[0, N-1]` for a dimension of length N. image: torch.Tensor `(d, h, w)` array into which data will be inserted. - weights: torch.Tensor + weights: torch.Tensor | None `(d, h, w)` array containing weights associated with each pixel in `image`. This is useful for tracking weights across multiple calls to this function. @@ -99,6 +99,10 @@ def insert_into_image_3d( """ if data.shape != coordinates.shape[:-1]: raise ValueError('One coordinate triplet is required for each value in data.') + if coordinates.shape[-1] != 3: + raise ValueError('Coordinates must be of shape (..., 3).') + if weights is None: + weights = torch.zeros_like(image) # linearise data and coordinates data, _ = einops.pack([data], pattern='*') @@ -116,7 +120,7 @@ def insert_into_image_3d( _c[:, 1] = torch.ceil(coordinates) # for upper corners # calculate linear interpolation weights for each data point being inserted - _w = torch.empty(size=(data.shape[0], 2, 3), dtype=torch.float64, device=image.device) # (b, 2, zyx) + _w = torch.empty(size=(data.shape[0], 2, 3), dtype=image.dtype, device=image.device) # (b, 2, zyx) _w[:, 1] = coordinates - _c[:, 0] # upper corner weights _w[:, 0] = 1 - _w[:, 1] # lower corner weights diff --git a/tests/test_lerp_2d.py b/tests/test_lerp_2d.py deleted file mode 100644 index 363b3e2..0000000 --- a/tests/test_lerp_2d.py +++ /dev/null @@ -1,2 +0,0 @@ -def test_something(): - pass diff --git a/tests/test_lerp_3d.py b/tests/test_lerp_3d.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_linear_interpolation_2d.py b/tests/test_linear_interpolation_2d.py new file mode 100644 index 0000000..3123d05 --- /dev/null +++ b/tests/test_linear_interpolation_2d.py @@ -0,0 +1,51 @@ +import einops +import torch +import numpy as np + +from torch_image_lerp import sample_image_2d, insert_into_image_2d + + +def test_sample_image_2d(): + # basic sanity check only + image = torch.rand((28, 28)) + + # make an arbitrary stack (..., 2) of 2d coords + coords = torch.tensor(np.random.randint(low=0, high=27, size=(6, 7, 8, 2))) + + # sample + samples = sample_image_2d(image=image, coordinates=coords) + assert samples.shape == (6, 7, 8) + + +def test_insert_into_image_2d(): + image = torch.zeros((28, 28)).float() + + # single value + value = torch.tensor([5]).float() + coordinate = torch.tensor([10.5, 14.5]).view((1, 2)) + + # sample + image, weights = insert_into_image_2d(value, coordinates=coordinate, image=image) + + # check value (5) is evenly split over 4 nearest pixels + expected = einops.repeat(torch.tensor([5 / 4]), '1 -> 2 2') + assert torch.allclose(image[10:12, 14:16], expected) + + # check for zeros elsewhere + assert torch.allclose(image[:10, :14], torch.zeros_like(image[:10, :14])) + + +def test_insert_into_image_2d_multiple(): + image = torch.zeros((28, 28)).float() + + # multiple values + values = torch.ones(size=(6, 7, 8)).float() + coordinates = torch.tensor(np.random.randint(low=0, high=27, size=(6, 7, 8, 2))) + + # sample + image, weights = insert_into_image_2d(values, coordinates=coordinates, image=image) + + # check for nonzero value at one point + sample_point = coordinates[0, 0, 0] + y, x = sample_point + assert image[y, x] > 0 diff --git a/tests/test_linear_interpolation_3d.py b/tests/test_linear_interpolation_3d.py new file mode 100644 index 0000000..8fc29fa --- /dev/null +++ b/tests/test_linear_interpolation_3d.py @@ -0,0 +1,51 @@ +import numpy as np +import torch +import einops + +from torch_image_lerp import sample_image_3d, insert_into_image_3d + + +def test_sample_image_3d(): + # basic sanity check only + image = torch.rand((28, 28, 28)) + + # make an arbitrary stack (..., 3) of 3d coords + coords = torch.tensor(np.random.randint(low=0, high=27, size=(6, 7, 8, 3))) + + # sample + samples = sample_image_3d(image=image, coordinates=coords) + assert samples.shape == (6, 7, 8) + + +def test_insert_into_image_3d(): + image = torch.zeros((28, 28, 28)).float() + + # single value + value = torch.tensor([5]).float() + coordinate = torch.tensor([10.5, 14.5, 18.5]).view((1, 3)) + + # sample + image, weights = insert_into_image_3d(value, coordinates=coordinate, image=image) + + # check value (5) is evenly split over 4 nearest pixels + expected = einops.repeat(torch.tensor([5 / 8]), '1 -> 2 2 2') + assert torch.allclose(image[10:12, 14:16, 18:20], expected) + + # check for zeros elsewhere + assert torch.allclose(image[:10, :14, :18], torch.zeros_like(image[:10, :14, :18])) + + +def test_insert_into_image_3d_multiple(): + image = torch.zeros((28, 28, 28)).float() + + # multiple values + values = torch.ones(size=(6, 7, 8)).float() + coordinates = torch.tensor(np.random.randint(low=0, high=27, size=(6, 7, 8, 3))) + + # sample + image, weights = insert_into_image_3d(values, coordinates=coordinates, image=image) + + # check for nonzero value at one point + sample_point = coordinates[0, 0, 0] + z, y, x = sample_point + assert image[z, y, x] > 0