Skip to content

Commit

Permalink
Use torch's nonzero in smaller_indices
Browse files Browse the repository at this point in the history
We used to convert to Numpy, call nonzero, then convert back to Torch, since
the Numpy implementation was faster.
It no longer is, so I remove this needless optimization
  • Loading branch information
jakobnissen committed Sep 14, 2023
1 parent 72f388d commit f8b875f
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions vamb/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,18 +625,14 @@ def find_cluster(self) -> tuple[Cluster, int, _Tensor]:
def _smaller_indices(
tensor: _Tensor, kept_mask: _Tensor, threshold: float, cuda: bool
) -> _Tensor:
"""Get all indices where the tensor is smaller than the threshold.
Uses Numpy because Torch is slow - See https://github.com/pytorch/pytorch/pull/15190
"""
"""Get all indices where the tensor is smaller than the threshold."""

# If it's on GPU, we remove the already clustered points at this step.
# If it's on GPU, we remove the already clustered points at this step
# and move to CPU
if cuda:
return _torch.nonzero((tensor <= threshold) & kept_mask).flatten().cpu()
else:
arr = tensor.numpy()
indices = (arr <= threshold).nonzero()[0]
torch_indices = _torch.from_numpy(indices)
return torch_indices
return _torch.nonzero(tensor <= threshold).flatten()


def _normalize(matrix: _Tensor, inplace: bool = False) -> _Tensor:
Expand Down

0 comments on commit f8b875f

Please sign in to comment.