From adb0dfde1063f6fb64e99b1d182a3c315d4b9cdb Mon Sep 17 00:00:00 2001 From: Tim Head Date: Wed, 16 Oct 2024 07:46:52 +0000 Subject: [PATCH] Run common tests on all estimators --- python/cuml/cuml/internals/utils.py | 266 ++++++++++++++++++++++++++ python/cuml/cuml/tests/test_common.py | 21 +- 2 files changed, 286 insertions(+), 1 deletion(-) create mode 100644 python/cuml/cuml/internals/utils.py diff --git a/python/cuml/cuml/internals/utils.py b/python/cuml/cuml/internals/utils.py new file mode 100644 index 0000000000..32628cb6f9 --- /dev/null +++ b/python/cuml/cuml/internals/utils.py @@ -0,0 +1,266 @@ +"""Utilities to discover cuml estimators.""" + +# This code was taken from scikit-learn and edited for cuml +# Authors: The scikit-learn developers +# SPDX-License-Identifier: BSD-3-Clause +# +# Copyright (c) 2021-2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import inspect +import pkgutil +import sys +import warnings +from functools import wraps +from importlib import import_module +from operator import itemgetter +from pathlib import Path + + +_MODULE_TO_IGNORE = { + "tests", + "base", + "conftest", + "common", + "_thirdparty", + "benchmark", + "dask", + "randomforest_common", + "solvers", + "fil", + "internals", + "tsa", +} + + +def ignore_warnings(obj=None, category=Warning): + """Context manager and decorator to ignore warnings. + + Note: Using this (in both variants) will clear all warnings + from all python modules loaded. In case you need to test + cross-module-warning-logging, this is not your tool of choice. + + Parameters + ---------- + obj : callable, default=None + callable where you want to ignore the warnings. + category : warning class, default=Warning + The category to filter. If Warning, all categories will be muted. + + Examples + -------- + >>> import warnings + >>> from cuml.internals.utils import ignore_warnings + >>> with ignore_warnings(): + ... warnings.warn('buhuhuhu') + + >>> def nasty_warn(): + ... warnings.warn('buhuhuhu') + ... print(42) + + >>> ignore_warnings(nasty_warn)() + 42 + """ + if isinstance(obj, type) and issubclass(obj, Warning): + # Avoid common pitfall of passing category as the first positional + # argument which result in the test not being run + warning_name = obj.__name__ + raise ValueError( + "'obj' should be a callable where you want to ignore warnings. " + "You passed a warning class instead: 'obj={warning_name}'. " + "If you want to pass a warning class to ignore_warnings, " + "you should use 'category={warning_name}'".format( + warning_name=warning_name + ) + ) + elif callable(obj): + return _IgnoreWarnings(category=category)(obj) + else: + return _IgnoreWarnings(category=category) + + +class _IgnoreWarnings: + """Improved and simplified Python warnings context manager and decorator. + + This class allows the user to ignore the warnings raised by a function. + Copied from Python 2.7.5 and modified as required. + + Parameters + ---------- + category : tuple of warning class, default=Warning + The category to filter. By default, all the categories will be muted. + + """ + + def __init__(self, category): + self._record = True + self._module = sys.modules["warnings"] + self._entered = False + self.log = [] + self.category = category + + def __call__(self, fn): + """Decorator to catch and hide warnings without visual nesting.""" + + @wraps(fn) + def wrapper(*args, **kwargs): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", self.category) + return fn(*args, **kwargs) + + return wrapper + + def __repr__(self): + args = [] + if self._record: + args.append("record=True") + if self._module is not sys.modules["warnings"]: + args.append("module=%r" % self._module) + name = type(self).__name__ + return "%s(%s)" % (name, ", ".join(args)) + + def __enter__(self): + if self._entered: + raise RuntimeError("Cannot enter %r twice" % self) + self._entered = True + self._filters = self._module.filters + self._module.filters = self._filters[:] + self._showwarning = self._module.showwarning + warnings.simplefilter("ignore", self.category) + + def __exit__(self, *exc_info): + if not self._entered: + raise RuntimeError("Cannot exit %r without entering first" % self) + self._module.filters = self._filters + self._module.showwarning = self._showwarning + self.log[:] = [] + + +def all_estimators(type_filter=None): + """Get a list of all estimators from `cuml`. + + This function crawls the module and gets all classes that inherit + from BaseEstimator. Classes that are defined in test-modules are not + included. + + Parameters + ---------- + type_filter : {"classifier", "regressor", "cluster", "transformer"} \ + or list of such str, default=None + Which kind of estimators should be returned. If None, no filter is + applied and all estimators are returned. Possible values are + 'classifier', 'regressor', 'cluster' and 'transformer' to get + estimators only of these specific types, or a list of these to + get the estimators that fit at least one of the types. + + Returns + ------- + estimators : list of tuples + List of (name, class), where ``name`` is the class name as string + and ``class`` is the actual type of the class. + + """ + # lazy import to avoid circular imports + from .base import Base as BaseEstimator + from .mixins import ClassifierMixin, RegressorMixin, ClusterMixin + + def is_abstract(c): + if not (hasattr(c, "__abstractmethods__")): + return False + if not len(c.__abstractmethods__): + return False + return True + + all_classes = [] + root = str(Path(__file__).parent.parent) # sklearn package + # Ignore deprecation warnings triggered at import time and from walking + # packages + with ignore_warnings(category=FutureWarning): + for _, module_name, _ in pkgutil.walk_packages( + path=[root], prefix="cuml." + ): + module_parts = module_name.split(".") + if ( + any(part in _MODULE_TO_IGNORE for part in module_parts) + or "._" in module_name + ): + continue + + module = import_module(module_name) + classes_ = inspect.getmembers(module, inspect.isclass) + # Use the __name__ of each class instead of the name used in the module + classes_ = [ + (est_cls.__name__, est_cls) + for name, est_cls in classes_ + if not name.startswith("_") + ] + classes = [] + # A second round of filtering. Needed to make sure classes that are + # defined in ignored modules are skipped even if they are exposed + # via non-ignored modules. + for name, klass in classes_: + module_name_ = klass.__module__ + module_parts_ = module_name_.split(".") + if ( + any(part in _MODULE_TO_IGNORE for part in module_parts_) + or "._" in module_name_ + ): + continue + else: + classes.append((name, klass)) + + all_classes.extend(classes) + + all_classes = set(all_classes) + + estimators = [ + c + for c in all_classes + if (issubclass(c[1], BaseEstimator) and c[0] != "BaseEstimator") + ] + # get rid of abstract base classes + estimators = [c for c in estimators if not is_abstract(c[1])] + + if type_filter is not None: + if not isinstance(type_filter, list): + type_filter = [type_filter] + else: + type_filter = list(type_filter) # copy + filtered_estimators = [] + filters = { + "classifier": ClassifierMixin, + "regressor": RegressorMixin, + # "transformer": TransformerMixin, + "cluster": ClusterMixin, + } + for name, mixin in filters.items(): + if name in type_filter: + type_filter.remove(name) + filtered_estimators.extend( + [est for est in estimators if issubclass(est[1], mixin)] + ) + estimators = filtered_estimators + if type_filter: + raise ValueError( + "Parameter type_filter must be 'classifier', " + "'regressor', 'transformer', 'cluster' or " + "None, got" + f" {repr(type_filter)}." + ) + + # drop duplicates, sort for reproducibility + # itemgetter is used to ensure the sort does not extend to the 2nd item of + # the tuple + return sorted(set(estimators), key=itemgetter(0)) diff --git a/python/cuml/cuml/tests/test_common.py b/python/cuml/cuml/tests/test_common.py index 09da1f5216..2011e8cc08 100644 --- a/python/cuml/cuml/tests/test_common.py +++ b/python/cuml/cuml/tests/test_common.py @@ -16,10 +16,29 @@ from sklearn.utils import estimator_checks +from cuml.internals.utils import all_estimators from cuml import LogisticRegression -@estimator_checks.parametrize_with_checks([LogisticRegression()]) +DEFAULT_PARAMETERS = { + "MulticlassClassifier": dict(estimator=LogisticRegression()), + "OneVsOneClassifier": dict(estimator=LogisticRegression()), + "OneVsRestClassifier": dict(estimator=LogisticRegression()), +} + + +def constructed_estimators(): + """Build list of instances of all estimators in cuml""" + for name, Estimator in all_estimators( + type_filter=["classifier", "regressor", "cluster"] + ): + if name in DEFAULT_PARAMETERS: + yield Estimator(**DEFAULT_PARAMETERS[name]) + else: + yield Estimator() + + +@estimator_checks.parametrize_with_checks(list(constructed_estimators())) def test_sklearn_compatible_estimator(estimator, check): # Check that all estimators pass the "common estimator" checks # provided by scikit-learn