Skip to content

Commit

Permalink
[ENH] Collection and series base tidy (#2352)
Browse files Browse the repository at this point in the history
* base changes

* comment fix

* fixes

* fixes

* add abstract to base init

* init

* more init

* more init

* merge
  • Loading branch information
MatthewMiddlehurst authored Nov 22, 2024
1 parent ef18e29 commit 23f3f0b
Show file tree
Hide file tree
Showing 42 changed files with 287 additions and 243 deletions.
31 changes: 16 additions & 15 deletions aeon/anomaly_detection/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pandas as pd

from aeon.base import BaseSeriesEstimator
from aeon.base._base_series import VALID_INPUT_TYPES
from aeon.base._base_series import VALID_SERIES_INPUT_TYPES


class BaseAnomalyDetector(BaseSeriesEstimator):
Expand Down Expand Up @@ -76,7 +76,7 @@ class BaseAnomalyDetector(BaseSeriesEstimator):
"""

_tags = {
"X_inner_type": "np.ndarray", # One of VALID_INNER_TYPES
"X_inner_type": "np.ndarray", # One of VALID_SERIES_INNER_TYPES
"fit_is_empty": True,
"requires_y": False,
}
Expand All @@ -95,14 +95,14 @@ def fit(self, X, y=None, axis=1):
Parameters
----------
X : one of aeon.base._base_series.VALID_INPUT_TYPES
X : one of aeon.base._base_series.VALID_SERIES_INPUT_TYPES
The time series to fit the model to.
A valid aeon time series data structure. See
aeon.base._base_series.VALID_INPUT_TYPES for aeon supported types.
y : one of aeon.base._base_series.VALID_INPUT_TYPES, default=None
aeon.base._base_series.VALID_SERIES_INPUT_TYPES for aeon supported types.
y : one of aeon.base._base_series.VALID_SERIES_INPUT_TYPES, default=None
The target values for the time series.
A valid aeon time series data structure. See
aeon.base._base_series.VALID_INPUT_TYPES for aeon supported types.
aeon.base._base_series.VALID_SERIES_INPUT_TYPES for aeon supported types.
axis : int
The time point axis of the input series if it is 2D. If ``axis==0``, it is
assumed each column is a time series and each row is a time point. i.e. the
Expand Down Expand Up @@ -142,10 +142,10 @@ def predict(self, X, axis=1) -> np.ndarray:
Parameters
----------
X : one of aeon.base._base_series.VALID_INPUT_TYPES
X : one of aeon.base._base_series.VALID_SERIES_INPUT_TYPES
The time series to fit the model to.
A valid aeon time series data structure. See
aeon.base._base_series.VALID_INPUT_TYPES for aeon supported types.
aeon.base._base_series.VALID_SERIES_INPUT_TYPES for aeon supported types.
axis : int, default=1
The time point axis of the input series if it is 2D. If ``axis==0``, it is
assumed each column is a time series and each row is a time point. i.e. the
Expand Down Expand Up @@ -173,14 +173,14 @@ def fit_predict(self, X, y=None, axis=1) -> np.ndarray:
Parameters
----------
X : one of aeon.base._base_series.VALID_INPUT_TYPES
X : one of aeon.base._base_series.VALID_SERIES_INPUT_TYPES
The time series to fit the model to.
A valid aeon time series data structure. See
aeon.base._base_series.VALID_INPUT_TYPES for aeon supported types.
y : one of aeon.base._base_series.VALID_INPUT_TYPES, default=None
y : one of aeon.base._base_series.VALID_SERIES_INPUT_TYPES, default=None
The target values for the time series.
A valid aeon time series data structure. See
aeon.base._base_series.VALID_INPUT_TYPES for aeon supported types.
aeon.base._base_series.VALID_SERIES_INPUT_TYPES for aeon supported types.
axis : int, default=1
The time point axis of the input series if it is 2D. If ``axis==0``, it is
assumed each column is a time series and each row is a time point. i.e. the
Expand Down Expand Up @@ -226,7 +226,7 @@ def _fit_predict(self, X, y):
self._fit(X, y)
return self._predict(X)

def _check_y(self, y: VALID_INPUT_TYPES) -> np.ndarray:
def _check_y(self, y: VALID_SERIES_INPUT_TYPES) -> np.ndarray:
# Remind user if y is not required for this estimator on failure
req_msg = (
f"{self.__class__.__name__} does not require a y input."
Expand All @@ -235,7 +235,8 @@ def _check_y(self, y: VALID_INPUT_TYPES) -> np.ndarray:
)
new_y = y

# must be a valid input type, see VALID_INPUT_TYPES in BaseSeriesEstimator
# must be a valid input type, see VALID_SERIES_INPUT_TYPES in
# BaseSeriesEstimator
if isinstance(y, np.ndarray):
# check valid shape
if y.ndim > 1:
Expand Down Expand Up @@ -284,8 +285,8 @@ def _check_y(self, y: VALID_INPUT_TYPES) -> np.ndarray:
new_y = y.squeeze().values
else:
raise ValueError(
f"Error in input type for y: it should be one of {VALID_INPUT_TYPES}, "
f"saw {type(y)}"
f"Error in input type for y: it should be one of "
f"{VALID_SERIES_INPUT_TYPES}, saw {type(y)}"
)

new_y = new_y.astype(bool)
Expand Down
2 changes: 1 addition & 1 deletion aeon/anomaly_detection/tests/test_left_stampi.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Tests for the LaftSTAMPi class."""
"""Tests for the LeftSTAMPi class."""

__maintainer__ = ["ferewi"]

Expand Down
7 changes: 6 additions & 1 deletion aeon/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
__all__ = ["BaseAeonEstimator"]

import inspect
from abc import ABC
from abc import ABC, abstractmethod
from copy import deepcopy

from sklearn import clone
from sklearn.base import BaseEstimator
from sklearn.ensemble._base import _set_random_states
from sklearn.exceptions import NotFittedError

from aeon.utils.validation._dependencies import _check_estimator_deps


class BaseAeonEstimator(BaseEstimator, ABC):
"""
Expand Down Expand Up @@ -44,12 +46,15 @@ class BaseAeonEstimator(BaseEstimator, ABC):
"capability:multithreading": False,
}

@abstractmethod
def __init__(self):
self.is_fitted = False # flag to indicate if fit has been called
self._tags_dynamic = dict() # storage for dynamic tags

super().__init__()

_check_estimator_deps(self)

def reset(self, keep=None):
"""
Reset the object to a clean post-init state.
Expand Down
Loading

0 comments on commit 23f3f0b

Please sign in to comment.