Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix GPU clustering #205

Merged
merged 4 commits into from
Sep 14, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 33 additions & 13 deletions vamb/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,11 +245,9 @@ def _init_histogram_kept_mask(self, N: int) -> tuple[_Tensor, _Tensor]:

kept_mask = _torch.ones(N, dtype=_torch.bool)
if self.cuda:
histogram = _torch.empty(_ceil(_XMAX / _DELTA_X), dtype=_torch.float).cuda()
kept_mask = kept_mask.cuda()
else:
histogram = _torch.empty(_ceil(_XMAX / _DELTA_X))

histogram = _torch.empty(_ceil(_XMAX / _DELTA_X))
return histogram, kept_mask

def __init__(
Expand Down Expand Up @@ -279,8 +277,10 @@ def __init__(
_normalize(torch_matrix, inplace=True)

# Move to GPU
torch_lengths = _torch.Tensor(lengths)
if cuda:
torch_matrix = torch_matrix.cuda()
torch_lengths = torch_lengths.cuda()

self.maxsteps: int = maxsteps
self.minsuccesses: int = minsuccesses
Expand All @@ -293,7 +293,7 @@ def __init__(
self.indices = _torch.arange(len(matrix))
self.order = _np.argsort(lengths)[::-1]
self.order_index = 0
self.lengths = _torch.Tensor(lengths)
self.lengths = torch_lengths
self.n_emitted_clusters = 0
self.n_remaining_points = len(torch_matrix)
self.peak_valley_ratio = 0.1
Expand Down Expand Up @@ -330,13 +330,16 @@ def __next__(self) -> Cluster:
def pack(self):
"Remove all used points from the matrix and indices, and reset kept_mask."
if self.cuda:
cpu_kept_mask = self.kept_mask.cpu()
self.matrix = _vambtools.torch_inplace_maskarray(
self.matrix.cpu(), self.kept_mask
self.matrix.cpu(), cpu_kept_mask
).cuda()
self.indices = self.indices[cpu_kept_mask]

else:
_vambtools.torch_inplace_maskarray(self.matrix, self.kept_mask)
self.indices = self.indices[self.kept_mask]

self.indices = self.indices[self.kept_mask]
self.lengths = self.lengths[self.kept_mask]
self.kept_mask.resize_(len(self.matrix))
self.kept_mask[:] = 1
Expand Down Expand Up @@ -381,6 +384,7 @@ def get_next_seed(self) -> int:
if (
new_index >= len(self.indices)
or self.indices[new_index].item() != order
or (self.cuda and not self.kept_mask[new_index].item())
):
self.order[i] = -1
continue
Expand Down Expand Up @@ -475,18 +479,25 @@ def find_threshold(
# We need to make a histogram of only the unclustered distances - when run on GPU
# these have not been removed and we must use the kept_mask
if self.cuda:
picked_distances = distances[self.kept_mask]
below_xmax = (distances <= _XMAX) & self.kept_mask
picked_distances = distances[below_xmax].cpu()
picked_lengths = self.lengths[below_xmax].cpu()
else:
picked_distances = distances
below_xmax = distances <= _XMAX
picked_distances = distances[below_xmax]
picked_lengths = self.lengths[below_xmax]

# TODO: https://github.com/pytorch/pytorch/issues/69519
# Currently, this function does not run on GPU. This means we must
# copy over the lengths and distances to CPU each time, which is very slow.
# If the issue is resolved, there can be large speedups on GPU
_torch.histogram(
input=picked_distances,
bins=len(self.histogram),
range=(0.0, _XMAX),
out=((self.histogram, self.histogram_edges)),
weight=self.lengths,
weight=picked_lengths,
)
# TODO: Decide: Should we remove the self point? This might create an invalid initial peak.
# On the other hand, if it's large, the peak is valid...

# When the peak_valley_ratio is too high, we need to return something to not get caught
# in an infinite loop.
Expand Down Expand Up @@ -669,8 +680,17 @@ def _sample_medoid(
"""

distances = _calc_distances(matrix, medoid)
cluster = _smaller_indices(distances, kept_mask, threshold, cuda)
local_density = (lengths[cluster] * (threshold - distances[cluster])).sum().item()

if cuda:
within_threshold = (distances <= threshold) & kept_mask
cluster = _torch.nonzero(within_threshold).flatten().cpu()
else:
within_threshold = distances.numpy() <= threshold
cluster = _torch.from_numpy(within_threshold.nonzero()[0])

closeness = threshold - distances[within_threshold]
local_density = (lengths[within_threshold] * closeness).sum().item()

return cluster, distances, local_density


Expand Down