Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MNT] Add/rework transformation tests and remove from exclude list #2360

Merged
merged 16 commits into from
Nov 24, 2024
Merged
9 changes: 8 additions & 1 deletion aeon/testing/estimator_checking/_estimator_checking.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
from aeon.testing.estimator_checking._yield_estimator_checks import (
_yield_all_aeon_checks,
)
from aeon.testing.testing_config import EXCLUDE_ESTIMATORS, EXCLUDED_TESTS
from aeon.testing.testing_config import (
EXCLUDE_ESTIMATORS,
EXCLUDED_TESTS,
EXCLUDED_TESTS_NO_NUMBA,
NUMBA_DISABLED,
)
from aeon.utils.validation._dependencies import (
_check_estimator_deps,
_check_soft_dependencies,
Expand Down Expand Up @@ -313,6 +318,8 @@ def _should_be_skipped(estimator, check, has_dependencies):
return True, "In aeon estimator exclude list", check_name
elif check_name in EXCLUDED_TESTS.get(est_name, []):
return True, "In aeon test exclude list for estimator", check_name
elif NUMBA_DISABLED and check_name in EXCLUDED_TESTS_NO_NUMBA.get(est_name, []):
return True, "In aeon no numba test exclude list for estimator", check_name

return False, "", check_name

Expand Down
29 changes: 8 additions & 21 deletions aeon/testing/estimator_checking/_yield_classification_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,14 @@ def _yield_classification_checks(estimator_class, estimator_instances, datatypes
results_dict=unit_test_proba,
resample_seed=0,
)
# the test currently fails when numba is disabled. See issue #622
if (
estimator_class.__name__ != "HIVECOTEV2"
or os.environ.get("NUMBA_DISABLE_JIT") != "1"
):
yield partial(
check_classifier_against_expected_results,
estimator_class=estimator_class,
data_name="BasicMotions",
data_loader=load_basic_motions,
results_dict=basic_motions_proba,
resample_seed=4,
)
yield partial(
check_classifier_against_expected_results,
estimator_class=estimator_class,
data_name="BasicMotions",
data_loader=load_basic_motions,
results_dict=basic_motions_proba,
resample_seed=4,
)
yield partial(check_classifier_overrides_and_tags, estimator_class=estimator_class)

# data type irrelevant
Expand Down Expand Up @@ -154,9 +149,6 @@ def check_classifier_overrides_and_tags(estimator_class):
f"Override _{method} instead."
)

# axis class parameter is for internal use only
assert "axis" not in estimator_class.__dict__

# Test valid tag for X_inner_type
X_inner_type = estimator_class.get_class_tag(tag_name="X_inner_type")
if isinstance(X_inner_type, str):
Expand All @@ -172,11 +164,6 @@ def check_classifier_overrides_and_tags(estimator_class):
else: # must be a list
assert any([t in valid_unequal_types for t in X_inner_type])

# Must have at least one set to True
multi = estimator_class.get_class_tag(tag_name="capability:multivariate")
uni = estimator_class.get_class_tag(tag_name="capability:univariate")
assert multi or uni

valid_algorithm_types = [
"distance",
"deeplearning",
Expand Down

This file was deleted.

35 changes: 8 additions & 27 deletions aeon/testing/estimator_checking/_yield_estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from aeon.regression import BaseRegressor
from aeon.regression.deep_learning.base import BaseDeepRegressor
from aeon.segmentation import BaseSegmenter
from aeon.similarity_search import BaseSimilaritySearch
from aeon.testing.estimator_checking._yield_anomaly_detection_checks import (
_yield_anomaly_detection_checks,
)
Expand All @@ -36,9 +35,6 @@
from aeon.testing.estimator_checking._yield_clustering_checks import (
_yield_clustering_checks,
)
from aeon.testing.estimator_checking._yield_collection_transformation_checks import (
_yield_collection_transformation_checks,
)
from aeon.testing.estimator_checking._yield_early_classification_checks import (
_yield_early_classification_checks,
)
Expand All @@ -48,12 +44,6 @@
from aeon.testing.estimator_checking._yield_segmentation_checks import (
_yield_segmentation_checks,
)
from aeon.testing.estimator_checking._yield_series_transformation_checks import (
_yield_series_transformation_checks,
)
from aeon.testing.estimator_checking._yield_similarity_search_checks import (
_yield_similarity_search_checks,
)
from aeon.testing.estimator_checking._yield_soft_dependency_checks import (
_yield_soft_dependency_checks,
)
Expand All @@ -68,8 +58,6 @@
from aeon.testing.utils.deep_equals import deep_equals
from aeon.testing.utils.estimator_checks import _get_tag, _run_estimator_method
from aeon.transformations.base import BaseTransformer
from aeon.transformations.collection import BaseCollectionTransformer
from aeon.transformations.series import BaseSeriesTransformer
from aeon.utils.base import VALID_ESTIMATOR_BASES
from aeon.utils.tags import check_valid_tags
from aeon.utils.validation._dependencies import _check_estimator_deps
Expand Down Expand Up @@ -148,26 +136,11 @@ def _yield_all_aeon_checks(
estimator_class, estimator_instances, datatypes
)

if issubclass(estimator_class, BaseSimilaritySearch):
yield from _yield_similarity_search_checks(
estimator_class, estimator_instances, datatypes
)

if issubclass(estimator_class, BaseTransformer):
yield from _yield_transformation_checks(
estimator_class, estimator_instances, datatypes
)

if issubclass(estimator_class, BaseCollectionTransformer):
yield from _yield_collection_transformation_checks(
estimator_class, estimator_instances, datatypes
)

if issubclass(estimator_class, BaseSeriesTransformer):
yield from _yield_series_transformation_checks(
estimator_class, estimator_instances, datatypes
)


def _yield_estimator_checks(estimator_class, estimator_instances, datatypes):
"""Yield all general checks for an aeon estimator."""
Expand Down Expand Up @@ -299,6 +272,14 @@ def check_has_common_interface(estimator_class):
estimator_class.get_fitted_params
)

# axis class parameter is for internal use only
assert "axis" not in estimator_class.__dict__

# Must have at least one set to True
multi = estimator_class.get_class_tag(tag_name="capability:multivariate")
uni = estimator_class.get_class_tag(tag_name="capability:univariate")
assert multi or uni


def check_set_params_sklearn(estimator_class):
"""Check that set_params works correctly, mirrors sklearn check_set_params.
Expand Down
8 changes: 0 additions & 8 deletions aeon/testing/estimator_checking/_yield_regression_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,6 @@ def check_regressor_overrides_and_tags(estimator_class):
f"Override _{method} instead."
)

# axis class parameter is for internal use only
assert "axis" not in estimator_class.__dict__

# Test valid tag for X_inner_type
X_inner_type = estimator_class.get_class_tag(tag_name="X_inner_type")
if isinstance(X_inner_type, str):
Expand All @@ -163,11 +160,6 @@ def check_regressor_overrides_and_tags(estimator_class):
else: # must be a list
assert any([t in valid_unequal_types for t in X_inner_type])

# Must have at least one set to True
multi = estimator_class.get_class_tag(tag_name="capability:multivariate")
uni = estimator_class.get_class_tag(tag_name="capability:univariate")
assert multi or uni

valid_algorithm_types = [
"distance",
"deeplearning",
Expand Down

This file was deleted.

This file was deleted.

Loading