-
Notifications
You must be signed in to change notification settings - Fork 527
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
286 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,266 @@ | ||
"""Utilities to discover cuml estimators.""" | ||
|
||
# This code was taken from scikit-learn and edited for cuml | ||
# Authors: The scikit-learn developers | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# | ||
# Copyright (c) 2021-2024, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
import inspect | ||
import pkgutil | ||
import sys | ||
import warnings | ||
from functools import wraps | ||
from importlib import import_module | ||
from operator import itemgetter | ||
from pathlib import Path | ||
|
||
|
||
_MODULE_TO_IGNORE = { | ||
"tests", | ||
"base", | ||
"conftest", | ||
"common", | ||
"_thirdparty", | ||
"benchmark", | ||
"dask", | ||
"randomforest_common", | ||
"solvers", | ||
"fil", | ||
"internals", | ||
"tsa", | ||
} | ||
|
||
|
||
def ignore_warnings(obj=None, category=Warning): | ||
"""Context manager and decorator to ignore warnings. | ||
Note: Using this (in both variants) will clear all warnings | ||
from all python modules loaded. In case you need to test | ||
cross-module-warning-logging, this is not your tool of choice. | ||
Parameters | ||
---------- | ||
obj : callable, default=None | ||
callable where you want to ignore the warnings. | ||
category : warning class, default=Warning | ||
The category to filter. If Warning, all categories will be muted. | ||
Examples | ||
-------- | ||
>>> import warnings | ||
>>> from cuml.internals.utils import ignore_warnings | ||
>>> with ignore_warnings(): | ||
... warnings.warn('buhuhuhu') | ||
>>> def nasty_warn(): | ||
... warnings.warn('buhuhuhu') | ||
... print(42) | ||
>>> ignore_warnings(nasty_warn)() | ||
42 | ||
""" | ||
if isinstance(obj, type) and issubclass(obj, Warning): | ||
# Avoid common pitfall of passing category as the first positional | ||
# argument which result in the test not being run | ||
warning_name = obj.__name__ | ||
raise ValueError( | ||
"'obj' should be a callable where you want to ignore warnings. " | ||
"You passed a warning class instead: 'obj={warning_name}'. " | ||
"If you want to pass a warning class to ignore_warnings, " | ||
"you should use 'category={warning_name}'".format( | ||
warning_name=warning_name | ||
) | ||
) | ||
elif callable(obj): | ||
return _IgnoreWarnings(category=category)(obj) | ||
else: | ||
return _IgnoreWarnings(category=category) | ||
|
||
|
||
class _IgnoreWarnings: | ||
"""Improved and simplified Python warnings context manager and decorator. | ||
This class allows the user to ignore the warnings raised by a function. | ||
Copied from Python 2.7.5 and modified as required. | ||
Parameters | ||
---------- | ||
category : tuple of warning class, default=Warning | ||
The category to filter. By default, all the categories will be muted. | ||
""" | ||
|
||
def __init__(self, category): | ||
self._record = True | ||
self._module = sys.modules["warnings"] | ||
self._entered = False | ||
self.log = [] | ||
self.category = category | ||
|
||
def __call__(self, fn): | ||
"""Decorator to catch and hide warnings without visual nesting.""" | ||
|
||
@wraps(fn) | ||
def wrapper(*args, **kwargs): | ||
with warnings.catch_warnings(): | ||
warnings.simplefilter("ignore", self.category) | ||
return fn(*args, **kwargs) | ||
|
||
return wrapper | ||
|
||
def __repr__(self): | ||
args = [] | ||
if self._record: | ||
args.append("record=True") | ||
if self._module is not sys.modules["warnings"]: | ||
args.append("module=%r" % self._module) | ||
name = type(self).__name__ | ||
return "%s(%s)" % (name, ", ".join(args)) | ||
|
||
def __enter__(self): | ||
if self._entered: | ||
raise RuntimeError("Cannot enter %r twice" % self) | ||
self._entered = True | ||
self._filters = self._module.filters | ||
self._module.filters = self._filters[:] | ||
self._showwarning = self._module.showwarning | ||
warnings.simplefilter("ignore", self.category) | ||
|
||
def __exit__(self, *exc_info): | ||
if not self._entered: | ||
raise RuntimeError("Cannot exit %r without entering first" % self) | ||
self._module.filters = self._filters | ||
self._module.showwarning = self._showwarning | ||
self.log[:] = [] | ||
|
||
|
||
def all_estimators(type_filter=None): | ||
"""Get a list of all estimators from `cuml`. | ||
This function crawls the module and gets all classes that inherit | ||
from BaseEstimator. Classes that are defined in test-modules are not | ||
included. | ||
Parameters | ||
---------- | ||
type_filter : {"classifier", "regressor", "cluster", "transformer"} \ | ||
or list of such str, default=None | ||
Which kind of estimators should be returned. If None, no filter is | ||
applied and all estimators are returned. Possible values are | ||
'classifier', 'regressor', 'cluster' and 'transformer' to get | ||
estimators only of these specific types, or a list of these to | ||
get the estimators that fit at least one of the types. | ||
Returns | ||
------- | ||
estimators : list of tuples | ||
List of (name, class), where ``name`` is the class name as string | ||
and ``class`` is the actual type of the class. | ||
""" | ||
# lazy import to avoid circular imports | ||
from .base import Base as BaseEstimator | ||
from .mixins import ClassifierMixin, RegressorMixin, ClusterMixin | ||
|
||
def is_abstract(c): | ||
if not (hasattr(c, "__abstractmethods__")): | ||
return False | ||
if not len(c.__abstractmethods__): | ||
return False | ||
return True | ||
|
||
all_classes = [] | ||
root = str(Path(__file__).parent.parent) # sklearn package | ||
# Ignore deprecation warnings triggered at import time and from walking | ||
# packages | ||
with ignore_warnings(category=FutureWarning): | ||
for _, module_name, _ in pkgutil.walk_packages( | ||
path=[root], prefix="cuml." | ||
): | ||
module_parts = module_name.split(".") | ||
if ( | ||
any(part in _MODULE_TO_IGNORE for part in module_parts) | ||
or "._" in module_name | ||
): | ||
continue | ||
|
||
module = import_module(module_name) | ||
classes_ = inspect.getmembers(module, inspect.isclass) | ||
# Use the __name__ of each class instead of the name used in the module | ||
classes_ = [ | ||
(est_cls.__name__, est_cls) | ||
for name, est_cls in classes_ | ||
if not name.startswith("_") | ||
] | ||
classes = [] | ||
# A second round of filtering. Needed to make sure classes that are | ||
# defined in ignored modules are skipped even if they are exposed | ||
# via non-ignored modules. | ||
for name, klass in classes_: | ||
module_name_ = klass.__module__ | ||
module_parts_ = module_name_.split(".") | ||
if ( | ||
any(part in _MODULE_TO_IGNORE for part in module_parts_) | ||
or "._" in module_name_ | ||
): | ||
continue | ||
else: | ||
classes.append((name, klass)) | ||
|
||
all_classes.extend(classes) | ||
|
||
all_classes = set(all_classes) | ||
|
||
estimators = [ | ||
c | ||
for c in all_classes | ||
if (issubclass(c[1], BaseEstimator) and c[0] != "BaseEstimator") | ||
] | ||
# get rid of abstract base classes | ||
estimators = [c for c in estimators if not is_abstract(c[1])] | ||
|
||
if type_filter is not None: | ||
if not isinstance(type_filter, list): | ||
type_filter = [type_filter] | ||
else: | ||
type_filter = list(type_filter) # copy | ||
filtered_estimators = [] | ||
filters = { | ||
"classifier": ClassifierMixin, | ||
"regressor": RegressorMixin, | ||
# "transformer": TransformerMixin, | ||
"cluster": ClusterMixin, | ||
} | ||
for name, mixin in filters.items(): | ||
if name in type_filter: | ||
type_filter.remove(name) | ||
filtered_estimators.extend( | ||
[est for est in estimators if issubclass(est[1], mixin)] | ||
) | ||
estimators = filtered_estimators | ||
if type_filter: | ||
raise ValueError( | ||
"Parameter type_filter must be 'classifier', " | ||
"'regressor', 'transformer', 'cluster' or " | ||
"None, got" | ||
f" {repr(type_filter)}." | ||
) | ||
|
||
# drop duplicates, sort for reproducibility | ||
# itemgetter is used to ensure the sort does not extend to the 2nd item of | ||
# the tuple | ||
return sorted(set(estimators), key=itemgetter(0)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters