Skip to content

Commit

Permalink
[MNT] Unit testing revamp part 2: classification (#1770)
Browse files Browse the repository at this point in the history
* classification checks in progress

* rework yield checks

* rework yield checks to allow for class input

* fixes

* fix

* pr testing split

* missing value testing data

* fix
  • Loading branch information
MatthewMiddlehurst authored Aug 12, 2024
1 parent 84aa1b7 commit d835391
Show file tree
Hide file tree
Showing 23 changed files with 1,920 additions and 1,012 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ repos:
additional_dependencies: [ isort==5.13.2 ]
args: [ "--nbqa-dont-skip-bad-cells", "--profile=black", "--multi-line=3" ]
- id: nbqa-black
additional_dependencies: [ black==24.2.0 ]
additional_dependencies: [ black==24.4.2 ]
args: [ "--nbqa-dont-skip-bad-cells" ]
- id: nbqa-flake8
additional_dependencies: [ flake8==7.0.0 ]
Expand Down
20 changes: 16 additions & 4 deletions aeon/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ class attribute via nested inheritance. NOT overridden by dynamic
return deepcopy(collected_tags)

@classmethod
def get_class_tag(cls, tag_name, tag_value_default=None):
def get_class_tag(cls, tag_name, tag_value_default=None, raise_error=False):
"""
Get tag value from estimator class (only class tags).
Expand All @@ -293,12 +293,19 @@ def get_class_tag(cls, tag_name, tag_value_default=None):
Name of tag value.
tag_value_default : any type
Default/fallback value if tag is not found.
raise_error : bool
Whether a ValueError is raised when the tag is not found.
Returns
-------
tag_value :
Value of the `tag_name` tag in self. If not found, returns
`tag_value_default`.
Value of the `tag_name` tag in self. If not found, returns an error if
raise_error is True, otherwise it returns `tag_value_default`.
Raises
------
ValueError if raise_error is True i.e. if tag_name is not in self.get_tags(
).keys()
See Also
--------
Expand All @@ -314,7 +321,12 @@ def get_class_tag(cls, tag_name, tag_value_default=None):
"""
collected_tags = cls.get_class_tags()

return collected_tags.get(tag_name, tag_value_default)
tag_value = collected_tags.get(tag_name, tag_value_default)

if raise_error and tag_name not in collected_tags.keys():
raise ValueError(f"Tag with name {tag_name} could not be found.")

return tag_value

def get_tags(self):
"""
Expand Down
17 changes: 10 additions & 7 deletions aeon/base/tests/test_base_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@
import pytest

from aeon.base import BaseCollectionEstimator
from aeon.testing.testing_data import EQUAL_LENGTH_UNIVARIATE, UNEQUAL_LENGTH_UNIVARIATE
from aeon.testing.testing_data import (
EQUAL_LENGTH_UNIVARIATE_CLASSIFICATION,
UNEQUAL_LENGTH_UNIVARIATE_CLASSIFICATION,
)
from aeon.utils import COLLECTIONS_DATA_TYPES
from aeon.utils.validation import get_type


@pytest.mark.parametrize("data", COLLECTIONS_DATA_TYPES)
def test__get_metadata(data):
"""Test get meta data."""
X = EQUAL_LENGTH_UNIVARIATE[data]
X = EQUAL_LENGTH_UNIVARIATE_CLASSIFICATION[data]["train"][0]
meta = BaseCollectionEstimator._get_X_metadata(X)
assert not meta["multivariate"]
assert not meta["missing_values"]
Expand Down Expand Up @@ -68,7 +71,7 @@ def test__convert_X(internal_type, data):
"""
cls = BaseCollectionEstimator()
# Equal length should default to numpy3D
X = EQUAL_LENGTH_UNIVARIATE[data]
X = EQUAL_LENGTH_UNIVARIATE_CLASSIFICATION[data]["train"][0]
cls.metadata_ = cls._check_X(X)
X2 = cls._convert_X(X)
assert get_type(X2) == cls.get_tag("X_inner_type")
Expand All @@ -94,9 +97,9 @@ def test__convert_X(internal_type, data):
X2 = cls._convert_X(X)
assert get_type(X2) == "numpy3D" if data != internal_type else internal_type

if data in UNEQUAL_LENGTH_UNIVARIATE.keys():
if internal_type in UNEQUAL_LENGTH_UNIVARIATE.keys():
X = UNEQUAL_LENGTH_UNIVARIATE[data]
if data in UNEQUAL_LENGTH_UNIVARIATE_CLASSIFICATION.keys():
if internal_type in UNEQUAL_LENGTH_UNIVARIATE_CLASSIFICATION.keys():
X = UNEQUAL_LENGTH_UNIVARIATE_CLASSIFICATION[data]["train"][0]

# Should stay as internal_type
cls.set_tags(**{"capability:unequal_length": True})
Expand All @@ -114,7 +117,7 @@ def test__convert_X(internal_type, data):
@pytest.mark.parametrize("data", COLLECTIONS_DATA_TYPES)
def test_preprocess_collection(data):
"""Test the functionality for preprocessing fit."""
data = EQUAL_LENGTH_UNIVARIATE[data]
data = EQUAL_LENGTH_UNIVARIATE_CLASSIFICATION[data]["train"][0]
cls = BaseCollectionEstimator()
X = cls._preprocess_collection(data)
assert cls._n_jobs == 1
Expand Down
24 changes: 4 additions & 20 deletions aeon/classification/convolution_based/_arsenal.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,8 @@ def _fit_arsenal(self, X, y, keep_transformed_data=False):
else:
raise ValueError(f"Invalid Rocket transformer: {self.rocket_transform}")

rng = check_random_state(self.random_state)

if time_limit > 0:
self.n_estimators_ = 0
self.estimators_ = []
Expand All @@ -308,16 +310,7 @@ def _fit_arsenal(self, X, y, keep_transformed_data=False):
fit = Parallel(n_jobs=self._n_jobs, prefer="threads")(
delayed(self._fit_ensemble_estimator)(
_clone_estimator(
base_rocket,
(
None
if self.random_state is None
else (
255 if self.random_state == 0 else self.random_state
)
* 37
* (i + 1)
),
base_rocket, rng.randint(np.iinfo(np.int32).max)
),
X,
y,
Expand All @@ -336,16 +329,7 @@ def _fit_arsenal(self, X, y, keep_transformed_data=False):
else:
fit = Parallel(n_jobs=self._n_jobs, prefer="threads")(
delayed(self._fit_ensemble_estimator)(
_clone_estimator(
base_rocket,
(
None
if self.random_state is None
else (255 if self.random_state == 0 else self.random_state)
* 37
* (i + 1)
),
),
_clone_estimator(base_rocket, rng.randint(np.iinfo(np.int32).max)),
X,
y,
keep_transformed_data=keep_transformed_data,
Expand Down

This file was deleted.

Loading

0 comments on commit d835391

Please sign in to comment.