@@ -110,18 +110,29 @@ def get_reference_grid(self, ddf: torch.Tensor, jitter: bool = False, seed: int
110110 self .ref_grid .requires_grad = False
111111 return self .ref_grid
112112
113- def forward (self , image : torch .Tensor , ddf : torch .Tensor ):
113+ def forward (
114+ self , image : torch .Tensor , ddf : torch .Tensor , keypoints : torch .Tensor | None = None
115+ ) -> torch .Tensor | tuple [torch .Tensor , torch .Tensor ]:
114116 """
115117 Args:
116118 image: Tensor in shape (batch, num_channels, H, W[, D])
117119 ddf: Tensor in the same spatial size as image, in shape (batch, ``spatial_dims``, H, W[, D])
120+ keypoints: Tensor in shape (batch, N, ``spatial_dims``), optional
118121
119122 Returns:
120123 warped_image in the same shape as image (batch, num_channels, H, W[, D])
124+ warped_keypoints in the same shape as keypoints (batch, N, ``spatial_dims``), if keypoints is not None
121125 """
126+ batch_size = image .shape [0 ]
122127 spatial_dims = len (image .shape ) - 2
123128 if spatial_dims not in (2 , 3 ):
124129 raise NotImplementedError (f"got unsupported spatial_dims={ spatial_dims } , currently support 2 or 3." )
130+ if keypoints is not None :
131+ if keypoints .shape [- 1 ] != spatial_dims :
132+ raise ValueError (
133+ f"Given input { spatial_dims } -d image, the last dimension of the input keypoints must be { spatial_dims } , "
134+ f"got { keypoints .shape } ."
135+ )
125136 ddf_shape = (image .shape [0 ], spatial_dims ) + tuple (image .shape [2 :])
126137 if ddf .shape != ddf_shape :
127138 raise ValueError (
@@ -136,13 +147,24 @@ def forward(self, image: torch.Tensor, ddf: torch.Tensor):
136147 grid [..., i ] = grid [..., i ] * 2 / (dim - 1 ) - 1
137148 index_ordering : list [int ] = list (range (spatial_dims - 1 , - 1 , - 1 ))
138149 grid = grid [..., index_ordering ] # z, y, x -> x, y, z
139- return F .grid_sample (
150+ warped_image = F .grid_sample (
140151 image , grid , mode = self ._interp_mode , padding_mode = f"{ self ._padding_mode } " , align_corners = True
141152 )
142-
143- # using csrc resampling
144- return grid_pull (image , grid , bound = self ._padding_mode , extrapolate = True , interpolation = self ._interp_mode )
145-
153+ else :
154+ # using csrc resampling
155+ warped_image = grid_pull (image , grid , bound = self ._padding_mode , extrapolate = True , interpolation = self ._interp_mode )
156+ if keypoints is not None :
157+ with torch .no_grad ():
158+ offset = torch .as_tensor (image .shape [2 :]).to (keypoints ) / 2.0
159+ offset = offset .unsqueeze (0 ).unsqueeze (0 )
160+ normalized_keypoints = torch .flip ((keypoints - offset ) / offset , (- 1 ,))
161+ ddf_keypoints = F .grid_sample (
162+ ddf , normalized_keypoints .view (batch_size , - 1 , 1 , 1 , spatial_dims ),
163+ mode = self ._interp_mode , padding_mode = f"{ self ._padding_mode } " , align_corners = True
164+ ).view (batch_size , 3 , - 1 ).permute ((0 , 2 , 1 ))
165+ warped_keypoints = keypoints + ddf_keypoints
166+ return warped_image , warped_keypoints
167+ return warped_image
146168
147169class DVF2DDF (nn .Module ):
148170 """
0 commit comments