Skip to content

Commit 2631920

Browse files
committed
Autofix linting and type checks
Signed-off-by: Kheil-Z <[email protected]>
1 parent bf43e59 commit 2631920

File tree

3 files changed

+23
-8
lines changed

3 files changed

+23
-8
lines changed

monai/networks/blocks/warp.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -152,20 +152,30 @@ def forward(
152152
)
153153
else:
154154
# using csrc resampling
155-
warped_image = grid_pull(image, grid, bound=self._padding_mode, extrapolate=True, interpolation=self._interp_mode)
155+
warped_image = grid_pull(
156+
image, grid, bound=self._padding_mode, extrapolate=True, interpolation=self._interp_mode
157+
)
156158
if keypoints is not None:
157159
with torch.no_grad():
158160
offset = torch.as_tensor(image.shape[2:]).to(keypoints) / 2.0
159161
offset = offset.unsqueeze(0).unsqueeze(0)
160162
normalized_keypoints = torch.flip((keypoints - offset) / offset, (-1,))
161-
ddf_keypoints = F.grid_sample(
162-
ddf, normalized_keypoints.view(batch_size, -1, 1, 1, spatial_dims),
163-
mode=self._interp_mode, padding_mode=f"{self._padding_mode}", align_corners=True
164-
).view(batch_size, 3, -1).permute((0, 2, 1))
163+
ddf_keypoints = (
164+
F.grid_sample(
165+
ddf,
166+
normalized_keypoints.view(batch_size, -1, 1, 1, spatial_dims),
167+
mode=self._interp_mode,
168+
padding_mode=f"{self._padding_mode}",
169+
align_corners=True,
170+
)
171+
.view(batch_size, 3, -1)
172+
.permute((0, 2, 1))
173+
)
165174
warped_keypoints = keypoints + ddf_keypoints
166175
return warped_image, warped_keypoints
167176
return warped_image
168177

178+
169179
class DVF2DDF(nn.Module):
170180
"""
171181
Layer calculates a dense displacement field (DDF) from a dense velocity field (DVF)

monai/networks/nets/voxelmorph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ def forward(
445445
moving: torch.Tensor,
446446
fixed: torch.Tensor,
447447
moving_seg: torch.Tensor | None = None,
448-
fixed_keypoints: torch.Tensor | None = None
448+
fixed_keypoints: torch.Tensor | None = None,
449449
) -> tuple[torch.Tensor, ...]:
450450
if moving.shape != fixed.shape:
451451
raise ValueError(

tests/networks/nets/test_voxelmorph.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,12 @@ def test_shape(self, input_param, input_shape, expected_shape):
277277
def test_shape_seg(
278278
self,
279279
input_param,
280-
moving_shape, fixed_shape, moving_seg_shape,
281-
expected_warped_moving_shape, expected_warped_moving_seg_shape, expected_ddf_shape
280+
moving_shape,
281+
fixed_shape,
282+
moving_seg_shape,
283+
expected_warped_moving_shape,
284+
expected_warped_moving_seg_shape,
285+
expected_ddf_shape,
282286
):
283287
net = VoxelMorph(**input_param).to(device)
284288
with eval_mode(net):
@@ -325,5 +329,6 @@ def test_ill_input_seg_shape(self, input_param, moving_shape, fixed_shape, movin
325329
torch.randn(moving_seg_shape).to(device),
326330
)
327331

332+
328333
if __name__ == "__main__":
329334
unittest.main()

0 commit comments

Comments
 (0)