Skip to content

Commit

Permalink
register the interpolation matrix buffer so that it is pushed to devi…
Browse files Browse the repository at this point in the history
…ce with the model (#13)
  • Loading branch information
alisterburt authored Mar 27, 2023
1 parent bf3b291 commit adf8b8c
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 34 deletions.
12 changes: 10 additions & 2 deletions src/torch_cubic_spline_grids/_base_cubic_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class CubicSplineGrid(torch.nn.Module):
n_channels: int
_data: torch.nn.Parameter
_interpolation_function: Callable
_interpolation_matrix: torch.Tensor
_minibatch_size: int

def __init__(
Expand All @@ -27,12 +28,19 @@ def __init__(
grid_shape = tuple([n_channels, *resolution])
self.data = torch.zeros(size=grid_shape)
self._minibatch_size = minibatch_size
self.register_buffer(
name='interpolation_matrix',
tensor=self._interpolation_matrix,
persistent=False
)

def forward(self, u: torch.Tensor) -> torch.Tensor:
u = self._coerce_to_batched_coordinates(u) # (b, d)
interpolated = [
self._interpolation_function(self._data, minibatch)
for minibatch in batch(u, n=self._minibatch_size)
self._interpolation_function(
self._data, minibatch_u, matrix=self._interpolation_matrix
)
for minibatch_u in batch(u, n=self._minibatch_size)
] # List[Tensor[(b, d)]]
interpolated = torch.cat(interpolated, dim=0) # (b, d)
return self._unpack_interpolated_output(interpolated)
Expand Down
28 changes: 12 additions & 16 deletions src/torch_cubic_spline_grids/b_spline_grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
CoordinateLike = Union[float, Sequence[float], torch.Tensor]


class CubicBSplineGrid1d(CubicSplineGrid):
class _CubicBSplineGrid(CubicSplineGrid):
_interpolation_matrix = CUBIC_B_SPLINE_MATRIX


class CubicBSplineGrid1d(_CubicBSplineGrid):
"""Continuous parametrisation of a 1D space with a specific resolution."""
ndim: int = 1
_interpolation_function: Callable = partial(
_interpolate_grid_1d, matrix=CUBIC_B_SPLINE_MATRIX
)
_interpolation_function: Callable = partial(_interpolate_grid_1d)

def __init__(
self,
Expand All @@ -35,25 +37,19 @@ def __init__(
)


class CubicBSplineGrid2d(CubicSplineGrid):
class CubicBSplineGrid2d(_CubicBSplineGrid):
"""Continuous parametrisation of a 2D space with a specific resolution."""
ndim: int = 2
_interpolation_function: Callable = partial(
_interpolate_grid_2d, matrix=CUBIC_B_SPLINE_MATRIX
)
_interpolation_function: Callable = partial(_interpolate_grid_2d)


class CubicBSplineGrid3d(CubicSplineGrid):
class CubicBSplineGrid3d(_CubicBSplineGrid):
"""Continuous parametrisation of a 3D space with a specific resolution."""
ndim: int = 3
_interpolation_function: Callable = partial(
_interpolate_grid_3d, matrix=CUBIC_B_SPLINE_MATRIX
)
_interpolation_function: Callable = partial(_interpolate_grid_3d)


class CubicBSplineGrid4d(CubicSplineGrid):
class CubicBSplineGrid4d(_CubicBSplineGrid):
"""Continuous parametrisation of a 4D space with a specific resolution."""
ndim: int = 4
_interpolation_function: Callable = partial(
_interpolate_grid_4d, matrix=CUBIC_B_SPLINE_MATRIX
)
_interpolation_function: Callable = partial(_interpolate_grid_4d)
28 changes: 12 additions & 16 deletions src/torch_cubic_spline_grids/catmull_rom_grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
CoordinateLike = Union[float, Sequence[float], torch.Tensor]


class CubicCatmullRomGrid1d(CubicSplineGrid):
class _CubicCatmullRomGrid(CubicSplineGrid):
_interpolation_matrix = CUBIC_CATMULL_ROM_MATRIX


class CubicCatmullRomGrid1d(_CubicCatmullRomGrid):
"""Continuous parametrisation of a 1D space with a specific resolution."""
ndim: int = 1
_interpolation_function: Callable = partial(
_interpolate_grid_1d, matrix=CUBIC_CATMULL_ROM_MATRIX
)
_interpolation_function: Callable = partial(_interpolate_grid_1d)

def __init__(
self,
Expand All @@ -35,25 +37,19 @@ def __init__(
)


class CubicCatmullRomGrid2d(CubicSplineGrid):
class CubicCatmullRomGrid2d(_CubicCatmullRomGrid):
"""Continuous parametrisation of a 2D space with a specific resolution."""
ndim: int = 2
_interpolation_function: Callable = partial(
_interpolate_grid_2d, matrix=CUBIC_CATMULL_ROM_MATRIX
)
_interpolation_function: Callable = partial(_interpolate_grid_2d)


class CubicCatmullRomGrid3d(CubicSplineGrid):
class CubicCatmullRomGrid3d(_CubicCatmullRomGrid):
"""Continuous parametrisation of a 3D space with a specific resolution."""
ndim: int = 3
_interpolation_function: Callable = partial(
_interpolate_grid_3d, matrix=CUBIC_CATMULL_ROM_MATRIX
)
_interpolation_function: Callable = partial(_interpolate_grid_3d)


class CubicCatmullRomGrid4d(CubicSplineGrid):
class CubicCatmullRomGrid4d(_CubicCatmullRomGrid):
"""Continuous parametrisation of a 4D space with a specific resolution."""
ndim: int = 4
_interpolation_function: Callable = partial(
_interpolate_grid_4d, matrix=CUBIC_CATMULL_ROM_MATRIX
)
_interpolation_function: Callable = partial(_interpolate_grid_4d)

0 comments on commit adf8b8c

Please sign in to comment.