Skip to content

Commit c61c3c9

Browse files
committed
fixed kmeans bug stopping tests working
1 parent f6b0ffb commit c61c3c9

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

aeon/clustering/_k_means.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def _fit_one_init(self, X: np.ndarray) -> tuple:
268268
prev_inertia = curr_inertia
269269
prev_labels = curr_labels
270270

271-
if change_in_centres < self.tol:
271+
if change_in_centres < self.tol or (i + 1) == self.max_iter:
272272
break
273273

274274
# Compute new cluster centres

aeon/testing/estimator_checking/_yield_clustering_checks.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,12 @@ def check_clusterer_output(estimator, datatype):
108108
estimator = _clone_estimator(estimator)
109109

110110
# run fit and predict
111-
estimator.fit(
112-
FULL_TEST_DATA_DICT[datatype]["train"][0],
113-
FULL_TEST_DATA_DICT[datatype]["train"][1],
114-
)
111+
data = FULL_TEST_DATA_DICT[datatype]["train"][0]
112+
estimator.fit(data)
115113
assert hasattr(estimator, "labels_")
116114
assert isinstance(estimator.labels_, np.ndarray)
115+
assert np.array_equal(estimator.labels_, estimator.fit_predict(data))
116+
assert np.array_equal(estimator.labels_, estimator.predict(data))
117117

118118
y_pred = estimator.predict(FULL_TEST_DATA_DICT[datatype]["test"][0])
119119

0 commit comments

Comments
 (0)