Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MNT] Unit testing revamp part 2: classification #1770

Merged
merged 16 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ repos:
additional_dependencies: [ isort==5.13.2 ]
args: [ "--nbqa-dont-skip-bad-cells", "--profile=black", "--multi-line=3" ]
- id: nbqa-black
additional_dependencies: [ black==24.2.0 ]
additional_dependencies: [ black==24.4.2 ]
args: [ "--nbqa-dont-skip-bad-cells" ]
- id: nbqa-flake8
additional_dependencies: [ flake8==7.0.0 ]
Expand Down
20 changes: 16 additions & 4 deletions aeon/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ class attribute via nested inheritance. NOT overridden by dynamic
return deepcopy(collected_tags)

@classmethod
def get_class_tag(cls, tag_name, tag_value_default=None):
def get_class_tag(cls, tag_name, tag_value_default=None, raise_error=False):
"""
Get tag value from estimator class (only class tags).

Expand All @@ -293,12 +293,19 @@ def get_class_tag(cls, tag_name, tag_value_default=None):
Name of tag value.
tag_value_default : any type
Default/fallback value if tag is not found.
raise_error : bool
Whether a ValueError is raised when the tag is not found.

Returns
-------
tag_value :
Value of the `tag_name` tag in self. If not found, returns
`tag_value_default`.
Value of the `tag_name` tag in self. If not found, returns an error if
raise_error is True, otherwise it returns `tag_value_default`.

Raises
------
ValueError if raise_error is True i.e. if tag_name is not in self.get_tags(
).keys()

See Also
--------
Expand All @@ -314,7 +321,12 @@ def get_class_tag(cls, tag_name, tag_value_default=None):
"""
collected_tags = cls.get_class_tags()

return collected_tags.get(tag_name, tag_value_default)
tag_value = collected_tags.get(tag_name, tag_value_default)

if raise_error and tag_name not in collected_tags.keys():
raise ValueError(f"Tag with name {tag_name} could not be found.")

return tag_value

def get_tags(self):
"""
Expand Down
17 changes: 10 additions & 7 deletions aeon/base/tests/test_base_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@
import pytest

from aeon.base import BaseCollectionEstimator
from aeon.testing.testing_data import EQUAL_LENGTH_UNIVARIATE, UNEQUAL_LENGTH_UNIVARIATE
from aeon.testing.testing_data import (
EQUAL_LENGTH_UNIVARIATE_CLASSIFICATION,
UNEQUAL_LENGTH_UNIVARIATE_CLASSIFICATION,
)
from aeon.utils import COLLECTIONS_DATA_TYPES
from aeon.utils.validation import get_type


@pytest.mark.parametrize("data", COLLECTIONS_DATA_TYPES)
def test__get_metadata(data):
"""Test get meta data."""
X = EQUAL_LENGTH_UNIVARIATE[data]
X = EQUAL_LENGTH_UNIVARIATE_CLASSIFICATION[data]["train"][0]
meta = BaseCollectionEstimator._get_X_metadata(X)
assert not meta["multivariate"]
assert not meta["missing_values"]
Expand Down Expand Up @@ -68,7 +71,7 @@ def test__convert_X(internal_type, data):
"""
cls = BaseCollectionEstimator()
# Equal length should default to numpy3D
X = EQUAL_LENGTH_UNIVARIATE[data]
X = EQUAL_LENGTH_UNIVARIATE_CLASSIFICATION[data]["train"][0]
cls.metadata_ = cls._check_X(X)
X2 = cls._convert_X(X)
assert get_type(X2) == cls.get_tag("X_inner_type")
Expand All @@ -94,9 +97,9 @@ def test__convert_X(internal_type, data):
X2 = cls._convert_X(X)
assert get_type(X2) == "numpy3D" if data != internal_type else internal_type

if data in UNEQUAL_LENGTH_UNIVARIATE.keys():
if internal_type in UNEQUAL_LENGTH_UNIVARIATE.keys():
X = UNEQUAL_LENGTH_UNIVARIATE[data]
if data in UNEQUAL_LENGTH_UNIVARIATE_CLASSIFICATION.keys():
if internal_type in UNEQUAL_LENGTH_UNIVARIATE_CLASSIFICATION.keys():
X = UNEQUAL_LENGTH_UNIVARIATE_CLASSIFICATION[data]["train"][0]

# Should stay as internal_type
cls.set_tags(**{"capability:unequal_length": True})
Expand All @@ -114,7 +117,7 @@ def test__convert_X(internal_type, data):
@pytest.mark.parametrize("data", COLLECTIONS_DATA_TYPES)
def test_preprocess_collection(data):
"""Test the functionality for preprocessing fit."""
data = EQUAL_LENGTH_UNIVARIATE[data]
data = EQUAL_LENGTH_UNIVARIATE_CLASSIFICATION[data]["train"][0]
cls = BaseCollectionEstimator()
X = cls._preprocess_collection(data)
assert cls._n_jobs == 1
Expand Down
24 changes: 4 additions & 20 deletions aeon/classification/convolution_based/_arsenal.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,8 @@ def _fit_arsenal(self, X, y, keep_transformed_data=False):
else:
raise ValueError(f"Invalid Rocket transformer: {self.rocket_transform}")

rng = check_random_state(self.random_state)

if time_limit > 0:
self.n_estimators_ = 0
self.estimators_ = []
Expand All @@ -308,16 +310,7 @@ def _fit_arsenal(self, X, y, keep_transformed_data=False):
fit = Parallel(n_jobs=self._n_jobs, prefer="threads")(
delayed(self._fit_ensemble_estimator)(
_clone_estimator(
base_rocket,
(
None
if self.random_state is None
else (
255 if self.random_state == 0 else self.random_state
)
* 37
* (i + 1)
),
base_rocket, rng.randint(np.iinfo(np.int32).max)
),
X,
y,
Expand All @@ -336,16 +329,7 @@ def _fit_arsenal(self, X, y, keep_transformed_data=False):
else:
fit = Parallel(n_jobs=self._n_jobs, prefer="threads")(
delayed(self._fit_ensemble_estimator)(
_clone_estimator(
base_rocket,
(
None
if self.random_state is None
else (255 if self.random_state == 0 else self.random_state)
* 37
* (i + 1)
),
),
_clone_estimator(base_rocket, rng.randint(np.iinfo(np.int32).max)),
X,
y,
keep_transformed_data=keep_transformed_data,
Expand Down

This file was deleted.

Loading