From 108948d969b0cf6994602b646c826be856a67efb Mon Sep 17 00:00:00 2001 From: Matthew Middlehurst Date: Fri, 22 Nov 2024 19:12:33 +0200 Subject: [PATCH] [MNT] Update and consolidate general estimator checks (#2377) * tidy general estimator tests * fixes * fixes --- .../estimator_checking/_estimator_checking.py | 4 +- .../_yield_classification_checks.py | 8 - .../_yield_estimator_checks.py | 365 +++++++----------- .../_yield_regression_checks.py | 8 - .../tests/test_check_estimator.py | 4 +- aeon/testing/testing_config.py | 11 +- 6 files changed, 158 insertions(+), 242 deletions(-) diff --git a/aeon/testing/estimator_checking/_estimator_checking.py b/aeon/testing/estimator_checking/_estimator_checking.py index 2211e659cb..0ea0ebfbe3 100644 --- a/aeon/testing/estimator_checking/_estimator_checking.py +++ b/aeon/testing/estimator_checking/_estimator_checking.py @@ -184,8 +184,8 @@ class is passed. >>> results = check_estimator(MockClassifier()) Running specific check for MockClassifier - >>> check_estimator(MockClassifier, checks_to_run="check_clone") - {'check_clone(estimator=MockClassifier())': 'PASSED'} + >>> check_estimator(MockClassifier, checks_to_run="check_get_params") + {'check_get_params(estimator=MockClassifier())': 'PASSED'} """ # check if estimator has soft dependencies installed _check_estimator_deps(estimator) diff --git a/aeon/testing/estimator_checking/_yield_classification_checks.py b/aeon/testing/estimator_checking/_yield_classification_checks.py index 2828063f34..582632b99d 100644 --- a/aeon/testing/estimator_checking/_yield_classification_checks.py +++ b/aeon/testing/estimator_checking/_yield_classification_checks.py @@ -154,9 +154,6 @@ def check_classifier_overrides_and_tags(estimator_class): f"Override _{method} instead." ) - # axis class parameter is for internal use only - assert "axis" not in estimator_class.__dict__ - # Test valid tag for X_inner_type X_inner_type = estimator_class.get_class_tag(tag_name="X_inner_type") if isinstance(X_inner_type, str): @@ -172,11 +169,6 @@ def check_classifier_overrides_and_tags(estimator_class): else: # must be a list assert any([t in valid_unequal_types for t in X_inner_type]) - # Must have at least one set to True - multi = estimator_class.get_class_tag(tag_name="capability:multivariate") - uni = estimator_class.get_class_tag(tag_name="capability:univariate") - assert multi or uni - valid_algorithm_types = [ "distance", "deeplearning", diff --git a/aeon/testing/estimator_checking/_yield_estimator_checks.py b/aeon/testing/estimator_checking/_yield_estimator_checks.py index 112fc96aee..0078c6dbfb 100644 --- a/aeon/testing/estimator_checking/_yield_estimator_checks.py +++ b/aeon/testing/estimator_checking/_yield_estimator_checks.py @@ -1,9 +1,8 @@ """Tests for all estimators.""" -import inspect +import math import numbers import pickle -import types from copy import deepcopy from functools import partial from inspect import getfullargspec, isclass, signature @@ -13,7 +12,6 @@ import pytest from numpy.testing import assert_array_almost_equal from sklearn.exceptions import NotFittedError -from sklearn.utils.estimator_checks import check_get_params_invariance from aeon.anomaly_detection.base import BaseAnomalyDetector from aeon.base import BaseAeonEstimator @@ -175,7 +173,7 @@ def _yield_estimator_checks(estimator_class, estimator_instances, datatypes): yield partial(check_create_test_instance, estimator_class=estimator_class) yield partial(check_inheritance, estimator_class=estimator_class) yield partial(check_has_common_interface, estimator_class=estimator_class) - yield partial(check_set_params_sklearn, estimator_class=estimator_class) + yield partial(check_set_params, estimator_class=estimator_class) yield partial(check_constructor, estimator_class=estimator_class) yield partial(check_estimator_class_tags, estimator_class=estimator_class) @@ -183,8 +181,6 @@ def _yield_estimator_checks(estimator_class, estimator_instances, datatypes): for i, estimator in enumerate(estimator_instances): # no data needed yield partial(check_get_params, estimator=estimator) - yield partial(check_set_params, estimator=estimator) - yield partial(check_clone, estimator=estimator) yield partial(check_repr, estimator=estimator) yield partial(check_estimator_tags, estimator=estimator) @@ -202,7 +198,9 @@ def _yield_estimator_checks(estimator_class, estimator_instances, datatypes): datatype=datatypes[i][0], ) yield partial( - check_fit_updates_state, estimator=estimator, datatype=datatypes[i][0] + check_fit_updates_state_and_cloning, + estimator=estimator, + datatype=datatypes[i][0], ) if not _get_tag(estimator, "fit_is_empty", default=False): @@ -230,13 +228,6 @@ def check_create_test_instance(estimator_class): _create_test_instance is the key method used to create test instances in testing. If this test does not pass, the validity of the other tests cannot be guaranteed. - - Also tests inheritance and super call logic in the constructor. - - Tests that: - * _create_test_instance results in an instance of estimator_class - * __init__ calls super.__init__ - * _tags_dynamic attribute for tag inspection is present after construction """ estimator = estimator_class._create_test_instance() @@ -246,38 +237,25 @@ def check_create_test_instance(estimator_class): f"found {type(estimator)}" ) - msg = ( - f"{estimator_class.__name__}.__init__ should call super().__init__, the " - "estimator does not produce the attributes this call would produce." - ) - assert hasattr(estimator, "_tags_dynamic"), msg - -# todo consider removing the multiple base class allowance. def check_inheritance(estimator_class): """Check that estimator inherits from BaseAeonEstimator.""" assert issubclass( estimator_class, BaseAeonEstimator ), f"object {estimator_class} is not a sub-class of BaseAeonEstimator." - if hasattr(estimator_class, "fit"): - assert issubclass(estimator_class, BaseAeonEstimator), ( - f"estimator: {estimator_class} has fit method, but" - f"is not a sub-class of BaseAeonEstimator." - ) - # Usually estimators inherit only from one BaseAeonEstimator type, but in some cases - # they may be predictor and transformer at the same time (e.g. pipelines) + # they may inherit both as part of a series/collection split n_base_types = sum( issubclass(estimator_class, cls) for cls in VALID_ESTIMATOR_BASES.values() ) + assert 2 >= n_base_types >= 1, "Estimator should inherit from 1 or 2 base types." - assert 2 >= n_base_types >= 1 - - # If the estimator inherits from more than one base estimator type, we check if - # one of them is a transformer base type + # Only transformers can inherit from multiple base types currently if n_base_types > 1: - assert issubclass(estimator_class, BaseTransformer) + assert issubclass( + estimator_class, BaseTransformer + ), "Only transformers can inherit from multiple base types." def check_has_common_interface(estimator_class): @@ -299,23 +277,36 @@ def check_has_common_interface(estimator_class): estimator_class.get_fitted_params ) + # axis class parameter is for internal use only + assert ( + "axis" not in estimator_class.__dict__ + ), "axis should not be a class parameter" -def check_set_params_sklearn(estimator_class): - """Check that set_params works correctly, mirrors sklearn check_set_params. - Instead of the "fuzz values" in sklearn's check_set_params, - we use the other test parameter settings (which are assumed valid). - This guarantees settings which play along with the __init__ content. - """ +def check_set_params(estimator_class): + """Check that set_params works correctly.""" + # some parameters do not have default values, we need to set them estimator = estimator_class._create_test_instance() + required_params_names = [ + p.name + for p in signature(estimator_class.__init__).parameters.values() + # dont include self and *args, **kwargs + if p.name != "self" and p.kind not in [p.VAR_KEYWORD, p.VAR_POSITIONAL] + # has no default + and p.default == p.empty + ] + params = estimator.get_params() + init_params = {p: params[p] for p in params if p in required_params_names} + + # default constructed instance except for required parameters + estimator = estimator_class(**init_params) + test_params = estimator_class._get_test_params() if not isinstance(test_params, list): test_params = [test_params] for params in test_params: - # we construct the full parameter set for params - # params may only have parameters that are deviating from defaults - # in order to set non-default parameters back to defaults + # parameter sets may only have parameters that are deviating from defaults params_full = estimator.get_params(deep=False) params_full.update(params) @@ -338,28 +329,37 @@ def check_set_params_sklearn(estimator_class): def check_constructor(estimator_class): """Check that the constructor has sklearn compatible signature and behaviour. - Based on sklearn check_estimator testing of __init__ logic. - Uses _create_test_instance to create an instance. - Assumes test_create_test_instance has passed and certified _create_test_instance. - Tests that: * constructor has no varargs * tests that constructor constructs an instance of the class * tests that all parameters are set in init to an attribute of the same name * tests that parameter values are always copied to the attribute and not changed - * tests that default parameters are one of the following: - None, str, int, float, bool, tuple, function, joblib memory, numpy primitive - (other type parameters should be None, default handling should be by writing - the default to attribute of a different name, e.g., my_param_ not my_param) + * tests that default parameters are a valid type or callable """ - msg = "constructor __init__ should have no varargs" - assert getfullargspec(estimator_class.__init__).varkw is None, msg + assert ( + getfullargspec(estimator_class.__init__).varkw is None + ), "constructor __init__ should have no varargs" estimator = estimator_class._create_test_instance() - assert isinstance(estimator, estimator_class) - # Ensure that each parameter is set in init - init_params = inspect.signature(estimator_class.__init__).parameters + # ensure base class super is called in constructor + assert hasattr(estimator, "is_fitted"), ( + "Estimator should have an is_fitted attribute after init, if not make sure " + "you call super().__init__ in the constructor" + ) + assert ( + estimator.is_fitted is False + ), "Estimator is_fitted attribute should be set to False after init" + assert hasattr(estimator, "_tags_dynamic"), ( + "Estimator should have a _tags_dynamic attribute after init, if not make sure " + "you call super().__init__ in the constructor" + ) + assert isinstance( + estimator._tags_dynamic, dict + ), "Estimator _tags_dynamic attribute should be a dict after init" + + # ensure that each parameter is set in init + init_params = signature(estimator_class.__init__).parameters invalid_attr = set(init_params) - set(vars(estimator)) - {"self"} assert not invalid_attr, ( "Estimator %s should store all parameters" @@ -367,60 +367,62 @@ def check_constructor(estimator_class): "attributes `%s`." % (estimator.__class__.__name__, sorted(invalid_attr)) ) - # Ensure that init does nothing but set parameters - # No logic/interaction with other parameters - def param_filter(p): - """Identify hyper parameters of an estimator.""" - return p.name != "self" and p.kind not in [p.VAR_KEYWORD, p.VAR_POSITIONAL] - - init_params = [ - p for p in signature(estimator.__init__).parameters.values() if param_filter(p) + param_values = [ + p + for p in init_params.values() + # dont include self and *args, **kwargs + if p.name != "self" and p.kind not in [p.VAR_KEYWORD, p.VAR_POSITIONAL] ] + required_params_names = [p.name for p in param_values if p.default == p.empty] + default_value_params = [p for p in param_values if p.default != p.empty] params = estimator.get_params() + init_params = {p: params[p] for p in params if p in required_params_names} - test_params = estimator_class._get_test_params() - if isinstance(test_params, list): - test_params = test_params[0] - test_params = test_params.keys() - - init_params = [param for param in init_params if param.name not in test_params] + # default constructed instance except for required parameters + estimator = estimator_class(**init_params) + params = estimator.get_params() - for param in init_params: - assert param.default != param.empty, ( - "parameter `%s` for %s has no default value and is not " - "set in _get_test_params" % (param.name, estimator.__class__.__name__) + for param in default_value_params: + allowed_types = { + str, + int, + float, + bool, + tuple, + type(None), + type, + np.float64, + np.int64, + np.nan, + } + + assert type(param.default) in allowed_types or callable(param.default), ( + f"Default value of parameter {param.name} is not callable or one of " + f"the allowed types: {allowed_types}" ) - if type(param.default) is type: - assert param.default in [np.float64, np.int64] - else: - assert type(param.default) in [ - str, - int, - float, - bool, - tuple, - type(None), - np.float64, - types.FunctionType, - joblib.Memory, - ] param_value = params[param.name] + msg = ( + f"Parameter {param.name} was mutated on init. All parameters must be " + f"stored unchanged." + ) if isinstance(param_value, np.ndarray): - np.testing.assert_array_equal(param_value, param.default) + np.testing.assert_array_equal(param_value, param.default, err_msg=msg) else: - if bool(isinstance(param_value, numbers.Real) and np.isnan(param_value)): - # Allows to set default parameters to np.nan - assert param_value is param.default, param.name + if ( + not isinstance(param_value, numbers.Integral) + and isinstance(param_value, numbers.Real) + and math.isnan(param_value) + ): + # Allows setting default parameters to np.nan + assert param_value is param.default, msg else: - assert param_value == param.default, param.name + assert param_value == param.default, msg def check_estimator_class_tags(estimator_class): """Check conventions on estimator tags for class.""" - # check get_class_tags method is retained from base - assert hasattr(estimator_class, "get_class_tags") all_tags = estimator_class.get_class_tags() assert isinstance(all_tags, dict) assert all(isinstance(key, str) for key in all_tags.keys()) @@ -446,56 +448,38 @@ def check_estimator_class_tags(estimator_class): f"estimator tags." ) + # Must have at least one set to True + multi = estimator_class.get_class_tag(tag_name="capability:multivariate") + uni = estimator_class.get_class_tag(tag_name="capability:univariate") + assert multi or uni, ( + "Estimator must have at least one of capability:multivariate or " + "capability:univariate set to True" + ) + def check_get_params(estimator): """Check that get_params works correctly.""" - params = estimator.get_params() - assert isinstance(params, dict) - check_get_params_invariance(estimator.__class__.__name__, estimator) - - -def check_set_params(estimator): - """Check that set_params works correctly.""" estimator = _clone_estimator(estimator) - params = estimator.get_params() - - assert ( - estimator.set_params(**params) is estimator - ), f"set_params of {type(estimator).__name__} does not return self" - - is_equal, equals_msg = deep_equals(estimator.get_params(), params, return_msg=True) - msg = ( - f"get_params result of {type(estimator).__name__} (x) does not match " - f"what was passed to set_params (y). Reason for discrepancy: {equals_msg}" - ) - assert is_equal, msg + params = estimator.get_params() + assert isinstance(params, dict) -def check_clone(estimator): - """Check that clone method does not raise exceptions and results in a clone. + shallow_params = estimator.get_params(deep=False) + deep_params = estimator.get_params(deep=True) - A clone of an object x is an object that: - * has same class and parameters as x - * is not identical with x - * is unfitted (even if x was fitted) - """ - est_clone = estimator.clone() - assert isinstance(est_clone, type(estimator)) - assert est_clone is not estimator - if hasattr(est_clone, "is_fitted"): - assert not est_clone.is_fitted + assert all(item in deep_params.items() for item in shallow_params.items()) -# todo roll into another test def check_repr(estimator): """Check that __repr__ call to instance does not raise exceptions.""" - repr(estimator) + estimator = _clone_estimator(estimator) + assert isinstance(repr(estimator), str) def check_estimator_tags(estimator): """Check conventions on estimator tags for test objects.""" - # check get_tags method is retained from base - assert hasattr(estimator, "get_tags") + estimator = _clone_estimator(estimator) + all_tags = estimator.get_tags() assert isinstance(all_tags, dict) assert all(isinstance(key, str) for key in all_tags.keys()) @@ -513,7 +497,9 @@ def check_estimator_tags(estimator): def check_dl_constructor_initializes_deeply(estimator): - """Test DL estimators that they pass custom parameters to underlying Network.""" + """Test deep learning estimators pass custom parameters to underlying Network.""" + estimator = _clone_estimator(estimator) + for key, value in estimator.__dict__.items(): assert vars(estimator)[key] == value # some keys are only relevant to the final model (eg: n_epochs) @@ -523,16 +509,10 @@ def check_dl_constructor_initializes_deeply(estimator): def check_non_state_changing_method(estimator, datatype): - """Check that non-state-changing methods behave as per interface contract. - - Check the following contract on non-state-changing methods: - 1. do not change state of the estimator, i.e., any attributes - (including hyper-parameters and fitted parameters) - 2. expected output type of the method matches actual output type - - only for abstract BaseAeonEstimator methods, common to all estimators. - List of BaseAeonEstimator methods tested: get_fitted_params - Subclass specific method outputs are tested in TestAll[estimatortype] class - 3. the state of method arguments does not change + """Check that non-state-changing methods behave correctly. + + Non-state-changing methods should not alter the estimator attributes or the + input arguments. We also check fit does not alter the input arguments here. """ estimator = _clone_estimator(estimator) @@ -566,33 +546,31 @@ def check_non_state_changing_method(estimator, datatype): ) -def check_fit_updates_state(estimator, datatype): +def check_fit_updates_state_and_cloning(estimator, datatype): """Check fit/update state change. - 1. Check estimator_instance calls base class constructor - 2. Check is_fitted attribute is set correctly to False before fit, at init - This is testing base class functionality, but its fast - 3. Check fit returns self - 4. Check is_fitted attribute is updated correctly to True after calling fit - 5. Check estimator hyper parameters are not changed in fit - """ - # Check that fit updates the is-fitted states - estimator = _clone_estimator(estimator) + We test clone here to avoid fitting again in a separate cloning test. - msg = ( - f"{type(estimator).__name__}.__init__ should call " - f"super({type(estimator).__name__}, self).__init__, " - "but that does not seem to be the case. Please ensure to call the " - f"parent class's constructor in {type(estimator).__name__}.__init__" - ) - assert hasattr(estimator, "is_fitted"), msg - - # Check is_fitted attribute is set correctly to False before fit, at init + Tests that: + * clone returns a new unfitted instance of the estimator + * fit returns self + * is_fitted attribute is updated correctly to True after calling fit + * estimator hyper parameters are not changed in fit + """ + # do some basic checks for cloning + estimator_clone = estimator.clone() + assert isinstance( + estimator_clone, type(estimator) + ), "Estimator clone should be of the same type as the original estimator" assert ( - not estimator.is_fitted - ), f"Estimator: {estimator} does not initiate attribute: is_fitted to False" + estimator_clone is not estimator + ), "Estimator clone should not be the same object as the original estimator" + assert ( + estimator_clone.is_fitted is False + ), "Estimator is_fitted attribute should be set to False after cloning and init" # Make a physical copy of the original estimator parameters before fitting. + estimator = estimator_clone original_params = deepcopy(estimator.get_params()) fitted_estimator = _run_estimator_method(estimator, "fit", datatype, "train") @@ -604,7 +582,7 @@ def check_fit_updates_state(estimator, datatype): # Check is_fitted attribute is updated correctly to True after calling fit assert ( - fitted_estimator.is_fitted + fitted_estimator.is_fitted is True ), f"Estimator: {estimator} does not update attribute: is_fitted during fit" # Compare the state of the model parameters with the original parameters @@ -617,28 +595,24 @@ def check_fit_updates_state(estimator, datatype): # that introspects recursively any subobjects to compute a checksum. # The only exception to this rule of immutable constructor parameters # is possible RandomState instance but in this check we explicitly - # fixed the random_state params recursively to be integer seeds. + # fixed the random_state params recursively to be integer seeds via clone. assert joblib.hash(new_value) == joblib.hash(original_value), ( "Estimator %s should not change or mutate " " the parameter %s from %s to %s during fit." % (estimator.__class__.__name__, param_name, original_value, new_value) ) + # check that estimator cloned from fitted estimator is not fitted + estimator_clone = estimator.clone() + assert ( + estimator_clone.is_fitted is False + ), "Estimator is_fitted attribute should be set to False after cloning" -def check_raises_not_fitted_error(estimator, datatype): - """Check exception raised for non-fit method calls to unfitted estimators. - - Tries to run all methods in NON_STATE_CHANGING_METHODS with valid scenario, - but before fit has been called on the estimator. - This should raise a NotFittedError if correctly caught, - normally by a self.check_is_fitted() call in the method's boilerplate. +def check_raises_not_fitted_error(estimator, datatype): + """Check exception raised for non-fit method calls to unfitted estimators.""" + estimator = _clone_estimator(estimator) - Raises - ------ - Exception if NotFittedError is not raised by non-state changing method - """ - # call methods without prior fitting and check that they raise NotFittedError for method in NON_STATE_CHANGING_METHODS: if hasattr(estimator, method): with pytest.raises(NotFittedError, match=r"has not been fitted"): @@ -648,6 +622,7 @@ def check_raises_not_fitted_error(estimator, datatype): def check_persistence_via_pickle(estimator, datatype): """Check that we can pickle all estimators.""" estimator = _clone_estimator(estimator, random_state=0) + _run_estimator_method(estimator, "fit", datatype, "train") results = [] @@ -664,14 +639,12 @@ def check_persistence_via_pickle(estimator, datatype): for method in NON_STATE_CHANGING_METHODS_ARRAYLIKE: if hasattr(estimator, method) and callable(getattr(estimator, method)): output = _run_estimator_method(estimator, method, datatype, "test") - assert_array_almost_equal( output, results[i], err_msg=f"Running {method} after fit twice with test " f"parameters gives different results.", ) - i += 1 @@ -696,48 +669,10 @@ def check_fit_deterministic(estimator, datatype): for method in NON_STATE_CHANGING_METHODS_ARRAYLIKE: if hasattr(estimator, method) and callable(getattr(estimator, method)): output = _run_estimator_method(estimator, method, datatype, "test") - assert_array_almost_equal( output, results[i], err_msg=f"Running {method} after fit twice with test " f"parameters gives different results.", ) - i += 1 - - -# def check_multiprocessing_idempotent(estimator): -# """Test that single and multi-process run results are identical. -# -# Check that running an estimator on a single process is no different to running -# it on multiple processes. We also check that we can set n_jobs=-1 to make use -# of all CPUs. The test is not really necessary though, as we rely on joblib for -# parallelization and can trust that it works as expected. -# """ -# method_nsc = method_nsc_arraylike -# params = estimator_instance.get_params() -# -# if "n_jobs" in params: -# # run on a single process -# # ----------------------- -# estimator = deepcopy(estimator_instance) -# estimator.set_params(n_jobs=1) -# set_random_state(estimator) -# result_single_process = scenario.run( -# estimator, method_sequence=["fit", method_nsc] -# ) -# -# # run on multiple processes -# # ------------------------- -# estimator = deepcopy(estimator_instance) -# estimator.set_params(n_jobs=-1) -# set_random_state(estimator) -# result_multiple_process = scenario.run( -# estimator, method_sequence=["fit", method_nsc] -# ) -# _assert_array_equal( -# result_single_process, -# result_multiple_process, -# err_msg="Results are not equal for n_jobs=1 and n_jobs=-1", -# ) diff --git a/aeon/testing/estimator_checking/_yield_regression_checks.py b/aeon/testing/estimator_checking/_yield_regression_checks.py index 48c5326975..bf6d2cb568 100644 --- a/aeon/testing/estimator_checking/_yield_regression_checks.py +++ b/aeon/testing/estimator_checking/_yield_regression_checks.py @@ -145,9 +145,6 @@ def check_regressor_overrides_and_tags(estimator_class): f"Override _{method} instead." ) - # axis class parameter is for internal use only - assert "axis" not in estimator_class.__dict__ - # Test valid tag for X_inner_type X_inner_type = estimator_class.get_class_tag(tag_name="X_inner_type") if isinstance(X_inner_type, str): @@ -163,11 +160,6 @@ def check_regressor_overrides_and_tags(estimator_class): else: # must be a list assert any([t in valid_unequal_types for t in X_inner_type]) - # Must have at least one set to True - multi = estimator_class.get_class_tag(tag_name="capability:multivariate") - uni = estimator_class.get_class_tag(tag_name="capability:univariate") - assert multi or uni - valid_algorithm_types = [ "distance", "deeplearning", diff --git a/aeon/testing/estimator_checking/tests/test_check_estimator.py b/aeon/testing/estimator_checking/tests/test_check_estimator.py index ed16d3835a..97f3fff55b 100644 --- a/aeon/testing/estimator_checking/tests/test_check_estimator.py +++ b/aeon/testing/estimator_checking/tests/test_check_estimator.py @@ -104,13 +104,13 @@ def test_check_estimator_subset_tests(): tests_to_run = [ "check_get_params", "check_set_params", - "check_clone", + "check_inheritance", ] tests_to_exclude = ["check_set_params"] expected_tests = [ + "check_inheritance(estimator_class=MockClassifier)", "check_get_params(estimator=MockClassifier())", - "check_clone(estimator=MockClassifier())", ] results = check_estimator( diff --git a/aeon/testing/testing_config.py b/aeon/testing/testing_config.py index 7caf96c752..cdc4e7730f 100644 --- a/aeon/testing/testing_config.py +++ b/aeon/testing/testing_config.py @@ -9,7 +9,7 @@ # per os/version default is False, can be set to True by pytest --prtesting True flag PR_TESTING = False -# Exclude estimators here for short term fixes +# exclude estimators here for short term fixes EXCLUDE_ESTIMATORS = [ "ClearSkyTransformer", # See #2071 @@ -76,20 +76,17 @@ "RDSTRegressor": ["check_regressor_against_expected_results"], } -# NON_STATE_CHANGING_METHODS = -# methods that should not change the state of the estimator, that is, they should -# not change fitted parameters or hyper-parameters. They are also the methods that -# "apply" the fitted estimator to data and useful for checking results. + +# estimator methods post-fit that should not change the state of the estimator # non-state-changing methods that return an array-like output NON_STATE_CHANGING_METHODS_ARRAYLIKE = ( "predict", - "predict_var", "predict_proba", - "decision_function", "transform", ) +# all non-state-changing methods NON_STATE_CHANGING_METHODS = NON_STATE_CHANGING_METHODS_ARRAYLIKE + ( "get_fitted_params", )