Skip to content

Commit 2e09773

Browse files
[DOC,ENH] base docs and testing (#2273)
* base docs * docs and tests for base * tag function usage * tag function usage * refactor * compose testing * compose testing 2 * compose
1 parent 03c6cd2 commit 2e09773

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+663
-443
lines changed

aeon/anomaly_detection/base.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,11 @@ def fit(self, X, y=None, axis=1):
115115
BaseAnomalyDetector
116116
The fitted estimator, reference to self.
117117
"""
118-
if self.get_class_tag("fit_is_empty"):
118+
if self.get_tag("fit_is_empty"):
119119
self.is_fitted = True
120120
return self
121121

122-
if self.get_class_tag("requires_y"):
122+
if self.get_tag("requires_y"):
123123
if y is None:
124124
raise ValueError("Tag requires_y is true, but fit called with y=None")
125125

@@ -159,7 +159,7 @@ def predict(self, X, axis=1) -> np.ndarray:
159159
A boolean, int or float array of length len(X), where each element indicates
160160
whether the corresponding subsequence is anomalous or its anomaly score.
161161
"""
162-
fit_empty = self.get_class_tag("fit_is_empty")
162+
fit_empty = self.get_tag("fit_is_empty")
163163
if not fit_empty:
164164
self._check_is_fitted()
165165

@@ -194,7 +194,7 @@ def fit_predict(self, X, y=None, axis=1) -> np.ndarray:
194194
A boolean, int or float array of length len(X), where each element indicates
195195
whether the corresponding subsequence is anomalous or its anomaly score.
196196
"""
197-
if self.get_class_tag("requires_y"):
197+
if self.get_tag("requires_y"):
198198
if y is None:
199199
raise ValueError("Tag requires_y is true, but fit called with y=None")
200200

@@ -203,7 +203,7 @@ def fit_predict(self, X, y=None, axis=1) -> np.ndarray:
203203

204204
X = self._preprocess_series(X, axis, True)
205205

206-
if self.get_class_tag("fit_is_empty"):
206+
if self.get_tag("fit_is_empty"):
207207
self.is_fitted = True
208208
return self._predict(X)
209209

@@ -230,7 +230,7 @@ def _check_y(self, y: VALID_INPUT_TYPES) -> np.ndarray:
230230
# Remind user if y is not required for this estimator on failure
231231
req_msg = (
232232
f"{self.__class__.__name__} does not require a y input."
233-
if self.get_class_tag("requires_y")
233+
if self.get_tag("requires_y")
234234
else ""
235235
)
236236
new_y = y

aeon/base/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
"BaseAeonEstimator",
55
"BaseCollectionEstimator",
66
"BaseSeriesEstimator",
7-
"_ComposableEstimatorMixin",
7+
"ComposableEstimatorMixin",
88
]
99

1010
from aeon.base._base import BaseAeonEstimator
1111
from aeon.base._base_collection import BaseCollectionEstimator
1212
from aeon.base._base_series import BaseSeriesEstimator
13-
from aeon.base._meta import _ComposableEstimatorMixin
13+
from aeon.base._compose import ComposableEstimatorMixin

aeon/base/_base.py

+44-30
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,19 @@ class BaseAeonEstimator(BaseEstimator, ABC):
1919
2020
Contains the following methods:
2121
22-
reset estimator to post-init - reset(keep)
23-
clone stimator (copy) - clone(random_state)
24-
inspect tags (class method) - get_class_tags()
25-
inspect tags (one tag, class) - get_class_tag(tag_name, tag_value_default,
22+
- reset estimator to post-init - reset(keep)
23+
- clone stimator (copy) - clone(random_state)
24+
- inspect tags (class method) - get_class_tags()
25+
- inspect tags (one tag, class) - get_class_tag(tag_name, tag_value_default,
2626
raise_error)
27-
inspect tags (all) - get_tags()
28-
inspect tags (one tag) - get_tag(tag_name, tag_value_default, raise_error)
29-
setting dynamic tags - set_tags(**tag_dict)
30-
get fitted parameters - get_fitted_params(deep)
27+
- inspect tags (all) - get_tags()
28+
- inspect tags (one tag) - get_tag(tag_name, tag_value_default, raise_error)
29+
- setting dynamic tags - set_tags(**tag_dict)
30+
- get fitted parameters - get_fitted_params(deep)
3131
3232
All estimators have the attribute:
3333
34-
fitted state flag - is_fitted
34+
- fitted state flag - is_fitted
3535
"""
3636

3737
_tags = {
@@ -63,7 +63,7 @@ def reset(self, keep=None):
6363
hyper-parameters (arguments of ``__init__``)
6464
object attributes containing double-underscores, i.e., the string "__"
6565
runs ``__init__`` with current values of hyperparameters (result of
66-
get_params)
66+
``get_params``)
6767
6868
Not affected by the reset are:
6969
object attributes containing double-underscores
@@ -73,13 +73,13 @@ class and object methods, class attributes
7373
Parameters
7474
----------
7575
keep : None, str, or list of str, default=None
76-
If None, all attributes are removed except hyper-parameters.
76+
If None, all attributes are removed except hyperparameters.
7777
If str, only the attribute with this name is kept.
7878
If list of str, only the attributes with these names are kept.
7979
8080
Returns
8181
-------
82-
self
82+
self : object
8383
Reference to self.
8484
"""
8585
# retrieve parameters to copy them later
@@ -163,30 +163,35 @@ def get_class_tags(cls):
163163
return deepcopy(collected_tags)
164164

165165
@classmethod
166-
def get_class_tag(cls, tag_name, tag_value_default=None, raise_error=False):
166+
def get_class_tag(
167+
cls,
168+
tag_name,
169+
raise_error=True,
170+
tag_value_default=None,
171+
):
167172
"""
168173
Get tag value from estimator class (only class tags).
169174
170175
Parameters
171176
----------
172177
tag_name : str
173178
Name of tag value.
174-
tag_value_default : any type
175-
Default/fallback value if tag is not found.
176-
raise_error : bool
179+
raise_error : bool, default=True
177180
Whether a ValueError is raised when the tag is not found.
181+
tag_value_default : any type, default=None
182+
Default/fallback value if tag is not found and error is not raised.
178183
179184
Returns
180185
-------
181186
tag_value
182-
Value of the ``tag_name`` tag in self.
183-
If not found, returns an error if raise_error is True, otherwise it
184-
returns `tag_value_default`.
187+
Value of the ``tag_name`` tag in cls.
188+
If not found, returns an error if ``raise_error`` is True, otherwise it
189+
returns ``tag_value_default``.
185190
186191
Raises
187192
------
188193
ValueError
189-
if raise_error is ``True`` and ``tag_name`` is not in
194+
if ``raise_error`` is True and ``tag_name`` is not in
190195
``self.get_tags().keys()``
191196
192197
Examples
@@ -221,7 +226,7 @@ def get_tags(self):
221226
collected_tags.update(self._tags_dynamic)
222227
return deepcopy(collected_tags)
223228

224-
def get_tag(self, tag_name, tag_value_default=None, raise_error=True):
229+
def get_tag(self, tag_name, raise_error=True, tag_value_default=None):
225230
"""
226231
Get tag value from estimator class.
227232
@@ -231,17 +236,17 @@ def get_tag(self, tag_name, tag_value_default=None, raise_error=True):
231236
----------
232237
tag_name : str
233238
Name of tag to be retrieved.
234-
tag_value_default : any type, default=None
235-
Default/fallback value if tag is not found.
236-
raise_error : bool
239+
raise_error : bool, default=True
237240
Whether a ValueError is raised when the tag is not found.
241+
tag_value_default : any type, default=None
242+
Default/fallback value if tag is not found and error is not raised.
238243
239244
Returns
240245
-------
241246
tag_value
242247
Value of the ``tag_name`` tag in self.
243-
If not found, returns an error if raise_error is True, otherwise it
244-
returns `tag_value_default`.
248+
If not found, returns an error if ``raise_error`` is True, otherwise it
249+
returns ``tag_value_default``.
245250
246251
Raises
247252
------
@@ -276,7 +281,7 @@ def set_tags(self, **tag_dict):
276281
277282
Returns
278283
-------
279-
self
284+
self : object
280285
Reference to self.
281286
"""
282287
tag_update = deepcopy(tag_dict)
@@ -297,7 +302,7 @@ def get_fitted_params(self, deep=True):
297302
298303
Returns
299304
-------
300-
fitted_params : mapping of string to any
305+
fitted_params : dict
301306
Fitted parameter names mapped to their values.
302307
"""
303308
self._check_is_fitted()
@@ -312,7 +317,13 @@ def _get_fitted_params(self, est, deep):
312317

313318
out = dict()
314319
for key in fitted_params:
315-
value = getattr(est, key)
320+
# some of these can be properties and can make assumptions which may not be
321+
# true in aeon i.e. sklearn Pipeline feature_names_in_
322+
try:
323+
value = getattr(est, key)
324+
except AttributeError:
325+
continue
326+
316327
if deep and isinstance(value, BaseEstimator):
317328
deep_items = self._get_fitted_params(value, deep).items()
318329
out.update((key + "__" + k, val) for k, val in deep_items)
@@ -406,7 +417,10 @@ def _validate_data(self, **kwargs):
406417
)
407418

408419
def get_metadata_routing(self):
409-
"""Sklearn metadata routing."""
420+
"""Sklearn metadata routing.
421+
422+
Not supported by ``aeon`` estimators.
423+
"""
410424
raise NotImplementedError(
411425
"aeon estimators do not have a get_metadata_routing method."
412426
)

aeon/base/_meta.py renamed to aeon/base/_compose.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
"""Implements meta estimator for estimators composed of other estimators."""
22

33
__maintainer__ = ["MatthewMiddlehurst"]
4-
__all__ = ["_ComposableEstimatorMixin"]
4+
__all__ = ["ComposableEstimatorMixin"]
55

66
from abc import ABC, abstractmethod
77

88
from aeon.base import BaseAeonEstimator
99
from aeon.base._base import _clone_estimator
1010

1111

12-
class _ComposableEstimatorMixin(ABC):
12+
class ComposableEstimatorMixin(ABC):
1313
"""Handles parameter management for estimators composed of named estimators.
1414
1515
Parts (i.e. get_params and set_params) adapted or copied from the scikit-learn
@@ -52,9 +52,8 @@ def get_params(self, deep=True):
5252
out.update(estimators)
5353

5454
for name, estimator in estimators:
55-
if hasattr(estimator, "get_params"):
56-
for key, value in estimator.get_params(deep=True).items():
57-
out[f"{name}__{key}"] = value
55+
for key, value in estimator.get_params(deep=True).items():
56+
out[f"{name}__{key}"] = value
5857
return out
5958

6059
def set_params(self, **params):
@@ -119,7 +118,7 @@ def get_fitted_params(self, deep=True):
119118
120119
Returns
121120
-------
122-
fitted_params : mapping of string to any
121+
fitted_params : dict
123122
Fitted parameter names mapped to their values.
124123
"""
125124
self._check_is_fitted()
@@ -190,16 +189,16 @@ def _check_estimators(
190189
for obj in estimators:
191190
if isinstance(obj, tuple):
192191
if not allow_tuples:
193-
raise TypeError(
192+
raise ValueError(
194193
f"{attr_name} should only contain singular estimators instead "
195194
f"of (str, estimator) tuples."
196195
)
197196
if not len(obj) == 2 or not isinstance(obj[0], str):
198-
raise TypeError(
197+
raise ValueError(
199198
f"All tuples in {attr_name} must be of form (str, estimator)."
200199
)
201200
if not isinstance(obj[1], class_type):
202-
raise TypeError(
201+
raise ValueError(
203202
f"All estimators in {attr_name} must be an instance "
204203
f"of {class_type}."
205204
)
@@ -213,15 +212,15 @@ def _check_estimators(
213212
raise ValueError(f"Estimator name is invalid: {obj[0]}")
214213
if unique_names:
215214
if obj[0] in names:
216-
raise TypeError(
215+
raise ValueError(
217216
f"Names in {attr_name} must be unique. Found duplicate "
218217
f"name: {obj[0]}."
219218
)
220219
else:
221220
names.append(obj[0])
222221
elif isinstance(obj, class_type):
223222
if not allow_single_estimators:
224-
raise TypeError(
223+
raise ValueError(
225224
f"{attr_name} should only contain (str, estimator) tuples "
226225
f"instead of singular estimators."
227226
)
File renamed without changes.

aeon/base/estimator/compose/collection_channel_ensemble.py renamed to aeon/base/estimators/compose/collection_channel_ensemble.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313
from aeon.base import (
1414
BaseAeonEstimator,
1515
BaseCollectionEstimator,
16-
_ComposableEstimatorMixin,
16+
ComposableEstimatorMixin,
1717
)
1818
from aeon.base._base import _clone_estimator
1919

2020

21-
class BaseCollectionChannelEnsemble(_ComposableEstimatorMixin, BaseCollectionEstimator):
21+
class BaseCollectionChannelEnsemble(ComposableEstimatorMixin, BaseCollectionEstimator):
2222
"""Applies estimators to channels of an array.
2323
2424
Parameters
@@ -101,7 +101,11 @@ def __init__(
101101
missing = all(
102102
[
103103
(
104-
e[1].get_tag("capability:missing_values", False, raise_error=False)
104+
e[1].get_tag(
105+
"capability:missing_values",
106+
raise_error=False,
107+
tag_value_default=False,
108+
)
105109
if isinstance(e[1], BaseAeonEstimator)
106110
else False
107111
)
@@ -110,14 +114,20 @@ def __init__(
110114
)
111115
remainder_missing = remainder is None or (
112116
isinstance(remainder, BaseAeonEstimator)
113-
and remainder.get_tag("capability:missing_values", False, raise_error=False)
117+
and remainder.get_tag(
118+
"capability:missing_values", raise_error=False, tag_value_default=False
119+
)
114120
)
115121

116122
# can handle unequal length if all estimators can
117123
unequal = all(
118124
[
119125
(
120-
e[1].get_tag("capability:unequal_length", False, raise_error=False)
126+
e[1].get_tag(
127+
"capability:unequal_length",
128+
raise_error=False,
129+
tag_value_default=False,
130+
)
121131
if isinstance(e[1], BaseAeonEstimator)
122132
else False
123133
)
@@ -126,7 +136,9 @@ def __init__(
126136
)
127137
remainder_unequal = remainder is None or (
128138
isinstance(remainder, BaseAeonEstimator)
129-
and remainder.get_tag("capability:unequal_length", False, raise_error=False)
139+
and remainder.get_tag(
140+
"capability:unequal_length", raise_error=False, tag_value_default=False
141+
)
130142
)
131143

132144
tags_to_set = {

0 commit comments

Comments
 (0)