From 12fb8a48392ecb97c9483231ff46a50f30b0c585 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 25 Jul 2024 18:39:25 +0100 Subject: [PATCH] [MNT] sync differential testing utilities with `sktime` (#434) This PR syncs differential testing utilities with `sktime` anticipating a joint refactor to `skbase`. Counterpart of https://github.com/sktime/sktime/pull/6840 Modules affected: * `skpro.tests.test_switch` * `skpro.utils.git_diff` --- skpro/tests/test_switch.py | 68 ++++++++++++++++++++++++++++++-------- skpro/utils/git_diff.py | 7 +++- 2 files changed, 60 insertions(+), 15 deletions(-) diff --git a/skpro/tests/test_switch.py b/skpro/tests/test_switch.py index f87fff4b..c309eb83 100644 --- a/skpro/tests/test_switch.py +++ b/skpro/tests/test_switch.py @@ -1,6 +1,8 @@ # copyright: skpro developers, BSD-3-Clause License (see LICENSE file) -# based on utility from sktime of the same name -"""Switch utility for determining whether tests for a class should be run or not.""" +"""Switch utility for determining whether tests for a class should be run or not. + +Module does not contain tests, only test utilities. +""" __author__ = ["fkiraly"] @@ -27,17 +29,17 @@ def run_test_for_class(cls, return_reason=False): 2. Condition 2: If the module containing the class/func has changed according to is_class_changed, - or one of the modules containing any parent classes in sktime, + or one of the modules containing any parent classes in the local package, then condition 2 is met. 3. Condition 3: - If the object is an sktime ``BaseObject``, and one of the test classes + If the object is an skpro ``BaseObject``, and one of the test classes covering the class have changed, then condition 3 is met. 4. Condition 4: - If the object is an sktime ``BaseObject``, and the package requirements + If the object is an skpro ``BaseObject``, and the package requirements for any of its dependencies have changed in ``pyproject.toml``, condition 4 is met. @@ -51,7 +53,7 @@ def run_test_for_class(cls, return_reason=False): If ``ONLY_CHANGED_MODULES`` is False, this condition is always True. Also checks whether the class or function is on the exclude override list, - EXCLUDE_ESTIMATORS in sktime.tests._config (a list of strings, of names). + EXCLUDE_ESTIMATORS in skpro.tests._config (a list of strings, of names). If so, the tests are always skipped, irrespective of the other conditions. Parameters @@ -160,7 +162,6 @@ class for which to determine whether it should be tested If multiple reasons are present, the first one in the above list is returned. """ - from skpro.tests.test_all_estimators import ONLY_CHANGED_MODULES from skpro.utils.git_diff import get_packages_with_changed_specs, is_class_changed from skpro.utils.validation._dependencies import _check_estimator_deps @@ -174,18 +175,18 @@ def _required_deps_present(obj): else: return True - def _is_class_changed_or_parents(cls): - """Check if class or any of its sktime parents have changed, return bool.""" + def _is_class_changed_or_local_parents(cls): + """Check if class or any of its local parents have changed, return bool.""" # if cls is a function, not a class, default to is_class_changed if not isclass(cls): return is_class_changed(cls) # now we know cls is a class, so has an mro cls_and_parents = getmro(cls) - cls_and_parents = [ + cls_and_local_parents = [ x for x in cls_and_parents if x.__module__.startswith(LOCAL_PACKAGE) ] - return any(is_class_changed(x) for x in cls_and_parents) + return any(is_class_changed(x) for x in cls_and_local_parents) def _tests_covering_class_changed(cls): """Check if any of the tests covering cls have changed, return bool.""" @@ -231,7 +232,7 @@ def _is_impacted_by_pyproject_change(cls): return True, "True_pyproject_change" # Condition 3: - # if the object is an sktime BaseObject, and one of the test classes + # if the object is an skpro BaseObject, and one of the test classes # covering the class have changed, then run the test cond3 = _tests_covering_class_changed(cls) if cond3: @@ -239,11 +240,50 @@ def _is_impacted_by_pyproject_change(cls): # Condition 2: # any of the modules containing any of the classes in the list have changed - # or any of the modules containing any parent classes in sktime have changed - cond2 = _is_class_changed_or_parents(cls) + # or any of the modules containing any parent classes in local package have changed + cond2 = _is_class_changed_or_local_parents(cls) if cond2: return True, "True_changed_class" # if none of the conditions are met, do not run the test # reason is that there was no change return False, "False_no_change" + + +def run_test_module_changed(module): + """Check if test should run based on module changes + + This switch can be used to decorate tests not pertaining to a specific class. + + The function can be used to switch tests on and off + based on whether a target module has changed. + + This checks whether the module ``module``, or any of its child modules, + have changed. + + If ``ONLY_CHANGED_MODULES`` is False, the test is always run, + i.e., this function always returns True. + + Parameters + ---------- + module : string, or list of strings + modules to check for changes, e.g., ``skpro.regression`` + + Returns + ------- + bool : switch to run or skip the test + True iff: at least one of the modules or its submodules have changed, + or if ``ONLY_CHANGED_MODULES`` is False + """ + from skpro.tests.test_all_estimators import ONLY_CHANGED_MODULES + from skpro.utils.git_diff import is_module_changed + + # if ONLY_CHANGED_MODULES is off: always True + # tests are always run if soft dependencies are present + if not ONLY_CHANGED_MODULES: + return True + + if not isinstance(module, (list, tuple)): + module = [module] + + return any(is_module_changed(mod) for mod in module) diff --git a/skpro/utils/git_diff.py b/skpro/utils/git_diff.py index 0d53af9c..8e3987af 100644 --- a/skpro/utils/git_diff.py +++ b/skpro/utils/git_diff.py @@ -45,7 +45,10 @@ def get_path_from_module(module_str): raise ImportError( f"Error in get_path_from_module, module '{module_str}' not found." ) - return module_spec.origin + module_path = module_spec.origin + if module_path.endswith("__init__.py"): + return module_path[:-11] + return module_path except Exception as e: raise ImportError(f"Error finding module '{module_str}'") from e @@ -54,6 +57,8 @@ def get_path_from_module(module_str): def is_module_changed(module_str): """Check if a module has changed compared to the main branch. + If a child module has changed, the parent module is considered changed as well. + Parameters ---------- module_str : str