Skip to content

Commit

Permalink
Run common tests on all estimators
Browse files Browse the repository at this point in the history
  • Loading branch information
betatim committed Oct 16, 2024
1 parent fde0ee8 commit adb0dfd
Show file tree
Hide file tree
Showing 2 changed files with 286 additions and 1 deletion.
266 changes: 266 additions & 0 deletions python/cuml/cuml/internals/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
"""Utilities to discover cuml estimators."""

# This code was taken from scikit-learn and edited for cuml
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
#
# Copyright (c) 2021-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import inspect
import pkgutil
import sys
import warnings
from functools import wraps
from importlib import import_module
from operator import itemgetter
from pathlib import Path


_MODULE_TO_IGNORE = {
"tests",
"base",
"conftest",
"common",
"_thirdparty",
"benchmark",
"dask",
"randomforest_common",
"solvers",
"fil",
"internals",
"tsa",
}


def ignore_warnings(obj=None, category=Warning):
"""Context manager and decorator to ignore warnings.
Note: Using this (in both variants) will clear all warnings
from all python modules loaded. In case you need to test
cross-module-warning-logging, this is not your tool of choice.
Parameters
----------
obj : callable, default=None
callable where you want to ignore the warnings.
category : warning class, default=Warning
The category to filter. If Warning, all categories will be muted.
Examples
--------
>>> import warnings
>>> from cuml.internals.utils import ignore_warnings
>>> with ignore_warnings():
... warnings.warn('buhuhuhu')
>>> def nasty_warn():
... warnings.warn('buhuhuhu')
... print(42)
>>> ignore_warnings(nasty_warn)()
42
"""
if isinstance(obj, type) and issubclass(obj, Warning):
# Avoid common pitfall of passing category as the first positional
# argument which result in the test not being run
warning_name = obj.__name__
raise ValueError(
"'obj' should be a callable where you want to ignore warnings. "
"You passed a warning class instead: 'obj={warning_name}'. "
"If you want to pass a warning class to ignore_warnings, "
"you should use 'category={warning_name}'".format(
warning_name=warning_name
)
)
elif callable(obj):
return _IgnoreWarnings(category=category)(obj)
else:
return _IgnoreWarnings(category=category)


class _IgnoreWarnings:
"""Improved and simplified Python warnings context manager and decorator.
This class allows the user to ignore the warnings raised by a function.
Copied from Python 2.7.5 and modified as required.
Parameters
----------
category : tuple of warning class, default=Warning
The category to filter. By default, all the categories will be muted.
"""

def __init__(self, category):
self._record = True
self._module = sys.modules["warnings"]
self._entered = False
self.log = []
self.category = category

def __call__(self, fn):
"""Decorator to catch and hide warnings without visual nesting."""

@wraps(fn)
def wrapper(*args, **kwargs):
with warnings.catch_warnings():
warnings.simplefilter("ignore", self.category)
return fn(*args, **kwargs)

return wrapper

def __repr__(self):
args = []
if self._record:
args.append("record=True")
if self._module is not sys.modules["warnings"]:
args.append("module=%r" % self._module)
name = type(self).__name__
return "%s(%s)" % (name, ", ".join(args))

def __enter__(self):
if self._entered:
raise RuntimeError("Cannot enter %r twice" % self)
self._entered = True
self._filters = self._module.filters
self._module.filters = self._filters[:]
self._showwarning = self._module.showwarning
warnings.simplefilter("ignore", self.category)

def __exit__(self, *exc_info):
if not self._entered:
raise RuntimeError("Cannot exit %r without entering first" % self)
self._module.filters = self._filters
self._module.showwarning = self._showwarning
self.log[:] = []


def all_estimators(type_filter=None):
"""Get a list of all estimators from `cuml`.
This function crawls the module and gets all classes that inherit
from BaseEstimator. Classes that are defined in test-modules are not
included.
Parameters
----------
type_filter : {"classifier", "regressor", "cluster", "transformer"} \
or list of such str, default=None
Which kind of estimators should be returned. If None, no filter is
applied and all estimators are returned. Possible values are
'classifier', 'regressor', 'cluster' and 'transformer' to get
estimators only of these specific types, or a list of these to
get the estimators that fit at least one of the types.
Returns
-------
estimators : list of tuples
List of (name, class), where ``name`` is the class name as string
and ``class`` is the actual type of the class.
"""
# lazy import to avoid circular imports
from .base import Base as BaseEstimator
from .mixins import ClassifierMixin, RegressorMixin, ClusterMixin

def is_abstract(c):
if not (hasattr(c, "__abstractmethods__")):
return False
if not len(c.__abstractmethods__):
return False
return True

all_classes = []
root = str(Path(__file__).parent.parent) # sklearn package
# Ignore deprecation warnings triggered at import time and from walking
# packages
with ignore_warnings(category=FutureWarning):
for _, module_name, _ in pkgutil.walk_packages(
path=[root], prefix="cuml."
):
module_parts = module_name.split(".")
if (
any(part in _MODULE_TO_IGNORE for part in module_parts)
or "._" in module_name
):
continue

module = import_module(module_name)
classes_ = inspect.getmembers(module, inspect.isclass)
# Use the __name__ of each class instead of the name used in the module
classes_ = [
(est_cls.__name__, est_cls)
for name, est_cls in classes_
if not name.startswith("_")
]
classes = []
# A second round of filtering. Needed to make sure classes that are
# defined in ignored modules are skipped even if they are exposed
# via non-ignored modules.
for name, klass in classes_:
module_name_ = klass.__module__
module_parts_ = module_name_.split(".")
if (
any(part in _MODULE_TO_IGNORE for part in module_parts_)
or "._" in module_name_
):
continue
else:
classes.append((name, klass))

all_classes.extend(classes)

all_classes = set(all_classes)

estimators = [
c
for c in all_classes
if (issubclass(c[1], BaseEstimator) and c[0] != "BaseEstimator")
]
# get rid of abstract base classes
estimators = [c for c in estimators if not is_abstract(c[1])]

if type_filter is not None:
if not isinstance(type_filter, list):
type_filter = [type_filter]
else:
type_filter = list(type_filter) # copy
filtered_estimators = []
filters = {
"classifier": ClassifierMixin,
"regressor": RegressorMixin,
# "transformer": TransformerMixin,
"cluster": ClusterMixin,
}
for name, mixin in filters.items():
if name in type_filter:
type_filter.remove(name)
filtered_estimators.extend(
[est for est in estimators if issubclass(est[1], mixin)]
)
estimators = filtered_estimators
if type_filter:
raise ValueError(
"Parameter type_filter must be 'classifier', "
"'regressor', 'transformer', 'cluster' or "
"None, got"
f" {repr(type_filter)}."
)

# drop duplicates, sort for reproducibility
# itemgetter is used to ensure the sort does not extend to the 2nd item of
# the tuple
return sorted(set(estimators), key=itemgetter(0))
21 changes: 20 additions & 1 deletion python/cuml/cuml/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,29 @@

from sklearn.utils import estimator_checks

from cuml.internals.utils import all_estimators
from cuml import LogisticRegression


@estimator_checks.parametrize_with_checks([LogisticRegression()])
DEFAULT_PARAMETERS = {
"MulticlassClassifier": dict(estimator=LogisticRegression()),
"OneVsOneClassifier": dict(estimator=LogisticRegression()),
"OneVsRestClassifier": dict(estimator=LogisticRegression()),
}


def constructed_estimators():
"""Build list of instances of all estimators in cuml"""
for name, Estimator in all_estimators(
type_filter=["classifier", "regressor", "cluster"]
):
if name in DEFAULT_PARAMETERS:
yield Estimator(**DEFAULT_PARAMETERS[name])
else:
yield Estimator()


@estimator_checks.parametrize_with_checks(list(constructed_estimators()))
def test_sklearn_compatible_estimator(estimator, check):
# Check that all estimators pass the "common estimator" checks
# provided by scikit-learn
Expand Down

0 comments on commit adb0dfd

Please sign in to comment.