Skip to content

Commit

Permalink
[ENH] Kmeans allow empty clusters (#1400)
Browse files Browse the repository at this point in the history
* enabled kmeans to handle empty clusters

* update comments
  • Loading branch information
chrisholder authored Apr 18, 2024
1 parent 9e36ed7 commit 3e9d959
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 3 deletions.
46 changes: 44 additions & 2 deletions aeon/clustering/_k_means.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,13 @@ def _fit_one_init(self, X: np.ndarray) -> tuple:
curr_labels = curr_pw.argmin(axis=1)
curr_inertia = curr_pw.min(axis=1).sum()

# If an empty cluster is encountered
if np.unique(curr_labels).size < self.n_clusters:
raise EmptyClusterError
curr_pw, curr_labels, curr_inertia, cluster_centres = (
self._handle_empty_cluster(
X, cluster_centres, curr_pw, curr_labels, curr_inertia
)
)

if self.verbose:
print("%.3f" % curr_inertia, end=" --> ") # noqa: T001, T201
Expand Down Expand Up @@ -290,7 +295,7 @@ def _check_params(self, X: np.ndarray) -> None:
isinstance(self.init_algorithm, np.ndarray)
and len(self.init_algorithm) == self.n_clusters
):
self._init_algorithm = self.init_algorithm
self._init_algorithm = self.init_algorithm.copy()
else:
raise ValueError(
f"The value provided for init_algorithm: {self.init_algorithm} is "
Expand Down Expand Up @@ -347,6 +352,43 @@ def _kmeans_plus_plus_center_initializer(self, X: np.ndarray):
centers = X[indexes]
return centers

def _handle_empty_cluster(
self,
X: np.ndarray,
cluster_centres: np.ndarray,
curr_pw: np.ndarray,
curr_labels: np.ndarray,
curr_inertia: float,
):
"""Handle an empty cluster.
This functions finds the time series that is furthest from its assigned centre
and then uses that as the new centre for the empty cluster. In terms of
optimisation this means it selects the time series that will reduce inertia
by the most.
"""
empty_clusters = np.setdiff1d(np.arange(self.n_clusters), curr_labels)
j = 0

while empty_clusters.size > 0:
# Assign each time series to the cluster that is closest to it
# and then find the time series that is furthest from its assigned centre
current_empty_cluster_index = empty_clusters[0]
index_furthest_from_centre = curr_pw.min(axis=1).argmax()
cluster_centres[current_empty_cluster_index] = X[index_furthest_from_centre]
curr_pw = pairwise_distance(
X, cluster_centres, metric=self.distance, **self._distance_params
)
curr_labels = curr_pw.argmin(axis=1)
curr_inertia = curr_pw.min(axis=1).sum()
empty_clusters = np.setdiff1d(np.arange(self.n_clusters), curr_labels)
j += 1
if j > self.n_clusters:
# This should be unreachable but just a safety check to stop it looping
# forever
raise EmptyClusterError
return curr_pw, curr_labels, curr_inertia, cluster_centres

@classmethod
def get_test_params(cls, parameter_set="default"):
"""Return testing parameter settings for the estimator.
Expand Down
79 changes: 78 additions & 1 deletion aeon/clustering/tests/test_k_means.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def test_means_init():
distance="euclidean",
n_clusters=num_clusters,
)
kmeans.fit(X_train)
kmeans.fit(custom_init_centres)

assert np.array_equal(kmeans.cluster_centers_, custom_init_centres)

Expand All @@ -342,3 +342,80 @@ def test_custom_distance_params():
average_params={"init_barycenter": "mean"},
)
assert not np.array_equal(default_dist, custom_params_dist)


def test_empty_cluster():
"""Test empty cluster handling."""
first = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
second = np.array([[4, 5, 6], [7, 8, 9], [11, 12, 13]])
third = np.array([[24, 25, 26], [27, 28, 29], [30, 31, 32]])
forth = np.array([[14, 15, 16], [17, 18, 19], [20, 21, 22]])

# Test where two swap must happen to avoid empty clusters
empty_cluster = np.array([[100, 100, 100], [100, 100, 100], [100, 100, 100]])
init_centres = np.array([first, empty_cluster, empty_cluster])

kmeans = TimeSeriesKMeans(
random_state=1,
n_init=1,
max_iter=5,
init_algorithm=init_centres,
distance="euclidean",
averaging_method="mean",
n_clusters=3,
)

kmeans.fit(np.array([first, second, third, forth]))

assert not np.array_equal(kmeans.cluster_centers_, init_centres)
assert np.unique(kmeans.labels_).size == 3

# Test that if a duplicate centre would be created the algorithm
init_centres = np.array([first, first, first])

kmeans = TimeSeriesKMeans(
random_state=1,
n_init=1,
max_iter=5,
init_algorithm=init_centres,
distance="euclidean",
averaging_method="mean",
n_clusters=3,
)

kmeans.fit(np.array([first, second, third]))

assert not np.array_equal(kmeans.cluster_centers_, init_centres)
assert np.unique(kmeans.labels_).size == 3

# Test duplicate data in dataset
init_centres = np.array([first, empty_cluster])
kmeans = TimeSeriesKMeans(
random_state=1,
n_init=1,
max_iter=5,
init_algorithm=init_centres,
distance="euclidean",
averaging_method="mean",
n_clusters=2,
)

kmeans.fit(np.array([first, first, first, first, second]))

assert not np.array_equal(kmeans.cluster_centers_, init_centres)
assert np.unique(kmeans.labels_).size == 2

# Test impossible to have 3 different clusters
init_centres = np.array([first, empty_cluster, empty_cluster])
kmeans = TimeSeriesKMeans(
random_state=1,
n_init=1,
max_iter=5,
init_algorithm=init_centres,
distance="euclidean",
averaging_method="mean",
n_clusters=3,
)

with pytest.raises(ValueError):
kmeans.fit(np.array([first, first, first, first, first]))

0 comments on commit 3e9d959

Please sign in to comment.