Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DOC,ENH] base docs and testing #2273

Merged
merged 8 commits into from
Nov 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions aeon/anomaly_detection/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,11 @@ def fit(self, X, y=None, axis=1):
BaseAnomalyDetector
The fitted estimator, reference to self.
"""
if self.get_class_tag("fit_is_empty"):
if self.get_tag("fit_is_empty"):
self.is_fitted = True
return self

if self.get_class_tag("requires_y"):
if self.get_tag("requires_y"):
if y is None:
raise ValueError("Tag requires_y is true, but fit called with y=None")

Expand Down Expand Up @@ -159,7 +159,7 @@ def predict(self, X, axis=1) -> np.ndarray:
A boolean, int or float array of length len(X), where each element indicates
whether the corresponding subsequence is anomalous or its anomaly score.
"""
fit_empty = self.get_class_tag("fit_is_empty")
fit_empty = self.get_tag("fit_is_empty")
if not fit_empty:
self._check_is_fitted()

Expand Down Expand Up @@ -194,7 +194,7 @@ def fit_predict(self, X, y=None, axis=1) -> np.ndarray:
A boolean, int or float array of length len(X), where each element indicates
whether the corresponding subsequence is anomalous or its anomaly score.
"""
if self.get_class_tag("requires_y"):
if self.get_tag("requires_y"):
if y is None:
raise ValueError("Tag requires_y is true, but fit called with y=None")

Expand All @@ -203,7 +203,7 @@ def fit_predict(self, X, y=None, axis=1) -> np.ndarray:

X = self._preprocess_series(X, axis, True)

if self.get_class_tag("fit_is_empty"):
if self.get_tag("fit_is_empty"):
self.is_fitted = True
return self._predict(X)

Expand All @@ -230,7 +230,7 @@ def _check_y(self, y: VALID_INPUT_TYPES) -> np.ndarray:
# Remind user if y is not required for this estimator on failure
req_msg = (
f"{self.__class__.__name__} does not require a y input."
if self.get_class_tag("requires_y")
if self.get_tag("requires_y")
else ""
)
new_y = y
Expand Down
4 changes: 2 additions & 2 deletions aeon/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
"BaseAeonEstimator",
"BaseCollectionEstimator",
"BaseSeriesEstimator",
"_ComposableEstimatorMixin",
"ComposableEstimatorMixin",
]

from aeon.base._base import BaseAeonEstimator
from aeon.base._base_collection import BaseCollectionEstimator
from aeon.base._base_series import BaseSeriesEstimator
from aeon.base._meta import _ComposableEstimatorMixin
from aeon.base._compose import ComposableEstimatorMixin
74 changes: 44 additions & 30 deletions aeon/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,19 @@ class BaseAeonEstimator(BaseEstimator, ABC):

Contains the following methods:

reset estimator to post-init - reset(keep)
clone stimator (copy) - clone(random_state)
inspect tags (class method) - get_class_tags()
inspect tags (one tag, class) - get_class_tag(tag_name, tag_value_default,
- reset estimator to post-init - reset(keep)
- clone stimator (copy) - clone(random_state)
- inspect tags (class method) - get_class_tags()
- inspect tags (one tag, class) - get_class_tag(tag_name, tag_value_default,
raise_error)
inspect tags (all) - get_tags()
inspect tags (one tag) - get_tag(tag_name, tag_value_default, raise_error)
setting dynamic tags - set_tags(**tag_dict)
get fitted parameters - get_fitted_params(deep)
- inspect tags (all) - get_tags()
- inspect tags (one tag) - get_tag(tag_name, tag_value_default, raise_error)
- setting dynamic tags - set_tags(**tag_dict)
- get fitted parameters - get_fitted_params(deep)

All estimators have the attribute:

fitted state flag - is_fitted
- fitted state flag - is_fitted
"""

_tags = {
Expand Down Expand Up @@ -63,7 +63,7 @@ def reset(self, keep=None):
hyper-parameters (arguments of ``__init__``)
object attributes containing double-underscores, i.e., the string "__"
runs ``__init__`` with current values of hyperparameters (result of
get_params)
``get_params``)

Not affected by the reset are:
object attributes containing double-underscores
Expand All @@ -73,13 +73,13 @@ class and object methods, class attributes
Parameters
----------
keep : None, str, or list of str, default=None
If None, all attributes are removed except hyper-parameters.
If None, all attributes are removed except hyperparameters.
If str, only the attribute with this name is kept.
If list of str, only the attributes with these names are kept.

Returns
-------
self
self : object
Reference to self.
"""
# retrieve parameters to copy them later
Expand Down Expand Up @@ -163,30 +163,35 @@ def get_class_tags(cls):
return deepcopy(collected_tags)

@classmethod
def get_class_tag(cls, tag_name, tag_value_default=None, raise_error=False):
def get_class_tag(
cls,
tag_name,
raise_error=True,
tag_value_default=None,
):
"""
Get tag value from estimator class (only class tags).

Parameters
----------
tag_name : str
Name of tag value.
tag_value_default : any type
Default/fallback value if tag is not found.
raise_error : bool
raise_error : bool, default=True
Whether a ValueError is raised when the tag is not found.
tag_value_default : any type, default=None
Default/fallback value if tag is not found and error is not raised.

Returns
-------
tag_value
Value of the ``tag_name`` tag in self.
If not found, returns an error if raise_error is True, otherwise it
returns `tag_value_default`.
Value of the ``tag_name`` tag in cls.
If not found, returns an error if ``raise_error`` is True, otherwise it
returns ``tag_value_default``.

Raises
------
ValueError
if raise_error is ``True`` and ``tag_name`` is not in
if ``raise_error`` is True and ``tag_name`` is not in
``self.get_tags().keys()``

Examples
Expand Down Expand Up @@ -221,7 +226,7 @@ def get_tags(self):
collected_tags.update(self._tags_dynamic)
return deepcopy(collected_tags)

def get_tag(self, tag_name, tag_value_default=None, raise_error=True):
def get_tag(self, tag_name, raise_error=True, tag_value_default=None):
TonyBagnall marked this conversation as resolved.
Show resolved Hide resolved
"""
Get tag value from estimator class.

Expand All @@ -231,17 +236,17 @@ def get_tag(self, tag_name, tag_value_default=None, raise_error=True):
----------
tag_name : str
Name of tag to be retrieved.
tag_value_default : any type, default=None
Default/fallback value if tag is not found.
raise_error : bool
raise_error : bool, default=True
Whether a ValueError is raised when the tag is not found.
tag_value_default : any type, default=None
Default/fallback value if tag is not found and error is not raised.

Returns
-------
tag_value
Value of the ``tag_name`` tag in self.
If not found, returns an error if raise_error is True, otherwise it
returns `tag_value_default`.
If not found, returns an error if ``raise_error`` is True, otherwise it
returns ``tag_value_default``.

Raises
------
Expand Down Expand Up @@ -276,7 +281,7 @@ def set_tags(self, **tag_dict):

Returns
-------
self
self : object
Reference to self.
"""
tag_update = deepcopy(tag_dict)
Expand All @@ -297,7 +302,7 @@ def get_fitted_params(self, deep=True):

Returns
-------
fitted_params : mapping of string to any
fitted_params : dict
Fitted parameter names mapped to their values.
"""
self._check_is_fitted()
Expand All @@ -312,7 +317,13 @@ def _get_fitted_params(self, est, deep):

out = dict()
for key in fitted_params:
value = getattr(est, key)
# some of these can be properties and can make assumptions which may not be
# true in aeon i.e. sklearn Pipeline feature_names_in_
try:
value = getattr(est, key)
except AttributeError:
continue

if deep and isinstance(value, BaseEstimator):
deep_items = self._get_fitted_params(value, deep).items()
out.update((key + "__" + k, val) for k, val in deep_items)
Expand Down Expand Up @@ -406,7 +417,10 @@ def _validate_data(self, **kwargs):
)

def get_metadata_routing(self):
"""Sklearn metadata routing."""
"""Sklearn metadata routing.

Not supported by ``aeon`` estimators.
"""
raise NotImplementedError(
"aeon estimators do not have a get_metadata_routing method."
)
Expand Down
21 changes: 10 additions & 11 deletions aeon/base/_meta.py → aeon/base/_compose.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
"""Implements meta estimator for estimators composed of other estimators."""

__maintainer__ = ["MatthewMiddlehurst"]
__all__ = ["_ComposableEstimatorMixin"]
__all__ = ["ComposableEstimatorMixin"]

from abc import ABC, abstractmethod

from aeon.base import BaseAeonEstimator
from aeon.base._base import _clone_estimator


class _ComposableEstimatorMixin(ABC):
class ComposableEstimatorMixin(ABC):
"""Handles parameter management for estimators composed of named estimators.

Parts (i.e. get_params and set_params) adapted or copied from the scikit-learn
Expand Down Expand Up @@ -52,9 +52,8 @@ def get_params(self, deep=True):
out.update(estimators)

for name, estimator in estimators:
if hasattr(estimator, "get_params"):
for key, value in estimator.get_params(deep=True).items():
out[f"{name}__{key}"] = value
for key, value in estimator.get_params(deep=True).items():
out[f"{name}__{key}"] = value
return out

def set_params(self, **params):
Expand Down Expand Up @@ -119,7 +118,7 @@ def get_fitted_params(self, deep=True):

Returns
-------
fitted_params : mapping of string to any
fitted_params : dict
Fitted parameter names mapped to their values.
"""
self._check_is_fitted()
Expand Down Expand Up @@ -190,16 +189,16 @@ def _check_estimators(
for obj in estimators:
if isinstance(obj, tuple):
if not allow_tuples:
raise TypeError(
raise ValueError(
f"{attr_name} should only contain singular estimators instead "
f"of (str, estimator) tuples."
)
if not len(obj) == 2 or not isinstance(obj[0], str):
raise TypeError(
raise ValueError(
f"All tuples in {attr_name} must be of form (str, estimator)."
)
if not isinstance(obj[1], class_type):
raise TypeError(
raise ValueError(
f"All estimators in {attr_name} must be an instance "
f"of {class_type}."
)
Expand All @@ -213,15 +212,15 @@ def _check_estimators(
raise ValueError(f"Estimator name is invalid: {obj[0]}")
if unique_names:
if obj[0] in names:
raise TypeError(
raise ValueError(
f"Names in {attr_name} must be unique. Found duplicate "
f"name: {obj[0]}."
)
else:
names.append(obj[0])
elif isinstance(obj, class_type):
if not allow_single_estimators:
raise TypeError(
raise ValueError(
f"{attr_name} should only contain (str, estimator) tuples "
f"instead of singular estimators."
)
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
from aeon.base import (
BaseAeonEstimator,
BaseCollectionEstimator,
_ComposableEstimatorMixin,
ComposableEstimatorMixin,
)
from aeon.base._base import _clone_estimator


class BaseCollectionChannelEnsemble(_ComposableEstimatorMixin, BaseCollectionEstimator):
class BaseCollectionChannelEnsemble(ComposableEstimatorMixin, BaseCollectionEstimator):
"""Applies estimators to channels of an array.

Parameters
Expand Down Expand Up @@ -101,7 +101,11 @@ def __init__(
missing = all(
[
(
e[1].get_tag("capability:missing_values", False, raise_error=False)
e[1].get_tag(
"capability:missing_values",
raise_error=False,
tag_value_default=False,
)
if isinstance(e[1], BaseAeonEstimator)
else False
)
Expand All @@ -110,14 +114,20 @@ def __init__(
)
remainder_missing = remainder is None or (
isinstance(remainder, BaseAeonEstimator)
and remainder.get_tag("capability:missing_values", False, raise_error=False)
and remainder.get_tag(
"capability:missing_values", raise_error=False, tag_value_default=False
)
)

# can handle unequal length if all estimators can
unequal = all(
[
(
e[1].get_tag("capability:unequal_length", False, raise_error=False)
e[1].get_tag(
"capability:unequal_length",
raise_error=False,
tag_value_default=False,
)
if isinstance(e[1], BaseAeonEstimator)
else False
)
Expand All @@ -126,7 +136,9 @@ def __init__(
)
remainder_unequal = remainder is None or (
isinstance(remainder, BaseAeonEstimator)
and remainder.get_tag("capability:unequal_length", False, raise_error=False)
and remainder.get_tag(
"capability:unequal_length", raise_error=False, tag_value_default=False
)
)

tags_to_set = {
Expand Down
Loading