Skip to content

Commit

Permalink
[ENH] More base class method removals (#2171)
Browse files Browse the repository at this point in the history
* base purge cont.

* remove useless checks

* fix

* make base classes abstract

* Apply suggestions from code review

Co-authored-by: Sebastian Schmidl <[email protected]>

* Automatic `pre-commit` fixes

* remove metaclasses

* problem for next PR

* Update _base.py

* more test skips for when the meta class gets sorted out

---------

Co-authored-by: Sebastian Schmidl <[email protected]>
Co-authored-by: MatthewMiddlehurst <[email protected]>
Co-authored-by: Tony Bagnall <[email protected]>
  • Loading branch information
4 people authored Oct 19, 2024
1 parent 1ee770b commit 4741999
Show file tree
Hide file tree
Showing 26 changed files with 171 additions and 360 deletions.
6 changes: 3 additions & 3 deletions aeon/anomaly_detection/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
__maintainer__ = ["MatthewMiddlehurst"]
__all__ = ["BaseAnomalyDetector"]

from abc import ABC, abstractmethod
from abc import abstractmethod
from typing import final

import numpy as np
Expand All @@ -13,7 +13,7 @@
from aeon.base._base_series import VALID_INPUT_TYPES


class BaseAnomalyDetector(BaseSeriesEstimator, ABC):
class BaseAnomalyDetector(BaseSeriesEstimator):
"""Base class for anomaly detection algorithms.
Anomaly detection algorithms are used to identify anomalous subsequences in time
Expand Down Expand Up @@ -161,7 +161,7 @@ def predict(self, X, axis=1) -> np.ndarray:
"""
fit_empty = self.get_class_tag("fit_is_empty")
if not fit_empty:
self.check_is_fitted()
self._check_is_fitted()

X = self._preprocess_series(X, axis, False)

Expand Down
333 changes: 68 additions & 265 deletions aeon/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
__all__ = ["BaseAeonEstimator"]

import inspect
from abc import ABC
from copy import deepcopy

from sklearn import clone
Expand All @@ -12,7 +13,7 @@
from sklearn.exceptions import NotFittedError


class BaseAeonEstimator(BaseEstimator):
class BaseAeonEstimator(BaseEstimator, ABC):
"""
Base class for defining estimators in aeon.
Expand Down Expand Up @@ -281,6 +282,72 @@ def set_tags(self, **tag_dict):
self._tags_dynamic.update(tag_update)
return self

def get_fitted_params(self, deep=True):
"""Get fitted parameters.
State required:
Requires state to be "fitted".
Parameters
----------
deep : bool, default=True
Whether to return fitted parameters of components.
* If True, will return a dict of parameter name : value for this object,
including fitted parameters of fittable components
(= BaseAeonEstimator-valued parameters).
* If False, will return a dict of parameter name : value for this object,
but not include fitted parameters of components.
Returns
-------
fitted_params : dict with str-valued keys
Dictionary of fitted parameters, paramname : paramvalue
keys-value pairs include:
* always: all fitted parameters of this object
* if ``deep=True``, also contains keys/value pairs of component parameters
parameters of components are indexed as ``[componentname]__[paramname]``
all parameters of ``componentname`` appear as ``paramname`` with its value
* if ``deep=True``, also contains arbitrary levels of component recursion,
e.g., ``[componentname]__[componentcomponentname]__[paramname]``, etc.
"""
self._check_is_fitted()
return self._get_fitted_params(self, deep)

def _get_fitted_params(self, est, deep):
"""Recursive function to get fitted parameters."""
# retrieves all self attributes ending in "_"
fitted_params = [
attr for attr in dir(est) if attr.endswith("_") and not attr.startswith("_")
]

out = dict()
for key in fitted_params:
value = getattr(est, key)
if deep and isinstance(value, BaseEstimator):
deep_items = self._get_fitted_params(value, deep).items()
out.update((key + "__" + k, val) for k, val in deep_items)
out[key] = value
return out

# private functions to help testing

def _check_is_fitted(self):
"""
Check if the estimator has been fitted.
Raises
------
NotFittedError
If the estimator has not been fitted yet.
"""
if not self.is_fitted:
raise NotFittedError(
f"This instance of {self.__class__.__name__} has not "
f"been fitted yet; please call `fit` first."
)

@classmethod
def get_test_params(cls, parameter_set="default"):
"""
Expand Down Expand Up @@ -339,270 +406,6 @@ def create_test_instance(cls, parameter_set="default", return_first=True):
else:
return [cls(**params)]

def _components(self, base_class=None):
"""
Return references to all state changing BaseAeonEstimator type attributes.
This *excludes* the blue-print-like components passed in the __init__.
Caution: this method returns *references* and not *copies*.
Writing to the reference will change the respective attribute of self.
Parameters
----------
base_class : subclass of BaseAeonEstimator, default=None
if None, behaves the same as `base_class=BaseAeonEstimator`
if not None, return dict collects descendants of `base_class`.
Returns
-------
dict with key = attribute name, value = reference to attribute.
dict contains all attributes of `self` that inherit from `base_class`, and:
whose names do not contain the string "__", e.g., hidden attributes
are not class attributes, and are not hyper-parameters (`__init__` args).
"""
if base_class is None:
base_class = BaseAeonEstimator
if base_class is not None and not inspect.isclass(base_class):
raise TypeError(f"base_class must be a class, but found {type(base_class)}")
# if base_class is not None and not issubclass(base_class, BaseAeonEstimator):
# raise TypeError("base_class must be a subclass of BaseAeonEstimator")

# retrieve parameter names to exclude them later
param_names = self.get_params(deep=False).keys()

# retrieve all attributes that are BaseAeonEstimator descendants
attrs = [attr for attr in dir(self) if "__" not in attr]
cls_attrs = [attr for attr in dir(type(self))]
self_attrs = set(attrs).difference(cls_attrs).difference(param_names)

comp_dict = {x: getattr(self, x) for x in self_attrs}
comp_dict = {x: y for (x, y) in comp_dict.items() if isinstance(y, base_class)}

return comp_dict

def save(self, path=None):
"""
Save serialized self to bytes-like object or to (.zip) file.
Behaviour:
if `path` is None, returns an in-memory serialized self
if `path` is a file location, stores self at that location as a zip file
saved files are zip files with following contents:
_metadata - contains class of self, i.e., type(self)
_obj - serialized self. This class uses the default serialization (pickle).
Parameters
----------
path : None or file location (str or Path).
if None, self is saved to an in-memory object
if file location, self is saved to that file location. If:
path="estimator" then a zip file `estimator.zip` will be made at cwd.
path="/home/stored/estimator" then a zip file `estimator.zip` will be
stored in `/home/stored/`.
Returns
-------
if `path` is None - in-memory serialized self
if `path` is file location - ZipFile with reference to the file.
"""
import pickle
import shutil
from pathlib import Path
from zipfile import ZipFile

if path is None:
return (type(self), pickle.dumps(self))
if not isinstance(path, (str, Path)):
raise TypeError(
"`path` is expected to either be a string or a Path object "
f"but found of type:{type(path)}."
)

path = Path(path) if isinstance(path, str) else path
path.mkdir()

pickle.dump(type(self), open(path / "_metadata", "wb"))
pickle.dump(self, open(path / "_obj", "wb"))

shutil.make_archive(base_name=path, format="zip", root_dir=path)
shutil.rmtree(path)
return ZipFile(path.with_name(f"{path.stem}.zip"))

@classmethod
def load_from_serial(cls, serial):
"""
Load object from serialized memory container.
Parameters
----------
serial : object
First element of output of `cls.save(None)`.
Returns
-------
deserialized self resulting in output `serial`, of `cls.save(None)`.
"""
import pickle

return pickle.loads(serial)

@classmethod
def load_from_path(cls, serial):
"""
Load object from file location.
Parameters
----------
serial : object
Result of ZipFile(path).open("object).
Returns
-------
deserialized self resulting in output at `path`, of `cls.save(path)`
"""
import pickle
from zipfile import ZipFile

with ZipFile(serial, "r") as file:
return pickle.loads(file.open("_obj").read())

def check_is_fitted(self):
"""
Check if the estimator has been fitted.
Raises
------
NotFittedError
If the estimator has not been fitted yet.
"""
if not self.is_fitted:
raise NotFittedError(
f"This instance of {self.__class__.__name__} has not "
f"been fitted yet; please call `fit` first."
)

def get_fitted_params(self, deep=True):
"""Get fitted parameters.
State required:
Requires state to be "fitted".
Parameters
----------
deep : bool, default=True
Whether to return fitted parameters of components.
* If True, will return a dict of parameter name : value for this object,
including fitted parameters of fittable components
(= BaseAeonEstimator-valued parameters).
* If False, will return a dict of parameter name : value for this object,
but not include fitted parameters of components.
Returns
-------
fitted_params : dict with str-valued keys
Dictionary of fitted parameters, paramname : paramvalue
keys-value pairs include:
* always: all fitted parameters of this object
* if ``deep=True``, also contains keys/value pairs of component parameters
parameters of components are indexed as ``[componentname]__[paramname]``
all parameters of ``componentname`` appear as ``paramname`` with its value
* if ``deep=True``, also contains arbitrary levels of component recursion,
e.g., ``[componentname]__[componentcomponentname]__[paramname]``, etc.
"""
if not self.is_fitted:
raise NotFittedError(
f"estimator of type {type(self).__name__} has not been "
"fitted yet, please call fit on data before get_fitted_params"
)

# collect non-nested fitted params of self
fitted_params = self._get_fitted_params()

# the rest is only for nested parameters
# so, if deep=False, we simply return here
if not deep:
return fitted_params

def sh(x):
"""Shorthand to remove all underscores at end of a string."""
if x.endswith("_"):
return sh(x[:-1])
else:
return x

# add all nested parameters from components that are aeon BaseAeonEstimator
c_dict = self._components()
for c, comp in c_dict.items():
if isinstance(comp, BaseAeonEstimator) and comp.is_fitted:
c_f_params = comp.get_fitted_params()
c_f_params = {f"{sh(c)}__{k}": v for k, v in c_f_params.items()}
fitted_params.update(c_f_params)

# add all nested parameters from components that are sklearn estimators
# we do this recursively as we have to reach into nested sklearn estimators
n_new_params = 42
old_new_params = fitted_params
while n_new_params > 0:
new_params = dict()
for c, comp in old_new_params.items():
if isinstance(comp, BaseEstimator):
c_f_params = self._get_fitted_params_default(comp)
c_f_params = {f"{sh(c)}__{k}": v for k, v in c_f_params.items()}
new_params.update(c_f_params)
fitted_params.update(new_params)
old_new_params = new_params.copy()
n_new_params = len(new_params)

return fitted_params

def _get_fitted_params_default(self, obj=None):
"""Obtain fitted params of object, per sklearn convention.
Extracts a dict with {paramstr : paramvalue} contents,
where paramstr are all string names of "fitted parameters".
A "fitted attribute" of obj is one that ends in "_" but does not start with "_".
"fitted parameters" are names of fitted attributes, minus the "_" at the end.
Parameters
----------
obj : any object, optional, default=self.
Returns
-------
fitted_params : dict with str keys
fitted parameters, keyed by names of fitted parameter.
"""
obj = obj if obj else self

# default retrieves all self attributes ending in "_"
# and returns them with keys that have the "_" removed
fitted_params = [attr for attr in dir(obj) if attr.endswith("_")]
fitted_params = [x for x in fitted_params if not x.startswith("_")]
fitted_params = [x for x in fitted_params if hasattr(obj, x)]
fitted_param_dict = {p[:-1]: getattr(obj, p) for p in fitted_params}

return fitted_param_dict

def _get_fitted_params(self):
"""Get fitted parameters.
private _get_fitted_params, called from get_fitted_params
State required:
Requires state to be "fitted".
Returns
-------
fitted_params : dict with str keys
fitted parameters, keyed by names of fitted parameter.
"""
return self._get_fitted_params_default()

# override some sklearn private methods

def __sklearn_is_fitted__(self):
Expand Down
1 change: 1 addition & 0 deletions aeon/base/_base_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
__maintainer__ = ["TonyBagnall", "MatthewMiddlehurst"]
__all__ = ["BaseSeriesEstimator"]


import numpy as np
import pandas as pd

Expand Down
Loading

0 comments on commit 4741999

Please sign in to comment.