Skip to content

Commit f729608

Browse files
authored
Merge pull request #1 from kvttt/voxelmorph-warp-kvttt
Add `moving_seg` and `fixed_keypoints`
2 parents d22e4fc + 52311f0 commit f729608

File tree

3 files changed

+106
-9
lines changed

3 files changed

+106
-9
lines changed

monai/networks/blocks/warp.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,18 +110,29 @@ def get_reference_grid(self, ddf: torch.Tensor, jitter: bool = False, seed: int
110110
self.ref_grid.requires_grad = False
111111
return self.ref_grid
112112

113-
def forward(self, image: torch.Tensor, ddf: torch.Tensor):
113+
def forward(
114+
self, image: torch.Tensor, ddf: torch.Tensor, keypoints: torch.Tensor | None = None
115+
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
114116
"""
115117
Args:
116118
image: Tensor in shape (batch, num_channels, H, W[, D])
117119
ddf: Tensor in the same spatial size as image, in shape (batch, ``spatial_dims``, H, W[, D])
120+
keypoints: Tensor in shape (batch, N, ``spatial_dims``), optional
118121
119122
Returns:
120123
warped_image in the same shape as image (batch, num_channels, H, W[, D])
124+
warped_keypoints in the same shape as keypoints (batch, N, ``spatial_dims``), if keypoints is not None
121125
"""
126+
batch_size = image.shape[0]
122127
spatial_dims = len(image.shape) - 2
123128
if spatial_dims not in (2, 3):
124129
raise NotImplementedError(f"got unsupported spatial_dims={spatial_dims}, currently support 2 or 3.")
130+
if keypoints is not None:
131+
if keypoints.shape[-1] != spatial_dims:
132+
raise ValueError(
133+
f"Given input {spatial_dims}-d image, the last dimension of the input keypoints must be {spatial_dims}, "
134+
f"got {keypoints.shape}."
135+
)
125136
ddf_shape = (image.shape[0], spatial_dims) + tuple(image.shape[2:])
126137
if ddf.shape != ddf_shape:
127138
raise ValueError(
@@ -136,13 +147,24 @@ def forward(self, image: torch.Tensor, ddf: torch.Tensor):
136147
grid[..., i] = grid[..., i] * 2 / (dim - 1) - 1
137148
index_ordering: list[int] = list(range(spatial_dims - 1, -1, -1))
138149
grid = grid[..., index_ordering] # z, y, x -> x, y, z
139-
return F.grid_sample(
150+
warped_image = F.grid_sample(
140151
image, grid, mode=self._interp_mode, padding_mode=f"{self._padding_mode}", align_corners=True
141152
)
142-
143-
# using csrc resampling
144-
return grid_pull(image, grid, bound=self._padding_mode, extrapolate=True, interpolation=self._interp_mode)
145-
153+
else:
154+
# using csrc resampling
155+
warped_image = grid_pull(image, grid, bound=self._padding_mode, extrapolate=True, interpolation=self._interp_mode)
156+
if keypoints is not None:
157+
with torch.no_grad():
158+
offset = torch.as_tensor(image.shape[2:]).to(keypoints) / 2.0
159+
offset = offset.unsqueeze(0).unsqueeze(0)
160+
normalized_keypoints = torch.flip((keypoints - offset) / offset, (-1,))
161+
ddf_keypoints = F.grid_sample(
162+
ddf, normalized_keypoints.view(batch_size, -1, 1, 1, spatial_dims),
163+
mode=self._interp_mode, padding_mode=f"{self._padding_mode}", align_corners=True
164+
).view(batch_size, 3, -1).permute((0, 2, 1))
165+
warped_keypoints = keypoints + ddf_keypoints
166+
return warped_image, warped_keypoints
167+
return warped_image
146168

147169
class DVF2DDF(nn.Module):
148170
"""

monai/networks/nets/voxelmorph.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -440,14 +440,33 @@ def __init__(
440440
self.dvf2ddf = DVF2DDF(num_steps=self.integration_steps, mode="bilinear", padding_mode="zeros")
441441
self.warp = Warp(mode="bilinear", padding_mode="zeros")
442442

443-
def forward(self, moving: torch.Tensor, fixed: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
444-
# TODO: add optional moving_seg, fixed_seg arguments and handle warping of segmentation maps
443+
def forward(
444+
self,
445+
moving: torch.Tensor,
446+
fixed: torch.Tensor,
447+
moving_seg: torch.Tensor | None = None,
448+
fixed_keypoints: torch.Tensor | None = None
449+
) -> tuple[torch.Tensor, ...]:
445450
if moving.shape != fixed.shape:
446451
raise ValueError(
447452
"The spatial shape of the moving image should be the same as the spatial shape of the fixed image."
448453
f" Got {moving.shape} and {fixed.shape} instead."
449454
)
450455

456+
if moving_seg is not None:
457+
if moving_seg[-3:] != moving.shape[-3:]:
458+
raise ValueError(
459+
"The spatial shape of the moving segmentation should be the same as the spatial shape of the"
460+
f" moving image. Got {moving_seg.shape} and {moving.shape} instead."
461+
)
462+
463+
if fixed_keypoints is not None:
464+
if fixed_keypoints.shape[-1] != self.spatial_dims:
465+
raise ValueError(
466+
"The last dimension of the fixed keypoints should be equal to the number of spatial dimensions."
467+
f" Got {fixed_keypoints.shape[-1]} and {self.spatial_dims} instead."
468+
)
469+
451470
x = self.backbone(torch.cat([moving, fixed], dim=1))
452471

453472
if x.shape[1] != self.spatial_dims:
@@ -471,7 +490,14 @@ def forward(self, moving: torch.Tensor, fixed: torch.Tensor) -> tuple[torch.Tens
471490
if self.half_res:
472491
x = F.interpolate(x * 0.5, scale_factor=2.0, mode="trilinear", align_corners=True)
473492

474-
return self.warp(moving, x), x
493+
if moving_seg is None and fixed_keypoints is None:
494+
return self.warp(moving, x), x
495+
elif moving_seg is None and fixed_keypoints is not None:
496+
return *self.warp(moving, x, fixed_keypoints), x
497+
elif moving_seg is not None and fixed_keypoints is None:
498+
return self.warp(moving, x), self.warp(moving_seg, x), x
499+
else:
500+
return self.warp(moving, x), *self.warp(moving_seg, x, fixed_keypoints), x
475501

476502

477503
voxelmorph = VoxelMorph

tests/networks/nets/test_voxelmorph.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,18 @@
171171
TEST_CASE_9,
172172
]
173173

174+
TEST_CASE_SEG_0 = [
175+
{"spatial_dims": 3},
176+
(1, 1, 96, 96, 48), # moving image
177+
(1, 1, 96, 96, 48), # fixed image
178+
(1, 2, 96, 96, 48), # moving label
179+
(1, 1, 96, 96, 48), # expected warped moving image
180+
(1, 2, 96, 96, 48), # expected warped moving label
181+
(1, 3, 96, 96, 48), # expected ddf
182+
]
183+
184+
CASES_SEG = [TEST_CASE_SEG_0]
185+
174186
ILL_CASE_0 = [ # spatial_dims = 1
175187
{
176188
"spatial_dims": 1,
@@ -243,6 +255,15 @@
243255

244256
ILL_CASES_IN_SHAPE = [ILL_CASES_IN_SHAPE_0, ILL_CASES_IN_SHAPE_1]
245257

258+
ILL_CASE_SEG_SHAPE_0 = [ # moving_seg and moving image shape not match
259+
{"spatial_dims": 3},
260+
(1, 1, 96, 96, 48),
261+
(1, 1, 96, 96, 48),
262+
(1, 2, 80, 96, 48),
263+
]
264+
265+
ILL_CASES_SEG_SHAPE = [ILL_CASE_SEG_SHAPE_0]
266+
246267

247268
class TestVOXELMORPH(unittest.TestCase):
248269
@parameterized.expand(CASES)
@@ -252,6 +273,24 @@ def test_shape(self, input_param, input_shape, expected_shape):
252273
result = net.forward(torch.randn(input_shape).to(device))
253274
self.assertEqual(result.shape, expected_shape)
254275

276+
@parameterized.expand(CASES_SEG)
277+
def test_shape_seg(
278+
self,
279+
input_param,
280+
moving_shape, fixed_shape, moving_seg_shape,
281+
expected_warped_moving_shape, expected_warped_moving_seg_shape, expected_ddf_shape
282+
):
283+
net = VoxelMorph(**input_param).to(device)
284+
with eval_mode(net):
285+
warped_moving, warped_moving_seg, ddf = net.forward(
286+
torch.randn(moving_shape).to(device),
287+
torch.randn(fixed_shape).to(device),
288+
torch.randn(moving_seg_shape).to(device),
289+
)
290+
self.assertEqual(warped_moving.shape, expected_warped_moving_shape)
291+
self.assertEqual(warped_moving_seg.shape, expected_warped_moving_seg_shape)
292+
self.assertEqual(ddf.shape, expected_ddf_shape)
293+
255294
def test_script(self):
256295
net = VoxelMorphUNet(
257296
spatial_dims=2,
@@ -275,6 +314,16 @@ def test_ill_input_shape(self, input_param, moving_shape, fixed_shape):
275314
with eval_mode(net):
276315
_ = net.forward(torch.randn(moving_shape).to(device), torch.randn(fixed_shape).to(device))
277316

317+
@parameterized.expand(ILL_CASES_SEG_SHAPE)
318+
def test_ill_input_seg_shape(self, input_param, moving_shape, fixed_shape, moving_seg_shape):
319+
with self.assertRaises((ValueError, RuntimeError)):
320+
net = VoxelMorph(**input_param).to(device)
321+
with eval_mode(net):
322+
_ = net.forward(
323+
torch.randn(moving_shape).to(device),
324+
torch.randn(fixed_shape).to(device),
325+
torch.randn(moving_seg_shape).to(device),
326+
)
278327

279328
if __name__ == "__main__":
280329
unittest.main()

0 commit comments

Comments
 (0)