Skip to content

Commit

Permalink
fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisholder committed Nov 20, 2024
1 parent c61c3c9 commit 5cc63cf
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 5 deletions.
2 changes: 1 addition & 1 deletion aeon/clustering/_k_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def _fit(self, X, y=None):

self._tslearn_k_shapes.fit(_X)
self._cluster_centers = self._tslearn_k_shapes.cluster_centers_
self.labels_ = self._tslearn_k_shapes.labels_
self.labels_ = self._tslearn_k_shapes.predict(_X)
self.inertia_ = self._tslearn_k_shapes.inertia_
self.n_iter_ = self._tslearn_k_shapes.n_iter_

Expand Down
2 changes: 1 addition & 1 deletion aeon/clustering/_k_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def _fit(self, X, y=None):

self._tslearn_k_shapes.fit(_X)
self._cluster_centers = self._tslearn_k_shapes.cluster_centers_
self.labels_ = self._tslearn_k_shapes.labels_
self.labels_ = self._tslearn_k_shapes.predict(_X)
self.inertia_ = self._tslearn_k_shapes.inertia_
self.n_iter_ = self._tslearn_k_shapes.n_iter_

Expand Down
4 changes: 2 additions & 2 deletions aeon/clustering/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class DummyClusterer(BaseClusterer):
Parameters
----------
strategy : str, default="random"
strategy : str, default="uniform"
The strategy to use for generating cluster labels. Supported strategies are:
- "random": Assign clusters randomly.
- "uniform": Distribute clusters uniformly among samples.
Expand Down Expand Up @@ -54,7 +54,7 @@ class DummyClusterer(BaseClusterer):
array([0, 1, 0])
"""

def __init__(self, strategy="random", n_clusters=3, random_state=None):
def __init__(self, strategy="uniform", n_clusters=3, random_state=None):
self.strategy = strategy
self.random_state = random_state
self.n_clusters = n_clusters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ def check_clusterer_output(estimator, datatype):
estimator.fit(data)
assert hasattr(estimator, "labels_")
assert isinstance(estimator.labels_, np.ndarray)
assert np.array_equal(estimator.labels_, estimator.fit_predict(data))
assert np.array_equal(estimator.labels_, estimator.predict(data))

y_pred = estimator.predict(FULL_TEST_DATA_DICT[datatype]["test"][0])
Expand Down

0 comments on commit 5cc63cf

Please sign in to comment.