From d77557b08c6ee22639ea915898628e8533b17013 Mon Sep 17 00:00:00 2001 From: Grapymage <52558607+zhangtingyu11@users.noreply.github.com> Date: Wed, 6 Jul 2022 20:44:19 +0800 Subject: [PATCH] [Fix] Fix index error when using multi-samplers strategy (#2094) * fix index error when using multi-samplers strategy * After every loop, change the last_fps_end_index to fps_sample_range --- mmcv/ops/points_sampler.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/mmcv/ops/points_sampler.py b/mmcv/ops/points_sampler.py index f9d1c29bd5..e1fd376051 100644 --- a/mmcv/ops/points_sampler.py +++ b/mmcv/ops/points_sampler.py @@ -104,7 +104,6 @@ def forward(self, points_xyz: Tensor, features: Tensor) -> Tensor: """ indices = [] last_fps_end_index = 0 - for fps_sample_range, sampler, npoint in zip( self.fps_sample_range_list, self.samplers, self.num_point): assert fps_sample_range < points_xyz.shape[1] @@ -116,8 +115,8 @@ def forward(self, points_xyz: Tensor, features: Tensor) -> Tensor: else: sample_features = None else: - sample_points_xyz = \ - points_xyz[:, last_fps_end_index:fps_sample_range] + sample_points_xyz = points_xyz[:, last_fps_end_index: + fps_sample_range] if features is not None: sample_features = features[:, :, last_fps_end_index: fps_sample_range] @@ -128,7 +127,7 @@ def forward(self, points_xyz: Tensor, features: Tensor) -> Tensor: npoint) indices.append(fps_idx + last_fps_end_index) - last_fps_end_index += fps_sample_range + last_fps_end_index = fps_sample_range indices = torch.cat(indices, dim=1) return indices