Skip to content

Commit

Permalink
Clustering: Output observed PVR for each cluster
Browse files Browse the repository at this point in the history
Each Cluster object now also stores the actual observed peak/valley ratio for
the given cluster, where it previously only stored the maximally allows PVR.
This metadata can be valuable, as clusters with low PVR are usually of higher
quality.
  • Loading branch information
jakobnissen committed Nov 10, 2023
1 parent ce9ab88 commit 51c5e8e
Showing 1 changed file with 23 additions and 11 deletions.
34 changes: 23 additions & 11 deletions vamb/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from math import ceil as _ceil
from torch.functional import Tensor as _Tensor
import vamb.vambtools as _vambtools
from typing import TypeVar, Union, cast
from typing import TypeVar, Union, Optional, cast
from collections.abc import Sequence, Iterable

_DEFAULT_RADIUS = 0.12
Expand Down Expand Up @@ -77,7 +77,8 @@ class Cluster:
"medoid",
"seed",
"members",
"pvr",
"maximal_pvr",
"observed_pvr",
"radius",
"isdefault",
"successes",
Expand All @@ -89,7 +90,8 @@ def __init__(
medoid: int,
seed: int,
members: _np.ndarray,
pvr: float,
maximal_pvr: float,
observed_pvr: Optional[float],
radius: float,
isdefault: bool,
successes: int,
Expand All @@ -98,7 +100,8 @@ def __init__(
self.medoid = medoid
self.seed = seed
self.members = members
self.pvr = pvr
self.maximal_pvr = maximal_pvr
self.observed_pvr = observed_pvr
self.radius = radius
self.isdefault = isdefault
self.successes = successes
Expand All @@ -112,8 +115,8 @@ def as_tuple(self) -> tuple[int, set[int]]:

def dump(self) -> str:
return (
f"{self.medoid}\t{self.seed}\t{self.pvr}\t{self.radius}\t{self.isdefault}"
f"\t{self.successes}\t{self.attempts}\t"
f"{self.medoid}\t{self.seed}\t{self.maximal_pvr}\t{self.observed_pvr}"
f"\t{self.radius}\t{self.isdefault}\t{self.successes}\t{self.attempts}\t"
) + ",".join([str(i) for i in self.members])

def __str__(self) -> str:
Expand All @@ -126,7 +129,8 @@ def __str__(self) -> str:
seed: {self.seed}
radius: {radius}
successes: {self.successes} / {self.attempts}
pvr: {self.pvr:.1f}
max pvr: {self.maximal_pvr:.1f}
obs pvr: {self.observed_pvr:.1f}
"""


Expand Down Expand Up @@ -461,7 +465,9 @@ def wander_medoid(self, seed) -> tuple[int, _Tensor]:

return (medoid, distances)

def find_threshold(self, distances: _Tensor) -> Union[Loner, NoThreshold, float]:
def find_threshold(
self, distances: _Tensor
) -> Union[Loner, NoThreshold, tuple[float, float]]:
# If the point is a loner, immediately return a threshold in where only
# that point is contained.
if _torch.count_nonzero(distances < 0.05) == 1:
Expand Down Expand Up @@ -512,7 +518,8 @@ def find_threshold(self, distances: _Tensor) -> Union[Loner, NoThreshold, float]
# Analyze the point densities to find the valley
x = 0
density_at_minimum = 0.0
for density in densities:
for density_ in densities:
density = cast(float, density_.item())
# Define the first "peak" in point density. That's simply the max until
# the peak is defined as being over.
if not peak_over and density > peak_density:
Expand Down Expand Up @@ -548,7 +555,8 @@ def find_threshold(self, distances: _Tensor) -> Union[Loner, NoThreshold, float]
if threshold > 0.2 + self.peak_valley_ratio:
return NoThreshold()
else:
return threshold
observed_pvr = density_at_minimum / peak_density
return (threshold, observed_pvr)

def find_cluster(self) -> tuple[Cluster, int, _Tensor]:
while True:
Expand All @@ -561,6 +569,7 @@ def find_cluster(self) -> tuple[Cluster, int, _Tensor]:
seed,
_np.array([self.indices[medoid].item()]),
self.peak_valley_ratio,
None,
_DEFAULT_RADIUS,
False,
self.successes,
Expand All @@ -580,6 +589,7 @@ def find_cluster(self) -> tuple[Cluster, int, _Tensor]:
seed,
self.indices[points].numpy(),
self.peak_valley_ratio,
None,
_DEFAULT_RADIUS,
True,
self.successes,
Expand All @@ -589,7 +599,8 @@ def find_cluster(self) -> tuple[Cluster, int, _Tensor]:
else:
self.update_successes(False)

elif isinstance(threshold, float):
elif isinstance(threshold, tuple):
(threshold, observed_pvr) = threshold
points = _smaller_indices(
distances, self.kept_mask, threshold, self.cuda
)
Expand All @@ -598,6 +609,7 @@ def find_cluster(self) -> tuple[Cluster, int, _Tensor]:
seed,
self.indices[points].numpy(),
self.peak_valley_ratio,
observed_pvr,
threshold,
False,
self.successes,
Expand Down

0 comments on commit 51c5e8e

Please sign in to comment.