7
7
from aeon .base ._base import _clone_estimator
8
8
from aeon .clustering .deep_learning import BaseDeepClusterer
9
9
from aeon .testing .testing_data import FULL_TEST_DATA_DICT
10
+ from aeon .utils .validation import get_n_cases
10
11
11
12
12
13
def _yield_clustering_checks (estimator_class , estimator_instances , datatypes ):
@@ -26,6 +27,10 @@ def _yield_clustering_checks(estimator_class, estimator_instances, datatypes):
26
27
estimator = estimator ,
27
28
datatype = datatypes [i ][0 ],
28
29
)
30
+ for datatype in datatypes [i ]:
31
+ yield partial (
32
+ check_clusterer_output , estimator = estimator , datatype = datatype
33
+ )
29
34
30
35
31
36
def check_clusterer_tags_consistent (estimator_class ):
@@ -82,3 +87,39 @@ def check_clustering_random_state_deep_learning(estimator, datatype):
82
87
_weight2 = np .asarray (weights2 [j ])
83
88
84
89
np .testing .assert_almost_equal (_weight1 , _weight2 , 4 )
90
+
91
+
92
+ def check_clusterer_output (estimator , datatype ):
93
+ """Test clusterer outputs the correct data types and values.
94
+
95
+ Test predict produces a np.array or pd.Series with only values seen in the train
96
+ data, and that predict_proba probability estimates add up to one.
97
+ """
98
+ estimator = _clone_estimator (estimator )
99
+
100
+ unique_labels = np .unique (FULL_TEST_DATA_DICT [datatype ]["train" ][1 ])
101
+
102
+ # run fit and predict
103
+ estimator .fit (
104
+ FULL_TEST_DATA_DICT [datatype ]["train" ][0 ],
105
+ FULL_TEST_DATA_DICT [datatype ]["train" ][1 ],
106
+ )
107
+ assert hasattr (estimator , "labels_" )
108
+ assert isinstance (estimator .labels_ , np .ndarray )
109
+
110
+ y_pred = estimator .predict (FULL_TEST_DATA_DICT [datatype ]["test" ][0 ])
111
+
112
+ # check predict
113
+ assert isinstance (y_pred , np .ndarray )
114
+ assert y_pred .shape == (get_n_cases (FULL_TEST_DATA_DICT [datatype ]["test" ][0 ]),)
115
+ assert np .all (np .isin (np .unique (y_pred ), unique_labels ))
116
+
117
+ # check predict proba (all classifiers have predict_proba by default)
118
+ y_proba = estimator .predict_proba (FULL_TEST_DATA_DICT [datatype ]["test" ][0 ])
119
+
120
+ assert isinstance (y_proba , np .ndarray )
121
+ assert y_proba .shape == (
122
+ get_n_cases (FULL_TEST_DATA_DICT [datatype ]["test" ][0 ]),
123
+ len (unique_labels ),
124
+ )
125
+ np .testing .assert_almost_equal (y_proba .sum (axis = 1 ), 1 , decimal = 4 )
0 commit comments