diff --git a/aeon/anomaly_detection/base.py b/aeon/anomaly_detection/base.py index 934f2ef6da..87ca40aa19 100644 --- a/aeon/anomaly_detection/base.py +++ b/aeon/anomaly_detection/base.py @@ -115,11 +115,11 @@ def fit(self, X, y=None, axis=1): BaseAnomalyDetector The fitted estimator, reference to self. """ - if self.get_class_tag("fit_is_empty"): + if self.get_tag("fit_is_empty"): self.is_fitted = True return self - if self.get_class_tag("requires_y"): + if self.get_tag("requires_y"): if y is None: raise ValueError("Tag requires_y is true, but fit called with y=None") @@ -159,7 +159,7 @@ def predict(self, X, axis=1) -> np.ndarray: A boolean, int or float array of length len(X), where each element indicates whether the corresponding subsequence is anomalous or its anomaly score. """ - fit_empty = self.get_class_tag("fit_is_empty") + fit_empty = self.get_tag("fit_is_empty") if not fit_empty: self._check_is_fitted() @@ -194,7 +194,7 @@ def fit_predict(self, X, y=None, axis=1) -> np.ndarray: A boolean, int or float array of length len(X), where each element indicates whether the corresponding subsequence is anomalous or its anomaly score. """ - if self.get_class_tag("requires_y"): + if self.get_tag("requires_y"): if y is None: raise ValueError("Tag requires_y is true, but fit called with y=None") @@ -203,7 +203,7 @@ def fit_predict(self, X, y=None, axis=1) -> np.ndarray: X = self._preprocess_series(X, axis, True) - if self.get_class_tag("fit_is_empty"): + if self.get_tag("fit_is_empty"): self.is_fitted = True return self._predict(X) @@ -230,7 +230,7 @@ def _check_y(self, y: VALID_INPUT_TYPES) -> np.ndarray: # Remind user if y is not required for this estimator on failure req_msg = ( f"{self.__class__.__name__} does not require a y input." - if self.get_class_tag("requires_y") + if self.get_tag("requires_y") else "" ) new_y = y diff --git a/aeon/base/__init__.py b/aeon/base/__init__.py index 062c44b11b..a9edf83e52 100644 --- a/aeon/base/__init__.py +++ b/aeon/base/__init__.py @@ -4,10 +4,10 @@ "BaseAeonEstimator", "BaseCollectionEstimator", "BaseSeriesEstimator", - "_ComposableEstimatorMixin", + "ComposableEstimatorMixin", ] from aeon.base._base import BaseAeonEstimator from aeon.base._base_collection import BaseCollectionEstimator from aeon.base._base_series import BaseSeriesEstimator -from aeon.base._meta import _ComposableEstimatorMixin +from aeon.base._compose import ComposableEstimatorMixin diff --git a/aeon/base/_base.py b/aeon/base/_base.py index 015d510fb4..6a9f7dbb70 100644 --- a/aeon/base/_base.py +++ b/aeon/base/_base.py @@ -19,19 +19,19 @@ class BaseAeonEstimator(BaseEstimator, ABC): Contains the following methods: - reset estimator to post-init - reset(keep) - clone stimator (copy) - clone(random_state) - inspect tags (class method) - get_class_tags() - inspect tags (one tag, class) - get_class_tag(tag_name, tag_value_default, + - reset estimator to post-init - reset(keep) + - clone stimator (copy) - clone(random_state) + - inspect tags (class method) - get_class_tags() + - inspect tags (one tag, class) - get_class_tag(tag_name, tag_value_default, raise_error) - inspect tags (all) - get_tags() - inspect tags (one tag) - get_tag(tag_name, tag_value_default, raise_error) - setting dynamic tags - set_tags(**tag_dict) - get fitted parameters - get_fitted_params(deep) + - inspect tags (all) - get_tags() + - inspect tags (one tag) - get_tag(tag_name, tag_value_default, raise_error) + - setting dynamic tags - set_tags(**tag_dict) + - get fitted parameters - get_fitted_params(deep) All estimators have the attribute: - fitted state flag - is_fitted + - fitted state flag - is_fitted """ _tags = { @@ -63,7 +63,7 @@ def reset(self, keep=None): hyper-parameters (arguments of ``__init__``) object attributes containing double-underscores, i.e., the string "__" runs ``__init__`` with current values of hyperparameters (result of - get_params) + ``get_params``) Not affected by the reset are: object attributes containing double-underscores @@ -73,13 +73,13 @@ class and object methods, class attributes Parameters ---------- keep : None, str, or list of str, default=None - If None, all attributes are removed except hyper-parameters. + If None, all attributes are removed except hyperparameters. If str, only the attribute with this name is kept. If list of str, only the attributes with these names are kept. Returns ------- - self + self : object Reference to self. """ # retrieve parameters to copy them later @@ -163,7 +163,12 @@ def get_class_tags(cls): return deepcopy(collected_tags) @classmethod - def get_class_tag(cls, tag_name, tag_value_default=None, raise_error=False): + def get_class_tag( + cls, + tag_name, + raise_error=True, + tag_value_default=None, + ): """ Get tag value from estimator class (only class tags). @@ -171,22 +176,22 @@ def get_class_tag(cls, tag_name, tag_value_default=None, raise_error=False): ---------- tag_name : str Name of tag value. - tag_value_default : any type - Default/fallback value if tag is not found. - raise_error : bool + raise_error : bool, default=True Whether a ValueError is raised when the tag is not found. + tag_value_default : any type, default=None + Default/fallback value if tag is not found and error is not raised. Returns ------- tag_value - Value of the ``tag_name`` tag in self. - If not found, returns an error if raise_error is True, otherwise it - returns `tag_value_default`. + Value of the ``tag_name`` tag in cls. + If not found, returns an error if ``raise_error`` is True, otherwise it + returns ``tag_value_default``. Raises ------ ValueError - if raise_error is ``True`` and ``tag_name`` is not in + if ``raise_error`` is True and ``tag_name`` is not in ``self.get_tags().keys()`` Examples @@ -221,7 +226,7 @@ def get_tags(self): collected_tags.update(self._tags_dynamic) return deepcopy(collected_tags) - def get_tag(self, tag_name, tag_value_default=None, raise_error=True): + def get_tag(self, tag_name, raise_error=True, tag_value_default=None): """ Get tag value from estimator class. @@ -231,17 +236,17 @@ def get_tag(self, tag_name, tag_value_default=None, raise_error=True): ---------- tag_name : str Name of tag to be retrieved. - tag_value_default : any type, default=None - Default/fallback value if tag is not found. - raise_error : bool + raise_error : bool, default=True Whether a ValueError is raised when the tag is not found. + tag_value_default : any type, default=None + Default/fallback value if tag is not found and error is not raised. Returns ------- tag_value Value of the ``tag_name`` tag in self. - If not found, returns an error if raise_error is True, otherwise it - returns `tag_value_default`. + If not found, returns an error if ``raise_error`` is True, otherwise it + returns ``tag_value_default``. Raises ------ @@ -276,7 +281,7 @@ def set_tags(self, **tag_dict): Returns ------- - self + self : object Reference to self. """ tag_update = deepcopy(tag_dict) @@ -297,7 +302,7 @@ def get_fitted_params(self, deep=True): Returns ------- - fitted_params : mapping of string to any + fitted_params : dict Fitted parameter names mapped to their values. """ self._check_is_fitted() @@ -312,7 +317,13 @@ def _get_fitted_params(self, est, deep): out = dict() for key in fitted_params: - value = getattr(est, key) + # some of these can be properties and can make assumptions which may not be + # true in aeon i.e. sklearn Pipeline feature_names_in_ + try: + value = getattr(est, key) + except AttributeError: + continue + if deep and isinstance(value, BaseEstimator): deep_items = self._get_fitted_params(value, deep).items() out.update((key + "__" + k, val) for k, val in deep_items) @@ -406,7 +417,10 @@ def _validate_data(self, **kwargs): ) def get_metadata_routing(self): - """Sklearn metadata routing.""" + """Sklearn metadata routing. + + Not supported by ``aeon`` estimators. + """ raise NotImplementedError( "aeon estimators do not have a get_metadata_routing method." ) diff --git a/aeon/base/_meta.py b/aeon/base/_compose.py similarity index 95% rename from aeon/base/_meta.py rename to aeon/base/_compose.py index 6637aa47f1..0995e85de6 100644 --- a/aeon/base/_meta.py +++ b/aeon/base/_compose.py @@ -1,7 +1,7 @@ """Implements meta estimator for estimators composed of other estimators.""" __maintainer__ = ["MatthewMiddlehurst"] -__all__ = ["_ComposableEstimatorMixin"] +__all__ = ["ComposableEstimatorMixin"] from abc import ABC, abstractmethod @@ -9,7 +9,7 @@ from aeon.base._base import _clone_estimator -class _ComposableEstimatorMixin(ABC): +class ComposableEstimatorMixin(ABC): """Handles parameter management for estimators composed of named estimators. Parts (i.e. get_params and set_params) adapted or copied from the scikit-learn @@ -52,9 +52,8 @@ def get_params(self, deep=True): out.update(estimators) for name, estimator in estimators: - if hasattr(estimator, "get_params"): - for key, value in estimator.get_params(deep=True).items(): - out[f"{name}__{key}"] = value + for key, value in estimator.get_params(deep=True).items(): + out[f"{name}__{key}"] = value return out def set_params(self, **params): @@ -119,7 +118,7 @@ def get_fitted_params(self, deep=True): Returns ------- - fitted_params : mapping of string to any + fitted_params : dict Fitted parameter names mapped to their values. """ self._check_is_fitted() @@ -190,16 +189,16 @@ def _check_estimators( for obj in estimators: if isinstance(obj, tuple): if not allow_tuples: - raise TypeError( + raise ValueError( f"{attr_name} should only contain singular estimators instead " f"of (str, estimator) tuples." ) if not len(obj) == 2 or not isinstance(obj[0], str): - raise TypeError( + raise ValueError( f"All tuples in {attr_name} must be of form (str, estimator)." ) if not isinstance(obj[1], class_type): - raise TypeError( + raise ValueError( f"All estimators in {attr_name} must be an instance " f"of {class_type}." ) @@ -213,7 +212,7 @@ def _check_estimators( raise ValueError(f"Estimator name is invalid: {obj[0]}") if unique_names: if obj[0] in names: - raise TypeError( + raise ValueError( f"Names in {attr_name} must be unique. Found duplicate " f"name: {obj[0]}." ) @@ -221,7 +220,7 @@ def _check_estimators( names.append(obj[0]) elif isinstance(obj, class_type): if not allow_single_estimators: - raise TypeError( + raise ValueError( f"{attr_name} should only contain (str, estimator) tuples " f"instead of singular estimators." ) diff --git a/aeon/base/estimator/__init__.py b/aeon/base/estimators/__init__.py similarity index 100% rename from aeon/base/estimator/__init__.py rename to aeon/base/estimators/__init__.py diff --git a/aeon/base/estimator/compose/__init__.py b/aeon/base/estimators/compose/__init__.py similarity index 100% rename from aeon/base/estimator/compose/__init__.py rename to aeon/base/estimators/compose/__init__.py diff --git a/aeon/base/estimator/compose/collection_channel_ensemble.py b/aeon/base/estimators/compose/collection_channel_ensemble.py similarity index 91% rename from aeon/base/estimator/compose/collection_channel_ensemble.py rename to aeon/base/estimators/compose/collection_channel_ensemble.py index 9f21f9dece..4164536f19 100644 --- a/aeon/base/estimator/compose/collection_channel_ensemble.py +++ b/aeon/base/estimators/compose/collection_channel_ensemble.py @@ -13,12 +13,12 @@ from aeon.base import ( BaseAeonEstimator, BaseCollectionEstimator, - _ComposableEstimatorMixin, + ComposableEstimatorMixin, ) from aeon.base._base import _clone_estimator -class BaseCollectionChannelEnsemble(_ComposableEstimatorMixin, BaseCollectionEstimator): +class BaseCollectionChannelEnsemble(ComposableEstimatorMixin, BaseCollectionEstimator): """Applies estimators to channels of an array. Parameters @@ -101,7 +101,11 @@ def __init__( missing = all( [ ( - e[1].get_tag("capability:missing_values", False, raise_error=False) + e[1].get_tag( + "capability:missing_values", + raise_error=False, + tag_value_default=False, + ) if isinstance(e[1], BaseAeonEstimator) else False ) @@ -110,14 +114,20 @@ def __init__( ) remainder_missing = remainder is None or ( isinstance(remainder, BaseAeonEstimator) - and remainder.get_tag("capability:missing_values", False, raise_error=False) + and remainder.get_tag( + "capability:missing_values", raise_error=False, tag_value_default=False + ) ) # can handle unequal length if all estimators can unequal = all( [ ( - e[1].get_tag("capability:unequal_length", False, raise_error=False) + e[1].get_tag( + "capability:unequal_length", + raise_error=False, + tag_value_default=False, + ) if isinstance(e[1], BaseAeonEstimator) else False ) @@ -126,7 +136,9 @@ def __init__( ) remainder_unequal = remainder is None or ( isinstance(remainder, BaseAeonEstimator) - and remainder.get_tag("capability:unequal_length", False, raise_error=False) + and remainder.get_tag( + "capability:unequal_length", raise_error=False, tag_value_default=False + ) ) tags_to_set = { diff --git a/aeon/base/estimator/compose/collection_ensemble.py b/aeon/base/estimators/compose/collection_ensemble.py similarity index 91% rename from aeon/base/estimator/compose/collection_ensemble.py rename to aeon/base/estimators/compose/collection_ensemble.py index dd379937cd..1223414ae9 100644 --- a/aeon/base/estimator/compose/collection_ensemble.py +++ b/aeon/base/estimators/compose/collection_ensemble.py @@ -15,12 +15,12 @@ from aeon.base import ( BaseAeonEstimator, BaseCollectionEstimator, - _ComposableEstimatorMixin, + ComposableEstimatorMixin, ) from aeon.base._base import _clone_estimator -class BaseCollectionEnsemble(_ComposableEstimatorMixin, BaseCollectionEstimator): +class BaseCollectionEnsemble(ComposableEstimatorMixin, BaseCollectionEstimator): """Weighted ensemble of collection estimators with fittable ensemble weight. Parameters @@ -111,7 +111,11 @@ def __init__( multivariate = all( [ ( - e[1].get_tag("capability:multivariate", False, raise_error=False) + e[1].get_tag( + "capability:multivariate", + raise_error=False, + tag_value_default=False, + ) if isinstance(e[1], BaseAeonEstimator) else False ) @@ -123,7 +127,11 @@ def __init__( missing = all( [ ( - e[1].get_tag("capability:missing_values", False, raise_error=False) + e[1].get_tag( + "capability:missing_values", + raise_error=False, + tag_value_default=False, + ) if isinstance(e[1], BaseAeonEstimator) else False ) @@ -135,7 +143,11 @@ def __init__( unequal = all( [ ( - e[1].get_tag("capability:unequal_length", False, raise_error=False) + e[1].get_tag( + "capability:unequal_length", + raise_error=False, + tag_value_default=False, + ) if isinstance(e[1], BaseAeonEstimator) else False ) diff --git a/aeon/base/estimator/compose/collection_pipeline.py b/aeon/base/estimators/compose/collection_pipeline.py similarity index 83% rename from aeon/base/estimator/compose/collection_pipeline.py rename to aeon/base/estimators/compose/collection_pipeline.py index a21b82be4d..48e333d431 100644 --- a/aeon/base/estimator/compose/collection_pipeline.py +++ b/aeon/base/estimators/compose/collection_pipeline.py @@ -13,12 +13,12 @@ from aeon.base import ( BaseAeonEstimator, BaseCollectionEstimator, - _ComposableEstimatorMixin, + ComposableEstimatorMixin, ) from aeon.base._base import _clone_estimator -class BaseCollectionPipeline(_ComposableEstimatorMixin, BaseCollectionEstimator): +class BaseCollectionPipeline(ComposableEstimatorMixin, BaseCollectionEstimator): """Base class for composable pipelines in collection based modules. Parameters @@ -85,7 +85,11 @@ def __init__(self, transformers, _estimator, random_state=None): # *or* transformer chain removes multivariate multivariate_tags = [ ( - e[1].get_tag("capability:multivariate", False, raise_error=False) + e[1].get_tag( + "capability:multivariate", + raise_error=False, + tag_value_default=False, + ) if isinstance(e[1], BaseAeonEstimator) else False ) @@ -96,13 +100,17 @@ def __init__(self, transformers, _estimator, random_state=None): for e in self._steps: if ( isinstance(e[1], BaseAeonEstimator) - and e[1].get_tag("capability:multivariate", False, raise_error=False) + and e[1].get_tag( + "capability:multivariate", + raise_error=False, + tag_value_default=False, + ) and e[1].get_tag("output_data_type", raise_error=False) == "Tabular" ): multivariate_rm_tag = True break elif not isinstance(e[1], BaseAeonEstimator) or not e[1].get_tag( - "capability:multivariate", False, raise_error=False + "capability:multivariate", raise_error=False, tag_value_default=False ): break @@ -112,7 +120,11 @@ def __init__(self, transformers, _estimator, random_state=None): # *or* transformer chain removes missing data missing_tags = [ ( - e[1].get_tag("capability:missing_values", False, raise_error=False) + e[1].get_tag( + "capability:missing_values", + raise_error=False, + tag_value_default=False, + ) if isinstance(e[1], BaseAeonEstimator) else False ) @@ -123,13 +135,19 @@ def __init__(self, transformers, _estimator, random_state=None): for e in self._steps: if ( isinstance(e[1], BaseAeonEstimator) - and e[1].get_tag("capability:missing_values", False, raise_error=False) - and e[1].get_tag("removes_missing_values", False, raise_error=False) + and e[1].get_tag( + "capability:missing_values", + raise_error=False, + tag_value_default=False, + ) + and e[1].get_tag( + "removes_missing_values", raise_error=False, tag_value_default=False + ) ): missing_rm_tag = True break elif not isinstance(e[1], BaseAeonEstimator) or not e[1].get_tag( - "capability:missing_values", False, raise_error=False + "capability:missing_values", raise_error=False, tag_value_default=False ): break @@ -140,7 +158,11 @@ def __init__(self, transformers, _estimator, random_state=None): # *or* transformer chain transforms the series to a tabular format unequal_tags = [ ( - e[1].get_tag("capability:unequal_length", False, raise_error=False) + e[1].get_tag( + "capability:unequal_length", + raise_error=False, + tag_value_default=False, + ) if isinstance(e[1], BaseAeonEstimator) else False ) @@ -151,16 +173,24 @@ def __init__(self, transformers, _estimator, random_state=None): for e in self._steps: if ( isinstance(e[1], BaseAeonEstimator) - and e[1].get_tag("capability:unequal_length", False, raise_error=False) + and e[1].get_tag( + "capability:unequal_length", + raise_error=False, + tag_value_default=False, + ) and ( - e[1].get_tag("removes_unequal_length", False, raise_error=False) + e[1].get_tag( + "removes_unequal_length", + raise_error=False, + tag_value_default=False, + ) or e[1].get_tag("output_data_type", raise_error=False) == "Tabular" ) ): unequal_rm_tag = True break elif not isinstance(e[1], BaseAeonEstimator) or not e[1].get_tag( - "capability:unequal_length", False, raise_error=False + "capability:unequal_length", raise_error=False, tag_value_default=False ): break diff --git a/aeon/base/estimator/hybrid/__init__.py b/aeon/base/estimators/hybrid/__init__.py similarity index 57% rename from aeon/base/estimator/hybrid/__init__.py rename to aeon/base/estimators/hybrid/__init__.py index 164aee492a..642a5cc0bc 100644 --- a/aeon/base/estimator/hybrid/__init__.py +++ b/aeon/base/estimators/hybrid/__init__.py @@ -2,4 +2,4 @@ __all__ = ["BaseRIST"] -from aeon.base.estimator.hybrid.base_rist import BaseRIST +from aeon.base.estimators.hybrid.base_rist import BaseRIST diff --git a/aeon/base/estimator/hybrid/base_rist.py b/aeon/base/estimators/hybrid/base_rist.py similarity index 100% rename from aeon/base/estimator/hybrid/base_rist.py rename to aeon/base/estimators/hybrid/base_rist.py diff --git a/aeon/base/estimator/hybrid/tests/__init__.py b/aeon/base/estimators/hybrid/tests/__init__.py similarity index 100% rename from aeon/base/estimator/hybrid/tests/__init__.py rename to aeon/base/estimators/hybrid/tests/__init__.py diff --git a/aeon/base/estimator/hybrid/tests/test_base_rist.py b/aeon/base/estimators/hybrid/tests/test_base_rist.py similarity index 100% rename from aeon/base/estimator/hybrid/tests/test_base_rist.py rename to aeon/base/estimators/hybrid/tests/test_base_rist.py diff --git a/aeon/base/estimator/interval_based/__init__.py b/aeon/base/estimators/interval_based/__init__.py similarity index 52% rename from aeon/base/estimator/interval_based/__init__.py rename to aeon/base/estimators/interval_based/__init__.py index 1c499261fc..4a65216eed 100644 --- a/aeon/base/estimator/interval_based/__init__.py +++ b/aeon/base/estimators/interval_based/__init__.py @@ -2,4 +2,4 @@ __all__ = ["BaseIntervalForest"] -from aeon.base.estimator.interval_based.base_interval_forest import BaseIntervalForest +from aeon.base.estimators.interval_based.base_interval_forest import BaseIntervalForest diff --git a/aeon/base/estimator/interval_based/base_interval_forest.py b/aeon/base/estimators/interval_based/base_interval_forest.py similarity index 100% rename from aeon/base/estimator/interval_based/base_interval_forest.py rename to aeon/base/estimators/interval_based/base_interval_forest.py diff --git a/aeon/base/estimator/interval_based/tests/__init__.py b/aeon/base/estimators/interval_based/tests/__init__.py similarity index 100% rename from aeon/base/estimator/interval_based/tests/__init__.py rename to aeon/base/estimators/interval_based/tests/__init__.py diff --git a/aeon/base/estimator/interval_based/tests/test_base_interval_forest.py b/aeon/base/estimators/interval_based/tests/test_base_interval_forest.py similarity index 100% rename from aeon/base/estimator/interval_based/tests/test_base_interval_forest.py rename to aeon/base/estimators/interval_based/tests/test_base_interval_forest.py diff --git a/aeon/base/tests/test_base.py b/aeon/base/tests/test_base.py index 15b185e99d..1caafa0cdf 100644 --- a/aeon/base/tests/test_base.py +++ b/aeon/base/tests/test_base.py @@ -1,319 +1,334 @@ -""" -Tests for BaseAeonEstimator universal base class. +"""Tests for BaseAeonEstimator universal base class.""" -tests in this module: +import pytest +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler +from sklearn.tree import DecisionTreeClassifier +from sklearn.utils._metadata_requests import MetadataRequest - test_get_class_tags - tests get_class_tags inheritance logic - test_get_class_tag - tests get_class_tag logic, incl default value - test_get_tags - tests get_tags inheritance logic - test_get_tag - tests get_tag logic, incl default value - test_set_tags - tests set_tags logic and related get_tags inheritance +from aeon.base import BaseAeonEstimator +from aeon.base._base import _clone_estimator +from aeon.classification import BaseClassifier +from aeon.classification.feature_based import SummaryClassifier +from aeon.testing.mock_estimators import MockClassifier +from aeon.testing.mock_estimators._mock_classifiers import ( + MockClassifierComposite, + MockClassifierFullTags, + MockClassifierParams, +) +from aeon.testing.testing_data import EQUAL_LENGTH_UNIVARIATE_CLASSIFICATION +from aeon.transformations.collection import Tabularizer - test_reset - tests reset logic on a simple, non-composite estimator - test_reset_composite - tests reset logic on a composite estimator - test_components - tests retrieval of list of components via _components - test_get_fitted_params - tests get_fitted_params logic, nested and non-nested -""" +def test_reset(): + """Tests reset method for correct behaviour, on a simple estimator.""" + X, y = EQUAL_LENGTH_UNIVARIATE_CLASSIFICATION["numpy3D"]["train"] -__maintainer__ = [] + clf = MockClassifierParams(return_ones=True) + clf.fit(X, y) -__all__ = [ - "test_get_class_tags", - "test_get_class_tag", - "test_get_tags", - "test_get_tag", - "test_set_tags", - "test_reset", - "test_reset_composite", - "test_get_fitted_params", -] + assert clf.return_ones is True + assert clf.value == 50 + assert clf.foo_ == "bar" + assert clf.is_fitted is True + clf.__secret_att = 42 -from copy import deepcopy + clf.reset() -import pytest + assert hasattr(clf, "return_ones") and clf.return_ones is True + assert hasattr(clf, "value") and clf.value == 50 + assert hasattr(clf, "_tags") and clf._tags == MockClassifierParams._tags + assert hasattr(clf, "is_fitted") and clf.is_fitted is False + assert hasattr(clf, "__secret_att") and clf.__secret_att == 42 + assert hasattr(clf, "fit") + assert not hasattr(clf, "foo_") -from aeon.base import BaseAeonEstimator + clf.fit(X, y) + clf.reset(keep="foo_") + assert hasattr(clf, "is_fitted") and clf.is_fitted is False + assert hasattr(clf, "foo_") and clf.foo_ == "bar" -# Fixture class for testing tag system -class FixtureClassParent(BaseAeonEstimator): - _tags = {"A": "1", "B": 2, "C": 1234, 3: "D"} + clf.fit(X, y) + clf.random_att = 60 + clf.unwanted_att = 70 + clf.reset(keep=["foo_", "random_att"]) + assert hasattr(clf, "is_fitted") and clf.is_fitted is False + assert hasattr(clf, "foo_") and clf.foo_ == "bar" + assert hasattr(clf, "random_att") and clf.random_att == 60 + assert not hasattr(clf, "unwanted_att") -# Fixture class for testing tag system, child overrides tags -class FixtureClassChild(FixtureClassParent): - _tags = {"A": 42, 3: "E"} +def test_reset_composite(): + """Test reset method for correct behaviour, on a composite estimator.""" + X, y = EQUAL_LENGTH_UNIVARIATE_CLASSIFICATION["numpy3D"]["train"] -FIXTURE_CLASSCHILD = FixtureClassChild + clf = MockClassifierComposite(mock=MockClassifierParams(return_ones=True)) + clf.fit(X, y) -FIXTURE_CLASSCHILD_TAGS = { - "python_version": None, - "python_dependencies": None, - "cant_pickle": False, - "non_deterministic": False, - "algorithm_type": None, - "capability:missing_values": False, - "capability:multithreading": False, - "A": 42, - "B": 2, - "C": 1234, - 3: "E", -} + assert clf.foo_ == "bar" + assert clf.mock_.foo_ == "bar" + assert clf.mock.return_ones is True + assert clf.mock_.return_ones is True -# Fixture class for testing tag system, object overrides class tags -FIXTURE_OBJECT = FixtureClassChild() -FIXTURE_OBJECT._tags_dynamic = {"A": 42424241, "B": 3} + clf.reset() -FIXTURE_OBJECT_TAGS = { - "python_version": None, - "python_dependencies": None, - "cant_pickle": False, - "non_deterministic": False, + assert hasattr(clf.mock, "return_ones") and clf.mock.return_ones is True + assert not hasattr(clf, "mock_") + assert not hasattr(clf, "foo_") + assert not hasattr(clf.mock, "foo_") + + clf.fit(X, y) + clf.reset(keep="mock_") + + assert not hasattr(clf, "foo_") + assert hasattr(clf, "mock_") + assert hasattr(clf.mock_, "foo_") and clf.mock_.foo_ == "bar" + assert hasattr(clf.mock_, "return_ones") and clf.mock_.return_ones is True + + +def test_reset_invalid(): + """Tests that reset method raises error for invalid keep argument.""" + clf = MockClassifier() + with pytest.raises(TypeError, match=r"keep must be a string or list"): + clf.reset(keep=1) + + +def test_clone(): + """Tests that clone method correctly clones an estimator.""" + X, y = EQUAL_LENGTH_UNIVARIATE_CLASSIFICATION["numpy3D"]["train"] + + clf = MockClassifierParams(return_ones=True) + clf.fit(X, y) + + clf_clone = clf.clone() + assert clf_clone.return_ones is True + assert not hasattr(clf_clone, "foo_") + + clf = SummaryClassifier(random_state=100) + + clf_clone = clf.clone(random_state=42) + assert clf_clone.random_state == 1608637542 + + +def test_clone_function(): + """Tests that _clone_estimator function correctly clones an estimator.""" + X, y = EQUAL_LENGTH_UNIVARIATE_CLASSIFICATION["numpy3D"]["train"] + + clf = MockClassifierParams(return_ones=True) + clf.fit(X, y) + + clf_clone = _clone_estimator(clf) + assert clf_clone.return_ones is True + assert not hasattr(clf_clone, "foo_") + + clf = SummaryClassifier(random_state=100) + + clf_clone = _clone_estimator(clf, random_state=42) + assert clf_clone.random_state == 1608637542 + + +EXPECTED_MOCK_TAGS = { + "X_inner_type": ["np-list", "numpy3D"], "algorithm_type": None, - "capability:missing_values": False, + "cant_pickle": False, + "capability:contractable": False, + "capability:missing_values": True, "capability:multithreading": False, - "A": 42424241, - "B": 3, - "C": 1234, - 3: "E", + "capability:multivariate": True, + "capability:train_estimate": False, + "capability:unequal_length": True, + "capability:univariate": True, + "fit_is_empty": False, + "non_deterministic": False, + "python_dependencies": None, + "python_version": None, } def test_get_class_tags(): - """Tests get_class_tags class method of BaseAeonEstimator for correctness. - - Raises - ------ - AssertError if inheritance logic in get_class_tags is incorrect - """ - child_tags = FIXTURE_CLASSCHILD.get_class_tags() - - msg = "Inheritance logic in BaseAeonEstimator.get_class_tags is incorrect" - - assert child_tags == FIXTURE_CLASSCHILD_TAGS, msg + """Tests get_class_tags class method of BaseAeonEstimator for correctness.""" + child_tags = MockClassifierFullTags.get_class_tags() + assert child_tags == EXPECTED_MOCK_TAGS def test_get_class_tag(): - """Tests get_class_tag class method of BaseAeonEstimator for correctness. + """Tests get_class_tag class method of BaseAeonEstimator for correctness.""" + for key in EXPECTED_MOCK_TAGS.keys(): + assert EXPECTED_MOCK_TAGS[key] == MockClassifierFullTags.get_class_tag(key) - Raises - ------ - AssertError if inheritance logic in get_tag is incorrect - AssertError if default override logic in get_tag is incorrect - """ - child_tags = dict() - child_tags_keys = FIXTURE_CLASSCHILD_TAGS.keys() + # these should be true for inherited class above, but false for the parent class + assert BaseClassifier.get_class_tag("capability:missing_values") is False + assert BaseClassifier.get_class_tag("capability:multivariate") is False + assert BaseClassifier.get_class_tag("capability:unequal_length") is False - for key in child_tags_keys: - child_tags[key] = FIXTURE_CLASSCHILD.get_class_tag(key) + assert ( + BaseAeonEstimator.get_class_tag( + "invalid_tag", raise_error=False, tag_value_default=50 + ) + == 50 + ) - child_tag_default = FIXTURE_CLASSCHILD.get_class_tag("foo", "bar") - child_tag_defaultNone = FIXTURE_CLASSCHILD.get_class_tag("bar") + with pytest.raises(ValueError, match=r"Tag with name invalid_tag"): + BaseAeonEstimator.get_class_tag("invalid_tag") - msg = "Inheritance logic in BaseAeonEstimator.get_class_tag is incorrect" - for key in child_tags_keys: - assert child_tags[key] == FIXTURE_CLASSCHILD_TAGS[key], msg +def test_get_tags(): + """Tests get_tags method of BaseAeonEstimator for correctness.""" + child_tags = MockClassifierFullTags().get_tags() + assert child_tags == EXPECTED_MOCK_TAGS - msg = "Default override logic in BaseAeonEstimator.get_class_tag is incorrect" - assert child_tag_default == "bar", msg - assert child_tag_defaultNone is None, msg +def test_get_tag(): + """Tests get_tag method of BaseAeonEstimator for correctness.""" + clf = MockClassifierFullTags() + for key in EXPECTED_MOCK_TAGS.keys(): + assert EXPECTED_MOCK_TAGS[key] == clf.get_tag(key) + # these should be true for class above which overrides, but false for this which + # does not + clf = MockClassifier() + assert clf.get_tag("capability:missing_values") is False + assert clf.get_tag("capability:multivariate") is False + assert clf.get_tag("capability:unequal_length") is False -def test_get_tags(): - """Tests get_tags method of BaseAeonEstimator for correctness. + assert clf.get_tag("invalid_tag", raise_error=False, tag_value_default=50) == 50 - Raises - ------ - AssertError if inheritance logic in get_tags is incorrect - """ - object_tags = FIXTURE_OBJECT.get_tags() + with pytest.raises(ValueError, match=r"Tag with name invalid_tag"): + clf.get_tag("invalid_tag") - msg = "Inheritance logic in BaseAeonEstimator.get_tags is incorrect" - assert object_tags == FIXTURE_OBJECT_TAGS, msg +def test_set_tags(): + """Tests set_tags method of BaseAeonEstimator for correctness.""" + clf = MockClassifier() + tags_to_set = { + "capability:multivariate": True, + "capability:missing_values": True, + "capability:unequal_length": True, + } + clf.set_tags(**tags_to_set) -def test_get_tag(): - """Tests get_tag method of BaseAeonEstimator for correctness. + assert clf.get_tag("capability:missing_values") is True + assert clf.get_tag("capability:multivariate") is True + assert clf.get_tag("capability:unequal_length") is True - Raises - ------ - AssertError if inheritance logic in get_tag is incorrect - AssertError if default override logic in get_tag is incorrect - """ - object_tags = dict() - object_tags_keys = FIXTURE_OBJECT_TAGS.keys() + clf.reset() - for key in object_tags_keys: - object_tags[key] = FIXTURE_OBJECT.get_tag(key, raise_error=False) + assert clf.get_tag("capability:missing_values") is False + assert clf.get_tag("capability:multivariate") is False + assert clf.get_tag("capability:unequal_length") is False - object_tag_default = FIXTURE_OBJECT.get_tag("foo", "bar", raise_error=False) - object_tag_defaultNone = FIXTURE_OBJECT.get_tag("bar", raise_error=False) - msg = "Inheritance logic in BaseAeonEstimator.get_tag is incorrect" +def test_get_fitted_params(): + """Tests fitted parameter retrieval.""" + X, y = EQUAL_LENGTH_UNIVARIATE_CLASSIFICATION["numpy3D"]["train"] - for key in object_tags_keys: - assert object_tags[key] == FIXTURE_OBJECT_TAGS[key], msg + non_composite = MockClassifier() + non_composite.fit(X, y) + composite = MockClassifierComposite() + composite.fit(X, y) - msg = "Default override logic in BaseAeonEstimator.get_tag is incorrect" + params = non_composite.get_fitted_params() + comp_params = composite.get_fitted_params() - assert object_tag_default == "bar", msg - assert object_tag_defaultNone is None, msg + expected = { + "fit_time_", + "foo_", + "classes_", + "metadata_", + "n_classes_", + } + assert isinstance(params, dict) + assert set(params.keys()) == expected + assert params["foo_"] is composite.foo_ -def test_get_tag_raises(): - """Tests that get_tag method raises error for unknown tag. + assert isinstance(comp_params, dict) + assert set(comp_params.keys()) == expected.union( + { + "mock_", + "mock___classes_", + "mock___fit_time_", + "mock___foo_", + "mock___metadata_", + "mock___n_classes_", + } + ) + assert comp_params["foo_"] is composite.foo_ + assert comp_params["mock___foo_"] is composite.mock_.foo_ - Raises - ------ - AssertError if get_tag does not raise error for unknown tag. - """ - with pytest.raises(ValueError, match=r"Tag with name"): - FIXTURE_OBJECT.get_tag("bar") + params_shallow = non_composite.get_fitted_params(deep=False) + comp_params_shallow = composite.get_fitted_params(deep=False) + assert isinstance(params_shallow, dict) + assert set(params_shallow.keys()) == set(params.keys()) -FIXTURE_TAG_SET = {"A": 42424243, "E": 3} -FIXTURE_OBJECT_SET = deepcopy(FIXTURE_OBJECT).set_tags(**FIXTURE_TAG_SET) -FIXTURE_OBJECT_SET_TAGS = { - "python_version": None, - "python_dependencies": None, - "cant_pickle": False, - "non_deterministic": False, - "algorithm_type": None, - "capability:missing_values": False, - "capability:multithreading": False, - "A": 42424243, - "B": 3, - "C": 1234, - 3: "E", - "E": 3, -} -FIXTURE_OBJECT_SET_DYN = {"A": 42424243, "B": 3, "E": 3} + assert isinstance(comp_params_shallow, dict) + assert set(comp_params_shallow.keys()) == set(params.keys()).union({"mock_"}) -def test_set_tags(): - """Tests set_tags method of BaseAeonEstimator for correctness. +def test_get_fitted_params_sklearn(): + """Tests fitted parameter retrieval with sklearn components.""" + X, y = EQUAL_LENGTH_UNIVARIATE_CLASSIFICATION["numpy3D"]["train"] - Raises - ------ - AssertionError if override logic in set_tags is incorrect - """ - msg = "Setter/override logic in BaseAeonEstimator.set_tags is incorrect" + clf = SummaryClassifier(estimator=DecisionTreeClassifier()) + clf.fit(X, y) - assert FIXTURE_OBJECT_SET._tags_dynamic == FIXTURE_OBJECT_SET_DYN, msg - assert FIXTURE_OBJECT_SET.get_tags() == FIXTURE_OBJECT_SET_TAGS, msg + params = clf.get_fitted_params() + assert "estimator_" in params.keys() + assert "transformer_" in params.keys() + assert "estimator___tree_" in params.keys() + assert "estimator___max_features_" in params.keys() -class CompositionDummy(BaseAeonEstimator): - """Potentially composite object, for testing.""" + # pipeline + pipe = make_pipeline(Tabularizer(), StandardScaler(), DecisionTreeClassifier()) + clf = SummaryClassifier(estimator=pipe) + clf.fit(X, y) - def __init__(self, foo, bar=84): - self.foo = foo - self.foo_ = deepcopy(foo) - self.bar = bar + params = clf.get_fitted_params() + assert "estimator_" in params.keys() + assert "transformer_" in params.keys() -class ResetTester(BaseAeonEstimator): - clsvar = 210 - def __init__(self, a, b=42): - self.a = a - self.b = b - self.c = 84 +def test_check_is_fitted(): + """Test _check_is_fitted works correctly.""" + X, y = EQUAL_LENGTH_UNIVARIATE_CLASSIFICATION["numpy3D"]["train"] - def foo(self, d=126): - self.d = deepcopy(d) - self._d = deepcopy(d) - self.d_ = deepcopy(d) - self.f__o__o = 252 + clf = MockClassifier() + with pytest.raises(ValueError, match=r"has not been fitted yet"): + clf._check_is_fitted() -def test_reset(): - """Tests reset method for correct behaviour, on a simple estimator. - - Raises - ------ - AssertionError if logic behind reset is incorrect, logic tested: - reset should remove any object attributes that are not hyper-parameters, - with the exception of attributes containing double-underscore "__" - reset should not remove class attributes or methods - reset should set hyper-parameters as in pre-reset state - """ - x = ResetTester(168) - x.foo() - - x.reset() - - assert hasattr(x, "a") and x.a == 168 - assert hasattr(x, "b") and x.b == 42 - assert hasattr(x, "c") and x.c == 84 - assert hasattr(x, "clsvar") and x.clsvar == 210 - assert not hasattr(x, "d") - assert not hasattr(x, "_d") - assert not hasattr(x, "d_") - assert hasattr(x, "f__o__o") and x.f__o__o == 252 - assert hasattr(x, "foo") + clf.fit(X, y) + clf._check_is_fitted() -def test_reset_composite(): - """Test reset method for correct behaviour, on a composite estimator.""" - y = ResetTester(42) - x = ResetTester(a=y) - x.foo(y) - x.d.foo() +def test_create_test_instance(): + """Test _create_test_instance works as expected.""" + clf = SummaryClassifier._create_test_instance() - x.reset() + assert isinstance(clf, SummaryClassifier) + assert clf.estimator.n_estimators == 2 - assert hasattr(x, "a") - assert not hasattr(x, "d") - assert not hasattr(x.a, "d") +def test_overridden_sklearn(): + """Tests that overridden sklearn components return expected outputs.""" + X, y = EQUAL_LENGTH_UNIVARIATE_CLASSIFICATION["numpy3D"]["train"] -class FittableCompositionDummy(BaseAeonEstimator): - """Potentially composite object, for testing.""" + clf = MockClassifier() + clf.fit(X, y) - def __init__(self, foo, bar=84): - self.foo = foo - self.foo_ = deepcopy(foo) - self.bar = bar + assert clf.__sklearn_is_fitted__() == clf.is_fitted - def fit(self): - if hasattr(self.foo_, "fit"): - self.foo_.fit() - self.is_fitted = True + assert isinstance(clf._get_default_requests(), MetadataRequest) + with pytest.raises(NotImplementedError): + clf._validate_data() -def test_get_fitted_params(): - """Tests fitted parameter retrieval. - - Raises - ------ - AssertionError if logic behind get_fitted_params is incorrect, logic tested: - calling get_fitted_params on a non-composite fittable returns the fitted param - calling get_fitted_params on a composite returns all nested params - """ - non_composite = FittableCompositionDummy(foo=42) - composite = FittableCompositionDummy(foo=deepcopy(non_composite)) - - non_composite.fit() - composite.fit() - - non_comp_f_params = non_composite.get_fitted_params() - comp_f_params = composite.get_fitted_params() - comp_f_params_shallow = composite.get_fitted_params(deep=False) - - assert isinstance(non_comp_f_params, dict) - assert set(non_comp_f_params.keys()) == {"foo_"} - - assert isinstance(comp_f_params, dict) - assert set(comp_f_params) == {"foo_", "foo___foo_"} - assert set(comp_f_params_shallow) == {"foo_"} - assert comp_f_params["foo_"] is composite.foo_ - assert comp_f_params["foo_"] is not composite.foo - assert comp_f_params_shallow["foo_"] is composite.foo_ - assert comp_f_params_shallow["foo_"] is not composite.foo + with pytest.raises(NotImplementedError): + clf.get_metadata_routing() diff --git a/aeon/base/tests/test_base_aeon.py b/aeon/base/tests/test_base_aeon.py deleted file mode 100644 index f9d3b57481..0000000000 --- a/aeon/base/tests/test_base_aeon.py +++ /dev/null @@ -1,46 +0,0 @@ -"""Tests for universal base class that require aeon or sklearn imports.""" - -__maintainer__ = [] - -from sklearn.preprocessing import StandardScaler -from sklearn.tree import DecisionTreeClassifier - -from aeon.classification.feature_based import SummaryClassifier -from aeon.pipeline import make_pipeline -from aeon.testing.data_generation import make_example_3d_numpy -from aeon.transformations.collection import Tabularizer - - -def test_get_fitted_params_sklearn(): - """Tests fitted parameter retrieval with sklearn components. - - Raises - ------ - AssertionError if logic behind get_fitted_params is incorrect, logic tested: - calling get_fitted_params on obj aeon component returns expected nested params - """ - X, y = make_example_3d_numpy() - clf = SummaryClassifier(estimator=DecisionTreeClassifier()) - clf.fit(X, y) - - # params = clf.get_fitted_params() - - # todo v1.0.0 fix this - - -def test_get_fitted_params_sklearn_nested(): - """Tests fitted parameter retrieval with sklearn components. - - Raises - ------ - AssertionError if logic behind get_fitted_params is incorrect, logic tested: - calling get_fitted_params on obj aeon component returns expected nested params - """ - X, y = make_example_3d_numpy() - pipe = make_pipeline(Tabularizer(), StandardScaler(), DecisionTreeClassifier()) - clf = SummaryClassifier(estimator=pipe) - clf.fit(X, y) - - # params = clf.get_fitted_params() - - # todo v1.0.0 fix this diff --git a/aeon/base/tests/test_compose.py b/aeon/base/tests/test_compose.py new file mode 100644 index 0000000000..55ba965e72 --- /dev/null +++ b/aeon/base/tests/test_compose.py @@ -0,0 +1,174 @@ +"""Test composable estimator mixin.""" + +import pytest + +from aeon.classification.compose import ClassifierEnsemble +from aeon.testing.mock_estimators import MockClassifier, MockClassifierParams +from aeon.testing.testing_data import EQUAL_LENGTH_UNIVARIATE_CLASSIFICATION + + +def test_get_params(): + """Tst get_params retrieval for composable estimators.""" + ens = [("clf1", MockClassifierParams()), ("clf2", MockClassifierParams())] + clf = ClassifierEnsemble(ens) + + params = clf.get_params(deep=False) + + expected = { + "classifiers", + "cv", + "majority_vote", + "metric", + "metric_probas", + "random_state", + "weights", + } + + assert isinstance(params, dict) + assert set(params.keys()) == expected + assert params["classifiers"] == ens + + params = clf.get_params() + + expected = expected.union( + { + "clf1", + "clf2", + "clf1__return_ones", + "clf1__value", + "clf2__return_ones", + "clf2__value", + } + ) + + assert isinstance(params, dict) + assert set(params.keys()) == expected + assert params["clf1__value"] == 50 + + +def test_set_params(): + """Test set_params for composable estimators.""" + clf = ClassifierEnsemble( + [("clf1", MockClassifierParams()), ("clf2", MockClassifierParams())] + ) + + ens = [("clf3", MockClassifierParams()), ("clf4", MockClassifierParams())] + params = {"_ensemble": ens, "clf3__value": 100, "clf4__return_ones": True} + clf.set_params(**params) + + assert clf._ensemble[0][1].value == 100 + assert clf._ensemble[1][1].return_ones is True + + +def test_get_fitted_params(): + """Test get_fitted_params for composable estimators.""" + X, y = EQUAL_LENGTH_UNIVARIATE_CLASSIFICATION["numpy3D"]["train"] + + clf = ClassifierEnsemble( + [("clf1", MockClassifierParams()), ("clf2", MockClassifierParams())] + ) + clf.fit(X, y) + + params = clf.get_fitted_params(deep=False) + + expected = { + "classes_", + "ensemble_", + "fit_time_", + "metadata_", + "n_classes_", + "weights_", + } + + assert isinstance(params, dict) + assert set(params.keys()) == expected + assert params["n_classes_"] == clf.n_classes_ + + params = clf.get_fitted_params() + + expected = expected.union( + { + "clf1", + "clf1__classes_", + "clf1__fit_time_", + "clf1__foo_", + "clf1__metadata_", + "clf1__n_classes_", + "clf2", + "clf2__classes_", + "clf2__fit_time_", + "clf2__foo_", + "clf2__metadata_", + "clf2__n_classes_", + } + ) + + assert isinstance(params, dict) + assert set(params.keys()) == expected + assert params["clf1__n_classes_"] == 2 + + +def test_check_estimators(): + """Test check_estimators for composable estimators.""" + ens = [("clf1", MockClassifier()), MockClassifier()] + clf = ClassifierEnsemble(ens) + + clf._check_estimators(ens, unique_names=False) + + with pytest.raises(ValueError, match="estimators should only contain singular"): + clf._check_estimators(ens, allow_tuples=False) + + with pytest.raises(ValueError, match="should only contain"): + clf._check_estimators(ens, allow_single_estimators=False) + + with pytest.raises(ValueError, match="must be an instance of"): + clf._check_estimators([("class", MockClassifier)]) + + with pytest.raises(ValueError, match="must be of form"): + clf._check_estimators([(MockClassifier(),)]) + + with pytest.raises(ValueError, match="must be of form"): + clf._check_estimators([(MockClassifier, "class")]) + + with pytest.raises(ValueError, match="conflicts with constructor arguments"): + clf._check_estimators([("classifiers", MockClassifier())]) + + with pytest.raises(ValueError, match="Estimator name must not contain"): + clf._check_estimators([("__clf", MockClassifier())]) + + with pytest.raises(ValueError, match="must be unique"): + clf._check_estimators( + [("clf", MockClassifier()), ("clf", MockClassifier())], unique_names=True + ) + + with pytest.raises(ValueError, match="name is invalid"): + clf._check_estimators(ens, invalid_names=["clf1"]) + + with pytest.raises(ValueError, match="name is invalid"): + clf._check_estimators(ens, invalid_names="clf1") + + with pytest.raises(TypeError, match="tuple or estimator"): + clf._check_estimators(["invalid"]) + + with pytest.raises(TypeError, match="Invalid estimators attribute"): + clf._check_estimators([]) + + +def test_convert_estimators(): + """Test convert_estimators for composable estimators.""" + ens = [ + ("clf1", MockClassifierParams()), + MockClassifierParams(), + MockClassifierParams(), + ] + clf = ClassifierEnsemble(ens) + ens2 = clf._convert_estimators(ens) + + assert isinstance(ens2, list) + assert len(ens2) == 3 + assert ens2[0][0] == "clf1" + assert ens2[1][0] == "MockClassifierParams_0" + assert ens2[2][0] == "MockClassifierParams_1" + assert isinstance(ens2[0][1], MockClassifierParams) + assert isinstance(ens2[1][1], MockClassifierParams) + assert isinstance(ens2[2][1], MockClassifierParams) diff --git a/aeon/classification/compose/_channel_ensemble.py b/aeon/classification/compose/_channel_ensemble.py index e425098ad4..a1ddc71e81 100644 --- a/aeon/classification/compose/_channel_ensemble.py +++ b/aeon/classification/compose/_channel_ensemble.py @@ -10,7 +10,7 @@ import numpy as np from sklearn.utils import check_random_state -from aeon.base.estimator.compose.collection_channel_ensemble import ( +from aeon.base.estimators.compose.collection_channel_ensemble import ( BaseCollectionChannelEnsemble, ) from aeon.classification.base import BaseClassifier diff --git a/aeon/classification/compose/_ensemble.py b/aeon/classification/compose/_ensemble.py index b6dad5341c..d409adaab7 100644 --- a/aeon/classification/compose/_ensemble.py +++ b/aeon/classification/compose/_ensemble.py @@ -7,7 +7,7 @@ import numpy as np from sklearn.utils import check_random_state -from aeon.base.estimator.compose.collection_ensemble import BaseCollectionEnsemble +from aeon.base.estimators.compose.collection_ensemble import BaseCollectionEnsemble from aeon.classification.base import BaseClassifier from aeon.classification.sklearn._wrapper import SklearnClassifierWrapper from aeon.utils.sklearn import is_sklearn_classifier diff --git a/aeon/classification/compose/_pipeline.py b/aeon/classification/compose/_pipeline.py index 8fa1e94d50..7a2fb2d076 100644 --- a/aeon/classification/compose/_pipeline.py +++ b/aeon/classification/compose/_pipeline.py @@ -4,7 +4,7 @@ __all__ = ["ClassifierPipeline"] -from aeon.base.estimator.compose.collection_pipeline import BaseCollectionPipeline +from aeon.base.estimators.compose.collection_pipeline import BaseCollectionPipeline from aeon.classification.base import BaseClassifier diff --git a/aeon/classification/feature_based/_summary.py b/aeon/classification/feature_based/_summary.py index 6d7e02cc55..a4f34ff688 100644 --- a/aeon/classification/feature_based/_summary.py +++ b/aeon/classification/feature_based/_summary.py @@ -50,6 +50,10 @@ class SummaryClassifier(BaseClassifier): Number of classes. Extracted from the data. classes_ : ndarray of shape (n_classes) Holds the label for each class. + estimator_ : sklearn classifier + The fitted estimator. + transformer_ : SevenNumberSummary + The fitted transformer. See Also -------- @@ -88,9 +92,6 @@ def __init__( self.n_jobs = n_jobs self.random_state = random_state - self._transformer = None - self._estimator = None - super().__init__() def _fit(self, X, y): @@ -113,11 +114,11 @@ def _fit(self, X, y): Changes state by creating a fitted model that updates attributes ending in "_" and sets is_fitted flag to True. """ - self._transformer = SevenNumberSummary( + self.transformer_ = SevenNumberSummary( summary_stats=self.summary_stats, ) - self._estimator = _clone_estimator( + self.estimator_ = _clone_estimator( ( RandomForestClassifier(n_estimators=200) if self.estimator is None @@ -126,12 +127,12 @@ def _fit(self, X, y): self.random_state, ) - m = getattr(self._estimator, "n_jobs", None) + m = getattr(self.estimator_, "n_jobs", None) if m is not None: - self._estimator.n_jobs = self._n_jobs + self.estimator_.n_jobs = self._n_jobs - X_t = self._transformer.fit_transform(X, y) - self._estimator.fit(X_t, y) + X_t = self.transformer_.fit_transform(X, y) + self.estimator_.fit(X_t, y) return self @@ -148,7 +149,7 @@ def _predict(self, X) -> np.ndarray: y : array-like, shape = [n_cases] Predicted class labels. """ - return self._estimator.predict(self._transformer.transform(X)) + return self.estimator_.predict(self.transformer_.transform(X)) def _predict_proba(self, X) -> np.ndarray: """Predict class probabilities for n instances in X. @@ -163,12 +164,12 @@ def _predict_proba(self, X) -> np.ndarray: y : array-like, shape = [n_cases, n_classes_] Predicted probabilities using the ordering in classes_. """ - m = getattr(self._estimator, "predict_proba", None) + m = getattr(self.estimator_, "predict_proba", None) if callable(m): - return self._estimator.predict_proba(self._transformer.transform(X)) + return self.estimator_.predict_proba(self.transformer_.transform(X)) else: dists = np.zeros((X.shape[0], self.n_classes_)) - preds = self._estimator.predict(self._transformer.transform(X)) + preds = self.estimator_.predict(self.transformer_.transform(X)) for i in range(0, X.shape[0]): dists[i, self._class_dictionary[preds[i]]] = 1 return dists diff --git a/aeon/classification/hybrid/_rist.py b/aeon/classification/hybrid/_rist.py index d3db758567..f098a6b9c6 100644 --- a/aeon/classification/hybrid/_rist.py +++ b/aeon/classification/hybrid/_rist.py @@ -6,7 +6,7 @@ from sklearn.ensemble import ExtraTreesClassifier from sklearn.preprocessing import FunctionTransformer -from aeon.base.estimator.hybrid import BaseRIST +from aeon.base.estimators.hybrid import BaseRIST from aeon.classification import BaseClassifier from aeon.utils.numba.general import first_order_differences_3d diff --git a/aeon/classification/interval_based/_cif.py b/aeon/classification/interval_based/_cif.py index a0c91d3706..c46a23dc9a 100644 --- a/aeon/classification/interval_based/_cif.py +++ b/aeon/classification/interval_based/_cif.py @@ -8,7 +8,7 @@ import numpy as np -from aeon.base.estimator.interval_based import BaseIntervalForest +from aeon.base.estimators.interval_based import BaseIntervalForest from aeon.classification import BaseClassifier from aeon.classification.sklearn import ContinuousIntervalTree from aeon.transformations.collection.feature_based import Catch22 diff --git a/aeon/classification/interval_based/_drcif.py b/aeon/classification/interval_based/_drcif.py index 650bcf42e9..90811f2539 100644 --- a/aeon/classification/interval_based/_drcif.py +++ b/aeon/classification/interval_based/_drcif.py @@ -10,7 +10,7 @@ import numpy as np from sklearn.preprocessing import FunctionTransformer -from aeon.base.estimator.interval_based import BaseIntervalForest +from aeon.base.estimators.interval_based import BaseIntervalForest from aeon.classification.base import BaseClassifier from aeon.classification.sklearn._continuous_interval_tree import ContinuousIntervalTree from aeon.transformations.collection import PeriodogramTransformer diff --git a/aeon/classification/interval_based/_interval_forest.py b/aeon/classification/interval_based/_interval_forest.py index f1593adf10..9cf6f33d43 100644 --- a/aeon/classification/interval_based/_interval_forest.py +++ b/aeon/classification/interval_based/_interval_forest.py @@ -5,7 +5,7 @@ import numpy as np -from aeon.base.estimator.interval_based.base_interval_forest import BaseIntervalForest +from aeon.base.estimators.interval_based.base_interval_forest import BaseIntervalForest from aeon.classification.base import BaseClassifier diff --git a/aeon/classification/interval_based/_rise.py b/aeon/classification/interval_based/_rise.py index b298fa59f2..e17ce0ff7f 100644 --- a/aeon/classification/interval_based/_rise.py +++ b/aeon/classification/interval_based/_rise.py @@ -5,7 +5,7 @@ import numpy as np -from aeon.base.estimator.interval_based.base_interval_forest import BaseIntervalForest +from aeon.base.estimators.interval_based.base_interval_forest import BaseIntervalForest from aeon.classification import BaseClassifier from aeon.classification.sklearn import ContinuousIntervalTree from aeon.transformations.collection import ( diff --git a/aeon/classification/interval_based/_stsf.py b/aeon/classification/interval_based/_stsf.py index f782cb47ec..4642be7e11 100644 --- a/aeon/classification/interval_based/_stsf.py +++ b/aeon/classification/interval_based/_stsf.py @@ -11,7 +11,7 @@ import numpy as np from sklearn.preprocessing import FunctionTransformer -from aeon.base.estimator.interval_based.base_interval_forest import BaseIntervalForest +from aeon.base.estimators.interval_based.base_interval_forest import BaseIntervalForest from aeon.classification.base import BaseClassifier from aeon.transformations.collection import PeriodogramTransformer from aeon.utils.numba.general import first_order_differences_3d diff --git a/aeon/classification/interval_based/_tsf.py b/aeon/classification/interval_based/_tsf.py index 17ecd5f79b..ae827c6950 100644 --- a/aeon/classification/interval_based/_tsf.py +++ b/aeon/classification/interval_based/_tsf.py @@ -8,7 +8,7 @@ import numpy as np -from aeon.base.estimator.interval_based.base_interval_forest import BaseIntervalForest +from aeon.base.estimators.interval_based.base_interval_forest import BaseIntervalForest from aeon.classification import BaseClassifier from aeon.classification.sklearn import ContinuousIntervalTree diff --git a/aeon/clustering/compose/_pipeline.py b/aeon/clustering/compose/_pipeline.py index 763f872e49..bb972a3d68 100644 --- a/aeon/clustering/compose/_pipeline.py +++ b/aeon/clustering/compose/_pipeline.py @@ -4,7 +4,7 @@ __all__ = ["ClustererPipeline"] -from aeon.base.estimator.compose.collection_pipeline import BaseCollectionPipeline +from aeon.base.estimators.compose.collection_pipeline import BaseCollectionPipeline from aeon.clustering import BaseClusterer diff --git a/aeon/regression/compose/_ensemble.py b/aeon/regression/compose/_ensemble.py index 128690d2e7..14d3f837bb 100644 --- a/aeon/regression/compose/_ensemble.py +++ b/aeon/regression/compose/_ensemble.py @@ -6,7 +6,7 @@ import numpy as np -from aeon.base.estimator.compose.collection_ensemble import BaseCollectionEnsemble +from aeon.base.estimators.compose.collection_ensemble import BaseCollectionEnsemble from aeon.regression import BaseRegressor from aeon.regression.sklearn._wrapper import SklearnRegressorWrapper from aeon.utils.sklearn import is_sklearn_regressor diff --git a/aeon/regression/compose/_pipeline.py b/aeon/regression/compose/_pipeline.py index 618dd2d193..3d161bf5df 100644 --- a/aeon/regression/compose/_pipeline.py +++ b/aeon/regression/compose/_pipeline.py @@ -3,7 +3,7 @@ __maintainer__ = ["MatthewMiddlehurst"] __all__ = ["RegressorPipeline"] -from aeon.base.estimator.compose.collection_pipeline import BaseCollectionPipeline +from aeon.base.estimators.compose.collection_pipeline import BaseCollectionPipeline from aeon.regression.base import BaseRegressor diff --git a/aeon/regression/hybrid/_rist.py b/aeon/regression/hybrid/_rist.py index f7471ef482..15e0f763fb 100644 --- a/aeon/regression/hybrid/_rist.py +++ b/aeon/regression/hybrid/_rist.py @@ -1,7 +1,7 @@ from sklearn.ensemble import ExtraTreesRegressor from sklearn.preprocessing import FunctionTransformer -from aeon.base.estimator.hybrid import BaseRIST +from aeon.base.estimators.hybrid import BaseRIST from aeon.regression import BaseRegressor from aeon.utils.numba.general import first_order_differences_3d diff --git a/aeon/regression/interval_based/_cif.py b/aeon/regression/interval_based/_cif.py index 61b029068d..4899f39cab 100644 --- a/aeon/regression/interval_based/_cif.py +++ b/aeon/regression/interval_based/_cif.py @@ -5,7 +5,7 @@ import numpy as np -from aeon.base.estimator.interval_based import BaseIntervalForest +from aeon.base.estimators.interval_based import BaseIntervalForest from aeon.regression import BaseRegressor from aeon.transformations.collection.feature_based import Catch22 from aeon.utils.numba.stats import row_mean, row_slope, row_std diff --git a/aeon/regression/interval_based/_drcif.py b/aeon/regression/interval_based/_drcif.py index 152547964f..843bb3c7b4 100644 --- a/aeon/regression/interval_based/_drcif.py +++ b/aeon/regression/interval_based/_drcif.py @@ -6,7 +6,7 @@ from sklearn.preprocessing import FunctionTransformer -from aeon.base.estimator.interval_based import BaseIntervalForest +from aeon.base.estimators.interval_based import BaseIntervalForest from aeon.regression import BaseRegressor from aeon.transformations.collection import PeriodogramTransformer from aeon.transformations.collection.feature_based import Catch22 diff --git a/aeon/regression/interval_based/_interval_forest.py b/aeon/regression/interval_based/_interval_forest.py index 3ddc52de31..aa0195298f 100644 --- a/aeon/regression/interval_based/_interval_forest.py +++ b/aeon/regression/interval_based/_interval_forest.py @@ -5,7 +5,7 @@ import numpy as np -from aeon.base.estimator.interval_based.base_interval_forest import BaseIntervalForest +from aeon.base.estimators.interval_based.base_interval_forest import BaseIntervalForest from aeon.regression.base import BaseRegressor diff --git a/aeon/regression/interval_based/_rise.py b/aeon/regression/interval_based/_rise.py index 40506552dc..ef1d34d8bb 100644 --- a/aeon/regression/interval_based/_rise.py +++ b/aeon/regression/interval_based/_rise.py @@ -5,7 +5,7 @@ import numpy as np -from aeon.base.estimator.interval_based.base_interval_forest import BaseIntervalForest +from aeon.base.estimators.interval_based.base_interval_forest import BaseIntervalForest from aeon.regression import BaseRegressor from aeon.transformations.collection import ( AutocorrelationFunctionTransformer, diff --git a/aeon/regression/interval_based/_tsf.py b/aeon/regression/interval_based/_tsf.py index b01575f062..c15da5a3ad 100644 --- a/aeon/regression/interval_based/_tsf.py +++ b/aeon/regression/interval_based/_tsf.py @@ -8,7 +8,7 @@ import numpy as np -from aeon.base.estimator.interval_based.base_interval_forest import BaseIntervalForest +from aeon.base.estimators.interval_based.base_interval_forest import BaseIntervalForest from aeon.regression import BaseRegressor diff --git a/aeon/segmentation/base.py b/aeon/segmentation/base.py index d35f964b3b..6fbcc93100 100644 --- a/aeon/segmentation/base.py +++ b/aeon/segmentation/base.py @@ -105,10 +105,10 @@ def fit(self, X, y=None, axis=1): self Fitted estimator """ - if self.get_class_tag("fit_is_empty"): + if self.get_tag("fit_is_empty"): self.is_fitted = True return self - if self.get_class_tag("requires_y"): + if self.get_tag("requires_y"): if y is None: raise ValueError("Tag requires_y is true, but fit called with y=None") # reset estimator at the start of fit @@ -149,7 +149,7 @@ def predict(self, X, axis=1): self._check_is_fitted() if axis is None: axis = self.axis - X = self._preprocess_series(X, axis, self.get_class_tag("fit_is_empty")) + X = self._preprocess_series(X, axis, self.get_tag("fit_is_empty")) return self._predict(X) def fit_predict(self, X, y=None, axis=1): diff --git a/aeon/testing/estimator_checking/_yield_anomaly_detection_checks.py b/aeon/testing/estimator_checking/_yield_anomaly_detection_checks.py index 3686f6e0b9..2763442df7 100644 --- a/aeon/testing/estimator_checking/_yield_anomaly_detection_checks.py +++ b/aeon/testing/estimator_checking/_yield_anomaly_detection_checks.py @@ -65,7 +65,7 @@ def check_anomaly_detector_univariate(estimator): """Test the anomaly detector on univariate data.""" estimator = _clone_estimator(estimator) - if estimator.get_class_tag(tag_name="capability:univariate"): + if estimator.get_tag(tag_name="capability:univariate"): pred = estimator.fit_predict(uv_series, labels) assert isinstance(pred, np.ndarray) assert pred.shape == (15,) @@ -79,7 +79,7 @@ def check_anomaly_detector_multivariate(estimator): """Test the anomaly detector on multivariate data.""" estimator = _clone_estimator(estimator) - if estimator.get_class_tag(tag_name="capability:multivariate"): + if estimator.get_tag(tag_name="capability:multivariate"): pred = estimator.fit_predict(mv_series, labels) assert isinstance(pred, np.ndarray) assert pred.shape == (15,) diff --git a/aeon/testing/estimator_checking/_yield_segmentation_checks.py b/aeon/testing/estimator_checking/_yield_segmentation_checks.py index 7f10d86d0f..898f034f05 100644 --- a/aeon/testing/estimator_checking/_yield_segmentation_checks.py +++ b/aeon/testing/estimator_checking/_yield_segmentation_checks.py @@ -56,12 +56,12 @@ def _assert_output(output, dense, length): else: # Segment labels returned, must be same length sas series assert len(output) == length - multivariate = estimator.get_class_tag(tag_name="capability:multivariate") + multivariate = estimator.get_tag(tag_name="capability:multivariate") X = np.random.random(size=(5, 20)) # Also tests does not fail if y is passed y = np.array([0, 0, 0, 1, 1]) # Test that capability:multivariate is correctly set - dense = estimator.get_class_tag(tag_name="returns_dense") + dense = estimator.get_tag(tag_name="returns_dense") if multivariate: output = estimator.fit_predict(X, y, axis=1) _assert_output(output, dense, X.shape[1]) @@ -70,7 +70,7 @@ def _assert_output(output, dense, length): estimator.fit_predict(X, y, axis=1) # Test that output is correct type X = np.random.random(size=(20)) - uni = estimator.get_class_tag(tag_name="capability:univariate") + uni = estimator.get_tag(tag_name="capability:univariate") if uni: output = estimator.fit_predict(X, y=X) _assert_output(output, dense, len(X)) diff --git a/aeon/testing/estimator_checking/tests/test_check_estimator.py b/aeon/testing/estimator_checking/tests/test_check_estimator.py index f8f5b41fef..dc3a33369a 100644 --- a/aeon/testing/estimator_checking/tests/test_check_estimator.py +++ b/aeon/testing/estimator_checking/tests/test_check_estimator.py @@ -9,7 +9,7 @@ from aeon.testing.estimator_checking._estimator_checking import _get_check_estimator_ids from aeon.testing.mock_estimators import ( MockClassifier, - MockClassifierMultiTestParams, + MockClassifierParams, MockRegressor, MockSegmenter, ) @@ -25,7 +25,7 @@ MockAnomalyDetector, # MockMultivariateSeriesTransformer, TimeSeriesScaler, - MockClassifierMultiTestParams, + MockClassifierParams, ] test_classes = {c.__name__: c for c in test_classes} diff --git a/aeon/testing/mock_estimators/__init__.py b/aeon/testing/mock_estimators/__init__.py index 624b566c61..32d947cb7d 100644 --- a/aeon/testing/mock_estimators/__init__.py +++ b/aeon/testing/mock_estimators/__init__.py @@ -5,7 +5,7 @@ "MockClassifier", "MockClassifierPredictProba", "MockClassifierFullTags", - "MockClassifierMultiTestParams", + "MockClassifierParams", "MockCluster", "MockDeepClusterer", "MockSegmenter", @@ -22,7 +22,7 @@ from aeon.testing.mock_estimators._mock_classifiers import ( MockClassifier, MockClassifierFullTags, - MockClassifierMultiTestParams, + MockClassifierParams, MockClassifierPredictProba, ) from aeon.testing.mock_estimators._mock_clusterers import MockCluster, MockDeepClusterer diff --git a/aeon/testing/mock_estimators/_mock_classifiers.py b/aeon/testing/mock_estimators/_mock_classifiers.py index af18857aff..da766b8e16 100644 --- a/aeon/testing/mock_estimators/_mock_classifiers.py +++ b/aeon/testing/mock_estimators/_mock_classifiers.py @@ -5,14 +5,16 @@ import numpy as np +from aeon.base._base import _clone_estimator from aeon.classification import BaseClassifier class MockClassifier(BaseClassifier): - """Dummy classifier for testing base class fit/predict.""" + """Mock classifier for testing fit/predict.""" def _fit(self, X, y): """Fit dummy.""" + self.foo_ = "bar" return self def _predict(self, X): @@ -21,7 +23,7 @@ def _predict(self, X): class MockClassifierPredictProba(MockClassifier): - """Dummy classifier for testing base class fit/predict/predict_proba.""" + """Mock classifier for testing fit/predict/predict_proba.""" def _predict_proba(self, X): """Predict proba dummy.""" @@ -31,7 +33,7 @@ def _predict_proba(self, X): class MockClassifierFullTags(MockClassifierPredictProba): - """Dummy classifier able to handle all input types.""" + """Mock classifier able to handle all input types.""" _tags = { "capability:multivariate": True, @@ -41,8 +43,8 @@ class MockClassifierFullTags(MockClassifierPredictProba): } -class MockClassifierMultiTestParams(BaseClassifier): - """Dummy classifier for testing base class fit/predict with multiple test params. +class MockClassifierParams(MockClassifier): + """Mock classifier for testing fit/predict with multiple parameters. Parameters ---------- @@ -50,17 +52,18 @@ class MockClassifierMultiTestParams(BaseClassifier): If True, predict ones, else zeros. """ - def __init__(self, return_ones=False): + def __init__(self, return_ones=False, value=50): self.return_ones = return_ones + self.value = value super().__init__() - def _fit(self, X, y): - """Fit dummy.""" - return self - def _predict(self, X): """Predict dummy.""" - return np.zeros(shape=(len(X),)) + return ( + np.zeros(shape=(len(X),)) + if not self.return_ones + else np.ones(shape=(len(X),)) + ) @classmethod def _get_test_params(cls, parameter_set="default"): @@ -79,4 +82,26 @@ def _get_test_params(cls, parameter_set="default"): Each dict are parameters to construct an "interesting" test instance, i.e., `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. """ - return [{"return_ones": False}, {"return_ones": True}] + return [{"return_ones": False, "value": 10}, {"return_ones": True}] + + +class MockClassifierComposite(BaseClassifier): + """Mock classifier which contains another mock classfier.""" + + def __init__(self, mock=None): + self.mock = mock + super().__init__() + + def _fit(self, X, y): + """Fit dummy.""" + self.mock_ = ( + MockClassifier().fit(X, y) + if self.mock is None + else _clone_estimator(self.mock).fit(X, y) + ) + self.foo_ = "bar" + return self + + def _predict(self, X): + """Predict dummy.""" + return self.mock_.predict(X) diff --git a/aeon/testing/testing_data.py b/aeon/testing/testing_data.py index 7ab34d2be3..55b9092443 100644 --- a/aeon/testing/testing_data.py +++ b/aeon/testing/testing_data.py @@ -788,16 +788,16 @@ def _get_capabilities_for_estimator(estimator): Tuple of valid capabilities for the estimator. """ univariate = estimator.get_tag( - "capability:univariate", tag_value_default=True, raise_error=False + "capability:univariate", raise_error=False, tag_value_default=True ) multivariate = estimator.get_tag( - "capability:multivariate", tag_value_default=False, raise_error=False + "capability:multivariate", raise_error=False, tag_value_default=False ) unequal_length = estimator.get_tag( - "capability:unequal_length", tag_value_default=False, raise_error=False + "capability:unequal_length", raise_error=False, tag_value_default=False ) missing_values = estimator.get_tag( - "capability:missing_values", tag_value_default=False, raise_error=False + "capability:missing_values", raise_error=False, tag_value_default=False ) return univariate, multivariate, unequal_length, missing_values diff --git a/aeon/testing/utils/estimator_checks.py b/aeon/testing/utils/estimator_checks.py index ca774cbd59..1c9e8f8cb3 100644 --- a/aeon/testing/utils/estimator_checks.py +++ b/aeon/testing/utils/estimator_checks.py @@ -43,11 +43,11 @@ def _get_tag(estimator, tag_name, default=None, raise_error=False): return None elif isclass(estimator): return estimator.get_class_tag( - tag_name=tag_name, tag_value_default=default, raise_error=raise_error + tag_name=tag_name, raise_error=raise_error, tag_value_default=default ) else: return estimator.get_tag( - tag_name=tag_name, tag_value_default=default, raise_error=raise_error + tag_name=tag_name, raise_error=raise_error, tag_value_default=default ) diff --git a/aeon/transformations/collection/compose/_pipeline.py b/aeon/transformations/collection/compose/_pipeline.py index 56f7697261..d4c57b4957 100644 --- a/aeon/transformations/collection/compose/_pipeline.py +++ b/aeon/transformations/collection/compose/_pipeline.py @@ -4,7 +4,7 @@ __all__ = ["CollectionTransformerPipeline"] -from aeon.base.estimator.compose.collection_pipeline import BaseCollectionPipeline +from aeon.base.estimators.compose.collection_pipeline import BaseCollectionPipeline from aeon.transformations.collection import BaseCollectionTransformer from aeon.transformations.collection.compose import CollectionId diff --git a/aeon/utils/discovery.py b/aeon/utils/discovery.py index 0819795516..8fd4a05efe 100644 --- a/aeon/utils/discovery.py +++ b/aeon/utils/discovery.py @@ -230,7 +230,7 @@ def _filter_tags(tags, estimators, name): cond_sat = True for key, value in tags.items(): - est_tag = est[1].get_class_tag(key) + est_tag = est[1].get_class_tag(key, raise_error=False) est_tag = est_tag if isinstance(est_tag, list) else [est_tag] if isinstance(value, list): diff --git a/docs/api_reference/base.rst b/docs/api_reference/base.rst index 5a7ba0d80a..3d06b37103 100644 --- a/docs/api_reference/base.rst +++ b/docs/api_reference/base.rst @@ -5,10 +5,6 @@ Base The :mod:`aeon.base` module contains abstract base classes. -.. automodule:: aeon.base - :no-members: - :no-inherited-members: - Base classes ------------ @@ -21,15 +17,3 @@ Base classes BaseAeonEstimator BaseCollectionEstimator BaseSeriesEstimator - -Estimator base classes ----------------------- - -.. currentmodule:: aeon.base.estimator - -.. autosummary:: - :toctree: auto_generated/ - :template: class.rst - - hybrid.BaseRIST - interval_based.BaseIntervalForest