Skip to content

Commit

Permalink
Fixup: More refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
jakobnissen committed Sep 1, 2023
1 parent 5346123 commit 1da7de0
Showing 1 changed file with 81 additions and 52 deletions.
133 changes: 81 additions & 52 deletions vamb/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,29 +191,56 @@ class ClusterGenerator:
"""

__slots__ = [
# Maximum number of futile steps taken in the wander_medoid function until it gives up
# and emits the currently best medoid as the medoid
"maxsteps",
# Minimum number of successul clusters emitted of the last `attempts`, lest the
# peak_valley_ratio be increased to prevent an infinite loop
"minsuccesses",
# Whether this clusterer runs on GPU
"cuda",
# Random number generator, currently used in wander_medoid function
"rng",
# Actual data to be clustered
"matrix",
# Marker object, list of markers genes to score clusters by
"markers",
# This are the original indices of the rows of the matrix. Initially this is just 1..len(matrix),
# but if not on GPU, we delete used rows of the matrix (and indices) to speed up subsequent computations.
# Then, we can obtain the original row index by looking up in this array
"indices",
# Original indices are scored from most promising as seeds to worse, and ordered in the `order` array.
# The best contig indices are first.
"order",
# The integer index in `order` which the clusterer on next iteration will try to use as medoid
"order_index",
"nclusters",
"n_emitted_clusters",
"n_remaining_points",
# Integer labels for samples for each original data point. Used for sample splitting
"sample_identifiers",
# A float value which determines how strictly the clusterer rejects potential clusters.
# The lower, the stricter. We increase this adaptively to avoid clustering for ever
"peak_valley_ratio",
# A deque to store whether the last N candicate clusters were rejected. See `minsuccesses`
"attempts",
# An integer storing the number of True values in `attempts` to avoid looping over it
"successes",
# A buffer in which a histrogram of distances from the current medoid is stored.
# Is overwritten on each iteration, we just keep it here to avoid allocations.
"histogram",
# This bool array is False if the contig at the given index has been emitted in a previous iteration.
# When matrix is on CPU, rows in this array (and the data matrix) is continuously deleted, and we use
# this array to keep track of which rows to delete each iteration.
# On GPU, deleting rows is not feasable because that requires us to copy the matrix from and to the GPU,
# so we instead use this array to mask away any point emitted in an earlier iteration.
"kept_mask",
]

def __repr__(self) -> str:
return f"ClusterGenerator({len(self.matrix)} points, {self.nclusters} clusters)"
return f"ClusterGenerator({len(self.matrix)} points, {self.n_emitted_clusters} clusters)"

def __str__(self) -> str:
return f"""ClusterGenerator({len(self.matrix)} points, {self.nclusters} clusters)
return f"""ClusterGenerator({len(self.matrix)} points, {self.n_emitted_clusters} clusters)
CUDA: {self.cuda}
maxsteps: {self.maxsteps}
minsuccesses: {self.minsuccesses}
Expand Down Expand Up @@ -351,9 +378,10 @@ def __init__(
]
)
)[::-1]
self.order_index = -1
self.order_index = 0
self.markers = markers
self.nclusters = 0
self.n_emitted_clusters = 0
self.n_remaining_points = len(torch_matrix)
self.peak_valley_ratio = 0.1
self.attempts: _deque[bool] = _deque(maxlen=windowsize)
self.successes = 0
Expand All @@ -373,16 +401,12 @@ def __iter__(self):
return self

def __next__(self) -> Cluster:
# Stop criterion. For CUDA, inplace masking the array is too slow, so the matrix is
# unchanged. On CPU, we continually modify the matrix by removing rows.
if self.cuda:
if not _torch.any(self.kept_mask).item():
raise StopIteration
elif len(self.matrix) == 0:
if self.n_remaining_points == 0:
raise StopIteration

cluster, _, points = self.find_cluster()
self.nclusters += 1
self.n_emitted_clusters += 1
self.n_remaining_points -= len(points)

for point in points:
self.kept_mask[point] = 0
Expand All @@ -401,22 +425,22 @@ def __next__(self) -> Cluster:
def get_next_seed(self) -> int:
"Get the next seed index for a new medoid search"
n_original_contigs = len(self.order)
i = self.order_index
i = self.order_index - 1 # we increment by 1 in beginning of loop
while True:
# Get the order: That's the original index of the contig.
i = (i + 1) % n_original_contigs

# If we reach the final index, we "compact" self.order, removing any discarded -1 values.
# When we loop back to the first index after having passed over all indices before,
# we potentially have many used up -1 values to skip, so we remove these
# Since the clustering algorithm may loop over self.order many times, we can potentially
# save time.
if i + 1 == n_original_contigs:
wrap = self.order[i] == -1
if i == 0 and self.n_emitted_clusters > 0:
self.order = self.order[self.order > -1]
assert len(self.order) > 0
n_original_contigs = len(self.order)
assert n_original_contigs > 0
i = 0 if wrap else n_original_contigs - 1

order = self.order[i]
# -1 signify an index which has previously been used up
if order == -1:
continue

Expand All @@ -432,10 +456,18 @@ def get_next_seed(self) -> int:
self.order[i] = -1
continue

self.order_index = i
self.order_index = (
i + 1
) # Move to next index for the next time this is called
return new_index

def update_successes(self, success: bool):
"""Keeps track of how many clusters have been rejected (False) and accepted (True).
When sufficiently many False has been seen, the peak_valley_ratio is bumped, which relaxes
the criteria for rejection.
This prevents the clusterer from getting stuck in an infinite loop.
"""

# Keep accurately track of successes if we exceed maxlen
if len(self.attempts) == self.attempts.maxlen:
self.successes -= self.attempts.popleft()
Expand All @@ -454,46 +486,48 @@ def update_successes(self, success: bool):
self.attempts.clear()
self.successes = 0

# After relaxing criteria, start over from the best candidate
# seed contigs which may have been skipped the first time around
self.order_index = 0

def wander_medoid(self, seed) -> tuple[int, _Tensor]:
"""Keeps sampling new points within the cluster until it has sampled
max_attempts without getting a new set of cluster with lower average
distance"""

futile_attempts = 0
medoid = seed
tried = {medoid} # keep track of already-tried medoids
cluster, distances, average_distance = _sample_medoid(
self.matrix, self.kept_mask, seed, _MEDOID_RADIUS, self.cuda
)
candidates = self.rng.choices(
cluster.tolist(), k=min(len(cluster), self.maxsteps)
)
self.rng.shuffle(candidates)
i = 0

while len(cluster) - len(tried) > 0 and futile_attempts < self.maxsteps:
sampled_medoid = int(cluster[self.rng.randrange(len(cluster))].item())

# Prevent sampling same medoid multiple times.
while sampled_medoid in tried:
sampled_medoid = int(cluster[self.rng.randrange(len(cluster))].item())

tried.add(sampled_medoid)

while i < len(candidates):
sampled_medoid = candidates[i]
sampling = _sample_medoid(
self.matrix, self.kept_mask, sampled_medoid, _MEDOID_RADIUS, self.cuda
)
sample_cluster, sample_distances, sample_avg = sampling

# If the mean distance of inner points of the sample is lower,
# we move the medoid and reset the futile_attempts count
# we move the medoid and start over
if sample_avg < average_distance:
medoid = sampled_medoid
cluster = sample_cluster
average_distance = sample_avg
futile_attempts = 0
tried = {medoid}
distances = sample_distances

average_distance = sample_avg
candidates = self.rng.choices(
cluster.tolist(), k=min(len(cluster), self.maxsteps)
)
self.rng.shuffle(candidates)
i = 0
else:
futile_attempts += 1
i += 1

return medoid, distances
return (medoid, distances)

def find_threshold(
self, distances: _Tensor
Expand All @@ -517,10 +551,12 @@ def find_threshold(
if self.histogram[:10].sum().item() == 0:
return Loner()

# When the peak_valley_ratio is too high, we need to return something to not get caught
# in an infinite loop.
must_return_points = self.peak_valley_ratio > 0.55
peak_density = 0
peak_over = False
minimum_x = None
density_at_minimum = None
minimum_x = 0.0
threshold = None
delta_x = _XMAX / len(self.histogram)
pdf_len = len(_NORMALPDF)
Expand All @@ -530,10 +566,11 @@ def find_threshold(
else:
histogram = self.histogram

# This smoothes out the histogram, so we can more reliably detect peaks
# and valleys.
densities = _torch.zeros(len(histogram) + pdf_len - 1)
for i in range(len(densities) - pdf_len + 1):
densities[i : i + pdf_len] += _NORMALPDF * histogram[i]

densities = densities[15:-15]

# Analyze the point densities to find the valley
Expand All @@ -545,10 +582,7 @@ def find_threshold(
if not peak_over and density > peak_density:
# Do not accept first peak to be after x = 0.1
if x > 0.1:
if self.peak_valley_ratio > 0.55:
return Default()
else:
return NoThreshold()
return Default() if must_return_points else NoThreshold()
peak_density = density

# Peak is over when density drops below 60% of peak density
Expand All @@ -570,18 +604,13 @@ def find_threshold(

x += delta_x

# If we have not detected a threshold, we can't return one. However, when the peak_valley_ratio
# is too high, we need to return something to not get caught in an infinite loop. So, we return
# to use a default value.
# If we have not detected a threshold, we can't return one.
if threshold is None:
if self.peak_valley_ratio > 0.55:
return Default()
else:
return NoThreshold()
return Default() if must_return_points else NoThreshold()
# Else, we check whether the threshold is too high. If not, we return it.
else:
if threshold > 0.2 + self.peak_valley_ratio:
return NoThreshold()
return Default() if must_return_points else NoThreshold()
else:
return threshold

Expand Down

0 comments on commit 1da7de0

Please sign in to comment.