Skip to content

Commit

Permalink
[Fix] Fix index error when using multi-samplers strategy (#2094)
Browse files Browse the repository at this point in the history
* fix index error when using multi-samplers strategy

* After every loop, change the last_fps_end_index to fps_sample_range
  • Loading branch information
zhangtingyu11 authored Jul 6, 2022
1 parent c67ab9a commit d77557b
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions mmcv/ops/points_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def forward(self, points_xyz: Tensor, features: Tensor) -> Tensor:
"""
indices = []
last_fps_end_index = 0

for fps_sample_range, sampler, npoint in zip(
self.fps_sample_range_list, self.samplers, self.num_point):
assert fps_sample_range < points_xyz.shape[1]
Expand All @@ -116,8 +115,8 @@ def forward(self, points_xyz: Tensor, features: Tensor) -> Tensor:
else:
sample_features = None
else:
sample_points_xyz = \
points_xyz[:, last_fps_end_index:fps_sample_range]
sample_points_xyz = points_xyz[:, last_fps_end_index:
fps_sample_range]
if features is not None:
sample_features = features[:, :, last_fps_end_index:
fps_sample_range]
Expand All @@ -128,7 +127,7 @@ def forward(self, points_xyz: Tensor, features: Tensor) -> Tensor:
npoint)

indices.append(fps_idx + last_fps_end_index)
last_fps_end_index += fps_sample_range
last_fps_end_index = fps_sample_range
indices = torch.cat(indices, dim=1)

return indices
Expand Down

0 comments on commit d77557b

Please sign in to comment.