Skip to content

Commit a591c31

Browse files
committed
removed deep learner n_clusters and assert labels_ exists
1 parent 4c2475f commit a591c31

File tree

7 files changed

+41
-24
lines changed

7 files changed

+41
-24
lines changed

aeon/clustering/deep_learning/_ae_abgru.py

-4
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@ class AEAttentionBiGRUClusterer(BaseDeepClusterer):
2020
2121
Parameters
2222
----------
23-
n_clusters : int, default=None
24-
Number of clusters for the deep learnign model.
2523
clustering_algorithm : str, default="deprecated"
2624
Use 'estimator' parameter instead.
2725
clustering_params : dict, default=None
@@ -100,7 +98,6 @@ class AEAttentionBiGRUClusterer(BaseDeepClusterer):
10098

10199
def __init__(
102100
self,
103-
n_clusters=None,
104101
estimator=None,
105102
clustering_algorithm="deprecated",
106103
clustering_params=None,
@@ -143,7 +140,6 @@ def __init__(
143140
self.random_state = random_state
144141

145142
super().__init__(
146-
n_clusters=n_clusters,
147143
clustering_algorithm=clustering_algorithm,
148144
clustering_params=clustering_params,
149145
estimator=estimator,

aeon/clustering/deep_learning/_ae_bgru.py

-4
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@ class AEBiGRUClusterer(BaseDeepClusterer):
2020
2121
Parameters
2222
----------
23-
n_clusters : int, default=None
24-
Number of clusters for the deep learnign model.
2523
clustering_algorithm : str, default="deprecated"
2624
Use 'estimator' parameter instead.
2725
clustering_params : dict, default=None
@@ -99,7 +97,6 @@ class AEBiGRUClusterer(BaseDeepClusterer):
9997

10098
def __init__(
10199
self,
102-
n_clusters=None,
103100
clustering_algorithm="deprecated",
104101
estimator=None,
105102
clustering_params=None,
@@ -140,7 +137,6 @@ def __init__(
140137
self.save_last_model = save_last_model
141138
self.best_file_name = best_file_name
142139
self.random_state = random_state
143-
self.n_clusters = n_clusters
144140

145141
super().__init__(
146142
clustering_algorithm=clustering_algorithm,

aeon/clustering/deep_learning/_ae_dcnn.py

-4
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ class AEDCNNClusterer(BaseDeepClusterer):
1919
2020
Parameters
2121
----------
22-
n_clusters : int, default=None
23-
Number of clusters for the deep learnign model.
2422
clustering_algorithm : str, default="deprecated"
2523
Use 'estimator' parameter instead.
2624
clustering_params : dict, default=None
@@ -113,7 +111,6 @@ class AEDCNNClusterer(BaseDeepClusterer):
113111

114112
def __init__(
115113
self,
116-
n_clusters=None,
117114
estimator=None,
118115
clustering_algorithm="deprecated",
119116
clustering_params=None,
@@ -164,7 +161,6 @@ def __init__(
164161
self.random_state = random_state
165162

166163
super().__init__(
167-
n_clusters=n_clusters,
168164
clustering_params=clustering_params,
169165
clustering_algorithm=clustering_algorithm,
170166
estimator=estimator,

aeon/clustering/deep_learning/_ae_drnn.py

-4
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@ class AEDRNNClusterer(BaseDeepClusterer):
2424
2525
Parameters
2626
----------
27-
n_clusters : int, default=None
28-
Number of clusters for the deep learnign model.
2927
clustering_algorithm : str, default="deprecated"
3028
Please use the 'estimator' parameter.
3129
estimator : aeon clusterer, default=None
@@ -114,7 +112,6 @@ class AEDRNNClusterer(BaseDeepClusterer):
114112

115113
def __init__(
116114
self,
117-
n_clusters=None,
118115
estimator=None,
119116
clustering_algorithm="deprecated",
120117
clustering_params=None,
@@ -167,7 +164,6 @@ def __init__(
167164
self.random_state = random_state
168165

169166
super().__init__(
170-
n_clusters=n_clusters,
171167
estimator=estimator,
172168
clustering_algorithm=clustering_algorithm,
173169
clustering_params=clustering_params,

aeon/clustering/deep_learning/_ae_fcn.py

-4
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ class AEFCNClusterer(BaseDeepClusterer):
2121
2222
Parameters
2323
----------
24-
n_clusters : int, default=None
25-
Please use 'estimator' parameter.
2624
estimator : aeon clusterer, default=None
2725
An aeon estimator to be built using the transformed data.
2826
Defaults to aeon TimeSeriesKMeans() with euclidean distance
@@ -122,7 +120,6 @@ class AEFCNClusterer(BaseDeepClusterer):
122120

123121
def __init__(
124122
self,
125-
n_clusters=None,
126123
estimator=None,
127124
clustering_algorithm="deprecated",
128125
clustering_params=None,
@@ -173,7 +170,6 @@ def __init__(
173170
self.save_last_model = save_last_model
174171
self.best_file_name = best_file_name
175172
self.random_state = random_state
176-
self.n_clusters = n_clusters
177173

178174
super().__init__(
179175
estimator=estimator,

aeon/clustering/deep_learning/_ae_resnet.py

-4
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@ class AEResNetClusterer(BaseDeepClusterer):
2424
2525
Parameters
2626
----------
27-
n_clusters : int, default=None
28-
Please use 'estimator' parameter.
2927
estimator : aeon clusterer, default=None
3028
An aeon estimator to be built using the transformed data.
3129
Defaults to aeon TimeSeriesKMeans() with euclidean distance
@@ -131,7 +129,6 @@ class method save_last_model_to_file.
131129

132130
def __init__(
133131
self,
134-
n_clusters=None,
135132
estimator=None,
136133
n_residual_blocks=3,
137134
clustering_algorithm="deprecated",
@@ -182,7 +179,6 @@ def __init__(
182179
self.best_file_name = best_file_name
183180
self.last_file_name = last_file_name
184181
self.optimizer = optimizer
185-
self.n_clusters = n_clusters
186182

187183
self.history = None
188184

aeon/testing/estimator_checking/_yield_clustering_checks.py

+41
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from aeon.base._base import _clone_estimator
88
from aeon.clustering.deep_learning import BaseDeepClusterer
99
from aeon.testing.testing_data import FULL_TEST_DATA_DICT
10+
from aeon.utils.validation import get_n_cases
1011

1112

1213
def _yield_clustering_checks(estimator_class, estimator_instances, datatypes):
@@ -26,6 +27,10 @@ def _yield_clustering_checks(estimator_class, estimator_instances, datatypes):
2627
estimator=estimator,
2728
datatype=datatypes[i][0],
2829
)
30+
for datatype in datatypes[i]:
31+
yield partial(
32+
check_clusterer_output, estimator=estimator, datatype=datatype
33+
)
2934

3035

3136
def check_clusterer_tags_consistent(estimator_class):
@@ -82,3 +87,39 @@ def check_clustering_random_state_deep_learning(estimator, datatype):
8287
_weight2 = np.asarray(weights2[j])
8388

8489
np.testing.assert_almost_equal(_weight1, _weight2, 4)
90+
91+
92+
def check_clusterer_output(estimator, datatype):
93+
"""Test clusterer outputs the correct data types and values.
94+
95+
Test predict produces a np.array or pd.Series with only values seen in the train
96+
data, and that predict_proba probability estimates add up to one.
97+
"""
98+
estimator = _clone_estimator(estimator)
99+
100+
unique_labels = np.unique(FULL_TEST_DATA_DICT[datatype]["train"][1])
101+
102+
# run fit and predict
103+
estimator.fit(
104+
FULL_TEST_DATA_DICT[datatype]["train"][0],
105+
FULL_TEST_DATA_DICT[datatype]["train"][1],
106+
)
107+
assert hasattr(estimator, "labels_")
108+
assert isinstance(estimator.labels_, np.ndarray)
109+
110+
y_pred = estimator.predict(FULL_TEST_DATA_DICT[datatype]["test"][0])
111+
112+
# check predict
113+
assert isinstance(y_pred, np.ndarray)
114+
assert y_pred.shape == (get_n_cases(FULL_TEST_DATA_DICT[datatype]["test"][0]),)
115+
assert np.all(np.isin(np.unique(y_pred), unique_labels))
116+
117+
# check predict proba (all classifiers have predict_proba by default)
118+
y_proba = estimator.predict_proba(FULL_TEST_DATA_DICT[datatype]["test"][0])
119+
120+
assert isinstance(y_proba, np.ndarray)
121+
assert y_proba.shape == (
122+
get_n_cases(FULL_TEST_DATA_DICT[datatype]["test"][0]),
123+
len(unique_labels),
124+
)
125+
np.testing.assert_almost_equal(y_proba.sum(axis=1), 1, decimal=4)

0 commit comments

Comments
 (0)