Skip to content

Commit 52311f0

Browse files
committed
make VoxelMorph.forward and Warp.forward able to handle keypoints
1 parent f50e8e8 commit 52311f0

File tree

2 files changed

+30
-10
lines changed

2 files changed

+30
-10
lines changed

monai/networks/blocks/warp.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,18 +110,29 @@ def get_reference_grid(self, ddf: torch.Tensor, jitter: bool = False, seed: int
110110
self.ref_grid.requires_grad = False
111111
return self.ref_grid
112112

113-
def forward(self, image: torch.Tensor, ddf: torch.Tensor):
113+
def forward(
114+
self, image: torch.Tensor, ddf: torch.Tensor, keypoints: torch.Tensor | None = None
115+
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
114116
"""
115117
Args:
116118
image: Tensor in shape (batch, num_channels, H, W[, D])
117119
ddf: Tensor in the same spatial size as image, in shape (batch, ``spatial_dims``, H, W[, D])
120+
keypoints: Tensor in shape (batch, N, ``spatial_dims``), optional
118121
119122
Returns:
120123
warped_image in the same shape as image (batch, num_channels, H, W[, D])
124+
warped_keypoints in the same shape as keypoints (batch, N, ``spatial_dims``), if keypoints is not None
121125
"""
126+
batch_size = image.shape[0]
122127
spatial_dims = len(image.shape) - 2
123128
if spatial_dims not in (2, 3):
124129
raise NotImplementedError(f"got unsupported spatial_dims={spatial_dims}, currently support 2 or 3.")
130+
if keypoints is not None:
131+
if keypoints.shape[-1] != spatial_dims:
132+
raise ValueError(
133+
f"Given input {spatial_dims}-d image, the last dimension of the input keypoints must be {spatial_dims}, "
134+
f"got {keypoints.shape}."
135+
)
125136
ddf_shape = (image.shape[0], spatial_dims) + tuple(image.shape[2:])
126137
if ddf.shape != ddf_shape:
127138
raise ValueError(
@@ -136,13 +147,24 @@ def forward(self, image: torch.Tensor, ddf: torch.Tensor):
136147
grid[..., i] = grid[..., i] * 2 / (dim - 1) - 1
137148
index_ordering: list[int] = list(range(spatial_dims - 1, -1, -1))
138149
grid = grid[..., index_ordering] # z, y, x -> x, y, z
139-
return F.grid_sample(
150+
warped_image = F.grid_sample(
140151
image, grid, mode=self._interp_mode, padding_mode=f"{self._padding_mode}", align_corners=True
141152
)
142-
143-
# using csrc resampling
144-
return grid_pull(image, grid, bound=self._padding_mode, extrapolate=True, interpolation=self._interp_mode)
145-
153+
else:
154+
# using csrc resampling
155+
warped_image = grid_pull(image, grid, bound=self._padding_mode, extrapolate=True, interpolation=self._interp_mode)
156+
if keypoints is not None:
157+
with torch.no_grad():
158+
offset = torch.as_tensor(image.shape[2:]).to(keypoints) / 2.0
159+
offset = offset.unsqueeze(0).unsqueeze(0)
160+
normalized_keypoints = torch.flip((keypoints - offset) / offset, (-1,))
161+
ddf_keypoints = F.grid_sample(
162+
ddf, normalized_keypoints.view(batch_size, -1, 1, 1, spatial_dims),
163+
mode=self._interp_mode, padding_mode=f"{self._padding_mode}", align_corners=True
164+
).view(batch_size, 3, -1).permute((0, 2, 1))
165+
warped_keypoints = keypoints + ddf_keypoints
166+
return warped_image, warped_keypoints
167+
return warped_image
146168

147169
class DVF2DDF(nn.Module):
148170
"""

monai/networks/nets/voxelmorph.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -493,13 +493,11 @@ def forward(
493493
if moving_seg is None and fixed_keypoints is None:
494494
return self.warp(moving, x), x
495495
elif moving_seg is None and fixed_keypoints is not None:
496-
# TODO: implement keypoint warping
497-
pass
496+
return *self.warp(moving, x, fixed_keypoints), x
498497
elif moving_seg is not None and fixed_keypoints is None:
499498
return self.warp(moving, x), self.warp(moving_seg, x), x
500499
else:
501-
# TODO: implement keypoint warping
502-
pass
500+
return self.warp(moving, x), *self.warp(moving_seg, x, fixed_keypoints), x
503501

504502

505503
voxelmorph = VoxelMorph

0 commit comments

Comments
 (0)