Skip to content

Commit

Permalink
ensure correct device output for control points and interpolation coo…
Browse files Browse the repository at this point in the history
…rdinates (#15)
  • Loading branch information
alisterburt authored Mar 27, 2023
1 parent ff0bce5 commit 82e6f58
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions src/torch_cubic_spline_grids/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from torch_cubic_spline_grids.pad_grids import pad_grid_1d


def generate_sample_positions_for_padded_grid_1d(n_samples: int) -> torch.Tensor:
def generate_sample_positions_for_padded_grid_1d(
n_samples: int, device: torch.device) -> torch.Tensor:
"""Generate a 1D vector of sample coordinates for a padded grid.
Coordinate system is [0, 1] covering each dimension, pre-padding.
Expand All @@ -18,6 +19,8 @@ def generate_sample_positions_for_padded_grid_1d(n_samples: int) -> torch.Tensor
----------
n_samples: int
The number of samples on the grid prior to padding.
device: torch.device
The torch device on which to store the tensor.
Returns
-------
Expand All @@ -26,7 +29,7 @@ def generate_sample_positions_for_padded_grid_1d(n_samples: int) -> torch.Tensor
padded grid.
"""
du = 1 / (n_samples - 1)
sample_coordinates = torch.linspace(-du, 1 + du, steps=n_samples + 2)
sample_coordinates = torch.linspace(-du, 1 + du, steps=n_samples + 2, device=device)

# fix for numerical stability issues around 0 and 1
# ensures valid control point indices are selected
Expand All @@ -36,8 +39,9 @@ def generate_sample_positions_for_padded_grid_1d(n_samples: int) -> torch.Tensor
return sample_coordinates


def find_control_point_idx_1d(sample_positions: torch.Tensor,
query_points: torch.Tensor):
def find_control_point_idx_1d(
sample_positions: torch.Tensor, query_points: torch.Tensor
):
"""Find indices of four control points required for cubic interpolation.
E.g. for sample positions `[0, 1, 2, 3, 4, 5]` and query point `2.5` the control
Expand Down Expand Up @@ -99,9 +103,14 @@ def interpolants_to_interpolation_data_1d(
interpolation coordinate associated with the interval `[p1, p2]`.
"""
interpolants = torch.clamp(interpolants, min=0, max=1)
device = interpolants.device
if n_samples > 1:
grid_u = generate_sample_positions_for_padded_grid_1d(n_samples)
control_point_idx = find_control_point_idx_1d(grid_u, query_points=interpolants)
grid_u = generate_sample_positions_for_padded_grid_1d(
n_samples, device=device
)
control_point_idx = find_control_point_idx_1d(
sample_positions=grid_u, query_points=interpolants
)
u_p1 = grid_u[control_point_idx[:, 1]]
du = 1 / (n_samples - 1)
interpolation_coordinate = (interpolants - u_p1) / du
Expand All @@ -110,7 +119,7 @@ def interpolants_to_interpolation_data_1d(
torch.tensor([0, 1, 2, 3]), 'p -> b p', b=len(interpolants)
)
interpolation_coordinate = einops.repeat(
torch.tensor([0.5]), '1 -> b', b=len(interpolants)
torch.tensor([0.5], device=device), '1 -> b', b=len(interpolants)
)
return control_point_idx, interpolation_coordinate

Expand Down

0 comments on commit 82e6f58

Please sign in to comment.