diff --git a/aeon/testing/estimator_checking/_estimator_checking.py b/aeon/testing/estimator_checking/_estimator_checking.py index 0ea0ebfbe3..c07815ea89 100644 --- a/aeon/testing/estimator_checking/_estimator_checking.py +++ b/aeon/testing/estimator_checking/_estimator_checking.py @@ -18,7 +18,12 @@ from aeon.testing.estimator_checking._yield_estimator_checks import ( _yield_all_aeon_checks, ) -from aeon.testing.testing_config import EXCLUDE_ESTIMATORS, EXCLUDED_TESTS +from aeon.testing.testing_config import ( + EXCLUDE_ESTIMATORS, + EXCLUDED_TESTS, + EXCLUDED_TESTS_NO_NUMBA, + NUMBA_DISABLED, +) from aeon.utils.validation._dependencies import ( _check_estimator_deps, _check_soft_dependencies, @@ -313,6 +318,8 @@ def _should_be_skipped(estimator, check, has_dependencies): return True, "In aeon estimator exclude list", check_name elif check_name in EXCLUDED_TESTS.get(est_name, []): return True, "In aeon test exclude list for estimator", check_name + elif NUMBA_DISABLED and check_name in EXCLUDED_TESTS_NO_NUMBA.get(est_name, []): + return True, "In aeon no numba test exclude list for estimator", check_name return False, "", check_name diff --git a/aeon/testing/estimator_checking/_yield_classification_checks.py b/aeon/testing/estimator_checking/_yield_classification_checks.py index ac83d1c878..00b82705de 100644 --- a/aeon/testing/estimator_checking/_yield_classification_checks.py +++ b/aeon/testing/estimator_checking/_yield_classification_checks.py @@ -41,19 +41,14 @@ def _yield_classification_checks(estimator_class, estimator_instances, datatypes results_dict=unit_test_proba, resample_seed=0, ) - # the test currently fails when numba is disabled. See issue #622 - if ( - estimator_class.__name__ != "HIVECOTEV2" - or os.environ.get("NUMBA_DISABLE_JIT") != "1" - ): - yield partial( - check_classifier_against_expected_results, - estimator_class=estimator_class, - data_name="BasicMotions", - data_loader=load_basic_motions, - results_dict=basic_motions_proba, - resample_seed=4, - ) + yield partial( + check_classifier_against_expected_results, + estimator_class=estimator_class, + data_name="BasicMotions", + data_loader=load_basic_motions, + results_dict=basic_motions_proba, + resample_seed=4, + ) yield partial(check_classifier_overrides_and_tags, estimator_class=estimator_class) # data type irrelevant diff --git a/aeon/testing/estimator_checking/_yield_collection_transformation_checks.py b/aeon/testing/estimator_checking/_yield_collection_transformation_checks.py deleted file mode 100644 index c62b9d5440..0000000000 --- a/aeon/testing/estimator_checking/_yield_collection_transformation_checks.py +++ /dev/null @@ -1,45 +0,0 @@ -"""Tests for all collection transformers.""" - -from functools import partial - -import numpy as np - -from aeon.testing.data_generation import make_example_3d_numpy -from aeon.transformations.collection.channel_selection.base import BaseChannelSelector - - -def _yield_collection_transformation_checks( - estimator_class, estimator_instances, datatypes -): - """Yield all collection transformer checks for an aeon estimator.""" - # only class required - yield partial( - check_does_not_override_final_methods, estimator_class=estimator_class - ) - - if issubclass(estimator_class, BaseChannelSelector): - yield partial(check_channel_selectors, estimator_class=estimator_class) - - -def check_does_not_override_final_methods(estimator_class): - """Test does not override final methods.""" - assert "fit" not in estimator_class.__dict__ - assert "transform" not in estimator_class.__dict__ - assert "fit_transform" not in estimator_class.__dict__ - - -def check_channel_selectors(estimator_class): - """Test channel selectors. - - Needs fit and must select at least one channel - """ - X, _ = make_example_3d_numpy(n_cases=20, n_channels=6, n_timepoints=30) - y = np.array([0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]) - cs = estimator_class._create_test_instance(return_first=True) - assert not cs.get_tag("fit_is_empty") - cs.fit(X, y) - assert cs.channels_selected_ is not None - assert len(cs.channels_selected_) > 0 - X2 = cs.transform(X) - assert isinstance(X2, np.ndarray) - assert X2.ndim == 3 diff --git a/aeon/testing/estimator_checking/_yield_estimator_checks.py b/aeon/testing/estimator_checking/_yield_estimator_checks.py index 098534d34c..b5fb60c1d3 100644 --- a/aeon/testing/estimator_checking/_yield_estimator_checks.py +++ b/aeon/testing/estimator_checking/_yield_estimator_checks.py @@ -24,7 +24,6 @@ from aeon.regression import BaseRegressor from aeon.regression.deep_learning.base import BaseDeepRegressor from aeon.segmentation import BaseSegmenter -from aeon.similarity_search import BaseSimilaritySearch from aeon.testing.estimator_checking._yield_anomaly_detection_checks import ( _yield_anomaly_detection_checks, ) @@ -34,9 +33,6 @@ from aeon.testing.estimator_checking._yield_clustering_checks import ( _yield_clustering_checks, ) -from aeon.testing.estimator_checking._yield_collection_transformation_checks import ( - _yield_collection_transformation_checks, -) from aeon.testing.estimator_checking._yield_early_classification_checks import ( _yield_early_classification_checks, ) @@ -49,12 +45,6 @@ from aeon.testing.estimator_checking._yield_segmentation_checks import ( _yield_segmentation_checks, ) -from aeon.testing.estimator_checking._yield_series_transformation_checks import ( - _yield_series_transformation_checks, -) -from aeon.testing.estimator_checking._yield_similarity_search_checks import ( - _yield_similarity_search_checks, -) from aeon.testing.estimator_checking._yield_soft_dependency_checks import ( _yield_soft_dependency_checks, ) @@ -69,8 +59,6 @@ from aeon.testing.utils.deep_equals import deep_equals from aeon.testing.utils.estimator_checks import _get_tag, _run_estimator_method from aeon.transformations.base import BaseTransformer -from aeon.transformations.collection import BaseCollectionTransformer -from aeon.transformations.series import BaseSeriesTransformer from aeon.utils.base import VALID_ESTIMATOR_BASES from aeon.utils.tags import check_valid_tags from aeon.utils.validation._dependencies import _check_estimator_deps @@ -153,26 +141,11 @@ def _yield_all_aeon_checks( estimator_class, estimator_instances, datatypes ) - if issubclass(estimator_class, BaseSimilaritySearch): - yield from _yield_similarity_search_checks( - estimator_class, estimator_instances, datatypes - ) - if issubclass(estimator_class, BaseTransformer): yield from _yield_transformation_checks( estimator_class, estimator_instances, datatypes ) - if issubclass(estimator_class, BaseCollectionTransformer): - yield from _yield_collection_transformation_checks( - estimator_class, estimator_instances, datatypes - ) - - if issubclass(estimator_class, BaseSeriesTransformer): - yield from _yield_series_transformation_checks( - estimator_class, estimator_instances, datatypes - ) - def _yield_estimator_checks(estimator_class, estimator_instances, datatypes): """Yield all general checks for an aeon estimator.""" @@ -289,6 +262,11 @@ def check_has_common_interface(estimator_class): "axis" not in estimator_class.__dict__ ), "axis should not be a class parameter" + # Must have at least one set to True + multi = estimator_class.get_class_tag(tag_name="capability:multivariate") + uni = estimator_class.get_class_tag(tag_name="capability:univariate") + assert multi or uni + def check_set_params(estimator_class): """Check that set_params works correctly.""" diff --git a/aeon/testing/estimator_checking/_yield_series_transformation_checks.py b/aeon/testing/estimator_checking/_yield_series_transformation_checks.py deleted file mode 100644 index 6a681a3b0e..0000000000 --- a/aeon/testing/estimator_checking/_yield_series_transformation_checks.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Tests for all series transformers.""" - - -def _yield_series_transformation_checks( - estimator_class, estimator_instances, datatypes -): - """Yield all series transformer checks for an aeon estimator.""" - # nothing currently! - return [] diff --git a/aeon/testing/estimator_checking/_yield_similarity_search_checks.py b/aeon/testing/estimator_checking/_yield_similarity_search_checks.py deleted file mode 100644 index 1c8c61f21b..0000000000 --- a/aeon/testing/estimator_checking/_yield_similarity_search_checks.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Tests for all similarity searchers.""" - - -def _yield_similarity_search_checks(estimator_class, estimator_instances, datatypes): - """Yield all similarity search checks for an aeon estimator.""" - # nothing currently! - return [] diff --git a/aeon/testing/estimator_checking/_yield_transformation_checks.py b/aeon/testing/estimator_checking/_yield_transformation_checks.py index 88936cd719..507c8c1e08 100644 --- a/aeon/testing/estimator_checking/_yield_transformation_checks.py +++ b/aeon/testing/estimator_checking/_yield_transformation_checks.py @@ -1,7 +1,7 @@ """Tests for all transformers.""" +import sys from functools import partial -from sys import platform import numpy as np import pandas as pd @@ -9,100 +9,193 @@ from sklearn.utils._testing import set_random_state from aeon.base._base import _clone_estimator +from aeon.base._base_series import VALID_SERIES_INNER_TYPES from aeon.datasets import load_basic_motions, load_unit_test from aeon.testing.expected_results.expected_transform_outputs import ( basic_motions_result, unit_test_result, ) from aeon.testing.testing_data import FULL_TEST_DATA_DICT +from aeon.testing.utils.deep_equals import deep_equals from aeon.testing.utils.estimator_checks import _run_estimator_method +from aeon.transformations.collection.channel_selection.base import BaseChannelSelector +from aeon.transformations.series import BaseSeriesTransformer +from aeon.utils import COLLECTIONS_DATA_TYPES def _yield_transformation_checks(estimator_class, estimator_instances, datatypes): """Yield all transformation checks for an aeon transformer.""" # only class required - yield partial( - check_transformer_against_expected_results, estimator_class=estimator_class - ) + if sys.platform != "darwin": + yield partial( + check_transformer_against_expected_results, + estimator_class=estimator_class, + data_name="UnitTest", + data_loader=load_unit_test, + results_dict=unit_test_result, + resample_seed=0, + ) + yield partial( + check_transformer_against_expected_results, + estimator_class=estimator_class, + data_name="BasicMotions", + data_loader=load_basic_motions, + results_dict=basic_motions_result, + resample_seed=4, + ) + yield partial(check_transformer_overrides_and_tags, estimator_class=estimator_class) # test class instances for i, estimator in enumerate(estimator_instances): # test all data types for datatype in datatypes[i]: yield partial( - check_transform_inverse_transform_equivalent, - estimator=estimator, - datatype=datatype, + check_transformer_output, estimator=estimator, datatype=datatype ) - -def check_transformer_against_expected_results(estimator_class): + if isinstance(estimator, BaseChannelSelector): + yield partial( + check_channel_selectors, + estimator=estimator, + datatype=datatype, + ) + + if estimator is not None and estimator.get_tag( + "capability:inverse_transform" + ): + yield partial( + check_transform_inverse_transform_equivalent, + estimator=estimator, + datatype=datatype, + ) + + +def check_transformer_against_expected_results( + estimator_class, data_name, data_loader, results_dict, resample_seed +): """Test transformer against stored results.""" + # retrieve expected transform output, and skip test if not available + if estimator_class.__name__ in results_dict.keys(): + expected_results = results_dict[estimator_class.__name__] + else: + # skip test if no expected results are registered + return f"No stored results for {estimator_class.__name__} on {data_name}" + # we only use the first estimator instance for testing - classname = estimator_class.__name__ - - # We cannot guarantee same results on ARM macOS - if platform == "darwin": - return None - - for data_name, data_dict, data_loader, data_seed in [ - ["UnitTest", unit_test_result, load_unit_test, 0], - ["BasicMotions", basic_motions_result, load_basic_motions, 4], - ]: - # retrieve expected transform output, and skip test if not available - if classname in data_dict.keys(): - expected_results = data_dict[classname] - else: - # skip test if no expected results are registered - continue - - # we only use the first estimator instance for testing - estimator_instance = estimator_class._create_test_instance( - parameter_set="results_comparison" - ) - # set random seed if possible - set_random_state(estimator_instance, 0) + estimator_instance = estimator_class._create_test_instance( + parameter_set="results_comparison" + ) + # set random seed if possible + set_random_state(estimator_instance, 0) - # load test data - X_train, y_train = data_loader(split="train") - indices = np.random.RandomState(data_seed).choice( - len(y_train), 5, replace=False - ) + # load test data + X_train, y_train = data_loader(split="train") + indices = np.random.RandomState(resample_seed).choice( + len(y_train), 5, replace=False + ) - # fit transformer and transform - results = np.nan_to_num( - estimator_instance.fit_transform(X_train[indices], y_train[indices]), - False, - 0, - 0, - 0, - ) + # fit transformer and transform + results = np.nan_to_num( + estimator_instance.fit_transform(X_train[indices], y_train[indices]), + False, + 0, + 0, + 0, + ) - # assert results are the same - assert_array_almost_equal( - results, - expected_results, - decimal=2, - err_msg=f"Failed to reproduce results for {classname} on {data_name}", - ) + # assert results are the same + assert_array_almost_equal( + results, + expected_results, + decimal=2, + err_msg=f"Failed to reproduce results for {estimator_class.__name__} " + f"on {data_name}", + ) -def check_transform_inverse_transform_equivalent(estimator, datatype): - """Test that inverse_transform is indeed inverse to transform.""" - # skip this test if the estimator does not have inverse_transform - if not estimator.get_class_tag("capability:inverse_transform", False): - return None +def check_transformer_overrides_and_tags(estimator_class): + """Test does not override final methods.""" + final_methods = [ + "fit", + "transform", + "fit_transform", + ] + for method in final_methods: + if method in estimator_class.__dict__: + raise ValueError( + f"Transformer {estimator_class} overrides the method {method}. " + f"Override _{method} instead." + ) - estimator = _clone_estimator(estimator) + dtypes = ( + VALID_SERIES_INNER_TYPES + if issubclass(estimator_class, BaseSeriesTransformer) + else COLLECTIONS_DATA_TYPES + ) + + # Test valid tag for X_inner_type + X_inner_type = estimator_class.get_class_tag(tag_name="X_inner_type") + if isinstance(X_inner_type, str): + assert X_inner_type in dtypes + else: # must be a list + assert all([t in dtypes for t in X_inner_type]) + + # one of X_inner_types must be capable of storing unequal length + if estimator_class.get_class_tag( + "capability:unequal_length", raise_error=False, tag_value_default=False + ): + valid_unequal_types = ["np-list", "df-list", "pd-multiindex"] + if isinstance(X_inner_type, str): + assert X_inner_type in valid_unequal_types + else: # must be a list + assert any([t in valid_unequal_types for t in X_inner_type]) + + if estimator_class.get_class_tag("capability:inverse_transform"): + assert "_inverse_transform" in estimator_class.__dict__ + else: + assert "_inverse_transform" not in estimator_class.__dict__ - X = FULL_TEST_DATA_DICT[datatype]["train"][0] +def check_transformer_output(estimator, datatype): + """Test transformer outputs.""" + estimator = _clone_estimator(estimator) + set_random_state(estimator, 0) + + # run fit and predict _run_estimator_method(estimator, "fit", datatype, "train") Xt = _run_estimator_method(estimator, "transform", datatype, "train") + if "_fit_transform" in estimator.__class__.__dict__: + Xt2 = _run_estimator_method(estimator, "fit_transform", datatype, "train") + assert deep_equals(Xt, Xt2, ignore_index=True) + + +def check_channel_selectors(estimator, datatype): + """Test channel selectors have fit and select at least one channel.""" + estimator = _clone_estimator(estimator) + + assert not estimator.get_tag("fit_is_empty") + + Xt = _run_estimator_method(estimator, "fit_transform", datatype, "train") + + assert hasattr(estimator, "channels_selected_") + assert isinstance(estimator.channels_selected_, (list, np.ndarray)) + assert len(estimator.channels_selected_) > 0 + assert isinstance(Xt, np.ndarray) + assert Xt.ndim == 3 + + +def check_transform_inverse_transform_equivalent(estimator, datatype): + """Test that inverse_transform is inverse to transform.""" + estimator = _clone_estimator(estimator) + + X = FULL_TEST_DATA_DICT[datatype]["train"][0] + Xt = _run_estimator_method(estimator, "fit_transform", datatype, "train") Xit = estimator.inverse_transform(Xt) - if isinstance(X, pd.DataFrame): - assert_array_almost_equal(X.loc[Xit.index], Xit) - else: - assert_array_almost_equal(X, Xit) + if isinstance(X, (np.ndarray, pd.DataFrame)): + X = X.squeeze() + if isinstance(Xit, (np.ndarray, pd.DataFrame)): + Xit = Xit.squeeze() + + assert deep_equals(X, Xit, ignore_index=True) diff --git a/aeon/testing/testing_config.py b/aeon/testing/testing_config.py index ddbc82de30..41477dead2 100644 --- a/aeon/testing/testing_config.py +++ b/aeon/testing/testing_config.py @@ -1,7 +1,14 @@ """Test configuration.""" __maintainer__ = ["MatthewMiddlehurst"] -__all__ = ["PR_TESTING", "EXCLUDE_ESTIMATORS", "EXCLUDED_TESTS"] +__all__ = [ + "PR_TESTING", + "EXCLUDE_ESTIMATORS", + "EXCLUDED_TESTS", + "EXCLUDED_TESTS_NO_NUMBA", +] + +import os import aeon.testing._cicd_numba_caching # noqa: F401 @@ -12,14 +19,13 @@ # --enablethreading True flag MULTITHREAD_TESTING = False -# exclude estimators here for short term fixes -EXCLUDE_ESTIMATORS = [ - "ClearSkyTransformer", - # See #2071 - "RISTRegressor", -] +# whether numba is disabled vis environment variable +NUMBA_DISABLED = os.environ.get("NUMBA_DISABLE_JIT") == "1" +# exclude estimators here for short term fixes +EXCLUDE_ESTIMATORS = [] +# Exclude specific tests for estimators here EXCLUDED_TESTS = { # Early classifiers (EC) intentionally retain information from previous predict # calls for #1 (test_non_state_changing_method_contract). @@ -38,20 +44,12 @@ "check_persistence_via_pickle", "check_save_estimators_to_file", ], - # has a keras fail, unknown reason, see #1387 - "LearningShapeletClassifier": ["check_fit_deterministic"], # needs investigation "SASTClassifier": ["check_fit_deterministic"], "RSASTClassifier": ["check_fit_deterministic"], "SAST": ["check_fit_deterministic"], "RSAST": ["check_fit_deterministic"], "SFA": ["check_persistence_via_pickle", "check_fit_deterministic"], - "CollectionId": ["check_transform_inverse_transform_equivalent"], - "ScaledLogitSeriesTransformer": ["check_transform_inverse_transform_equivalent"], - # also uncomment in test_check_estimator.py - "MockMultivariateSeriesTransformer": [ - "check_transform_inverse_transform_equivalent" - ], # missed in legacy testing, changes state in predict/transform "FLUSSSegmenter": ["check_non_state_changing_method"], "InformationGainSegmenter": ["check_non_state_changing_method"], @@ -75,8 +73,15 @@ # if the next predict calls uses the same query length parameter. "QuerySearch": ["check_non_state_changing_method"], "SeriesSearch": ["check_non_state_changing_method"], - # Unknown issue not producing the same results for Covid3Month (other is fine) + # Unknown issue not producing the same results "RDSTRegressor": ["check_regressor_against_expected_results"], + "RISTRegressor": ["check_regressor_against_expected_results"], +} + +# Exclude specific tests for estimators here only when numba is disabled +EXCLUDED_TESTS_NO_NUMBA = { + # See issue #622 + "HIVECOTEV2": ["check_classifier_against_expected_results"], } diff --git a/aeon/testing/utils/deep_equals.py b/aeon/testing/utils/deep_equals.py index aedaa202d4..86c6c5cd96 100644 --- a/aeon/testing/utils/deep_equals.py +++ b/aeon/testing/utils/deep_equals.py @@ -10,7 +10,7 @@ from scipy.sparse import csr_matrix -def deep_equals(x, y, return_msg=False): +def deep_equals(x, y, ignore_index=False, return_msg=False): """Test two objects for equality in value. Intended for: @@ -27,6 +27,8 @@ def deep_equals(x, y, return_msg=False): First item to compare. y : object Second item to compare. + ignore_index : bool, default=False + If True, will ignore the index of pd.Series and pd.DataFrame. return_msg : bool, default=False Whether to return an informative message about what is not equal. @@ -39,26 +41,26 @@ def deep_equals(x, y, return_msg=False): Only returned if return_msg is True. Indication of what is the reason for not being equal """ - eq, msg = _deep_equals(x, y, 0) + eq, msg = _deep_equals(x, y, 0, ignore_index) return eq if not return_msg else (eq, msg) -def _deep_equals(x, y, depth): +def _deep_equals(x, y, depth, ignore_index): if x is y: return True, "" if type(x) is not type(y): return False, f"x.type ({type(x)}) != y.type ({type(y)}), depth={depth}" if isinstance(x, pd.Series): - return _series_equals(x, y, depth) + return _series_equals(x, y, depth, ignore_index) elif isinstance(x, pd.DataFrame): - return _dataframe_equals(x, y, depth) + return _dataframe_equals(x, y, depth, ignore_index) elif isinstance(x, np.ndarray): return _numpy_equals(x, y, depth) elif isinstance(x, (list, tuple)): - return _list_equals(x, y, depth) + return _list_equals(x, y, depth, ignore_index) elif isinstance(x, dict): - return _dict_equals(x, y, depth) + return _dict_equals(x, y, depth, ignore_index) elif isinstance(x, csr_matrix): return _csrmatrix_equals(x, y, depth) # non-iterable types @@ -79,14 +81,16 @@ def _deep_equals(x, y, depth): raise ValueError(f"Unknown type: {type(x)}, depth={depth}") -def _series_equals(x, y, depth): +def _series_equals(x, y, depth, ignore_index): if x.dtype != y.dtype: return False, f"x.dtype ({x.dtype}) != y.dtype ({y.dtype}), depth={depth}" # if columns are object, recurse over entries and index if x.dtype == "object": - index_equal = x.index.equals(y.index) - values_equal, values_msg = _deep_equals(list(x.values), list(y.values), depth) + index_equal = ignore_index or x.index.equals(y.index) + values_equal, values_msg = _deep_equals( + list(x.values), list(y.values), depth, ignore_index + ) if not values_equal: msg = values_msg @@ -102,20 +106,24 @@ def _series_equals(x, y, depth): return eq, msg -def _dataframe_equals(x, y, depth): +def _dataframe_equals(x, y, depth, ignore_index): if not x.columns.equals(y.columns): return False, f"x.columns ({x.columns}) != y.columns ({y.columns})" # if columns are equal and at least one is object, recurse over Series if sum(x.dtypes == "object") > 0: for i, c in enumerate(x.columns): - eq, msg = _deep_equals(x[c], y[c], depth + 1) + eq, msg = _deep_equals(x[c], y[c], depth + 1, ignore_index) if not eq: return False, msg + f", idx={i}" return True, "" else: - eq = x.equals(y) + eq = ( + np.allclose(x.values, y.values, equal_nan=True) + if ignore_index + else x.equals(y) + ) msg = "" if eq else f"x ({x}) != y ({y}), depth={depth}" return eq, msg @@ -124,30 +132,30 @@ def _numpy_equals(x, y, depth): if x.dtype != y.dtype: return False, f"x.dtype ({x.dtype}) != y.dtype ({y.dtype})" - eq = np.array_equal(x, y, equal_nan=True) + eq = np.allclose(x, y, equal_nan=True) msg = "" if eq else f"x ({x}) != y ({y}), depth={depth}" return eq, msg def _csrmatrix_equals(x, y, depth): - if not np.allclose(x.toarray(), y.toarray()): + if not np.allclose(x.toarray(), y.toarray(), equal_nan=True): return False, f"x ({x}) != y ({y}), depth={depth}" return True, "" -def _list_equals(x, y, depth): +def _list_equals(x, y, depth, ignore_index): if len(x) != len(y): return False, f"x.len ({len(x)}) != y.len ({len(y)}), depth={depth}" for i in range(len(x)): - eq, msg = _deep_equals(x[i], y[i], depth + 1) + eq, msg = _deep_equals(x[i], y[i], depth + 1, ignore_index) if not eq: return False, msg + f", idx={i}" return True, "" -def _dict_equals(x, y, depth): +def _dict_equals(x, y, depth, ignore_index): xkeys = set(x.keys()) ykeys = set(y.keys()) if xkeys != ykeys: @@ -164,7 +172,7 @@ def _dict_equals(x, y, depth): # we now know that xkeys == ykeys for i, key in enumerate(xkeys): - eq, msg = _deep_equals(x[key], y[key], depth + 1) + eq, msg = _deep_equals(x[key], y[key], depth + 1, ignore_index) if not eq: return False, msg + f", idx={i}" diff --git a/aeon/transformations/series/_scaled_logit.py b/aeon/transformations/series/_scaled_logit.py index be483b9955..e400385f35 100644 --- a/aeon/transformations/series/_scaled_logit.py +++ b/aeon/transformations/series/_scaled_logit.py @@ -146,7 +146,6 @@ def _get_test_params(cls, parameter_set="default"): Name of the set of test parameters to return, for use in tests. If no special parameters are defined for a value, will return `"default"` set. - Returns ------- params : dict or list of dict, default = {} @@ -154,10 +153,4 @@ def _get_test_params(cls, parameter_set="default"): Each dict are parameters to construct an "interesting" test instance, i.e., `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. """ - test_params = [ - {"lower_bound": None, "upper_bound": None}, - {"lower_bound": -(10**6), "upper_bound": None}, - {"lower_bound": None, "upper_bound": 10**6}, - {"lower_bound": -(10**6), "upper_bound": 10**6}, - ] - return test_params + return {"lower_bound": -(10**6), "upper_bound": 10**6} diff --git a/aeon/transformations/series/tests/test_scaled_logit.py b/aeon/transformations/series/tests/test_scaled_logit.py index adb8c88085..7661458f5b 100644 --- a/aeon/transformations/series/tests/test_scaled_logit.py +++ b/aeon/transformations/series/tests/test_scaled_logit.py @@ -1,13 +1,9 @@ """ScaledLogit transform unit tests.""" -__maintainer__ = [] - -from warnings import warn - import numpy as np import pytest +from numpy.testing import assert_array_equal -from aeon.datasets import load_airline from aeon.transformations.series._scaled_logit import ScaledLogitSeriesTransformer TEST_SERIES = np.array([30, 40, 60]) @@ -26,31 +22,17 @@ def test_scaled_logit_transform(lower, upper, output): """Test that we get the right output.""" transformer = ScaledLogitSeriesTransformer(lower, upper) y_transformed = transformer.fit_transform(TEST_SERIES) - assert np.all(output == y_transformed) + assert_array_equal(y_transformed.squeeze(), output) -@pytest.mark.parametrize( - "lower, upper, message", - [ - ( - 0, - 300, - ( - "X in ScaledLogitSeriesTransformer should not have values greater" - "than upper_bound" - ), - ), - ( - 300, - 700, - "X in ScaledLogitSeriesTransformer should not have values lower than " - "lower_bound", - ), - ], -) -def test_scaled_logit_bound_errors(lower, upper, message): - """Tests all exceptions.""" - y = load_airline() - with pytest.warns(RuntimeWarning): - ScaledLogitSeriesTransformer(lower, upper).fit_transform(y) - warn(message, RuntimeWarning) +def test_scaled_logit_bound_warnings(): + """Tests all warnings.""" + with pytest.warns(RuntimeWarning, match="not have values lower than lower_bound"): + ScaledLogitSeriesTransformer(lower_bound=300, upper_bound=0).fit_transform( + TEST_SERIES + ) + + with pytest.warns(RuntimeWarning, match="not have values greater than upper_bound"): + ScaledLogitSeriesTransformer(lower_bound=300, upper_bound=0).fit_transform( + TEST_SERIES + )