Skip to content

Commit

Permalink
re-orient to new ban
Browse files Browse the repository at this point in the history
  • Loading branch information
icfaust committed Oct 11, 2024
1 parent 76fbf85 commit f9707e9
Showing 1 changed file with 13 additions and 20 deletions.
33 changes: 13 additions & 20 deletions sklearnex/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,18 @@ def wrap(*args, **kwargs):
return wrap


def test_class_trailing_underscore_ban(monkeypatch):
"""Trailing underscores are defined for sklearn to be signatures of a fitted
estimator instance, sklearnex extends this to the classes as well"""
monkeypatch.setattr(pkgutil, "walk_packages", _sklearnex_walk(pkgutil.walk_packages))
estimators = all_estimators() # list of tuples
for name, obj in estimators:
if "preview" not in obj.__module__ and "daal4py" not in obj.__module__:
assert all(
[i.startswith("_") or not i.endswith("_") for i in dir(obj)]
), f"{name} contains class attributes which have a trailing underscore but no leading one"


def test_all_estimators_covered(monkeypatch):
"""Check that all estimators defined in sklearnex are available in either the
patch map or covered in special testing via SPECIAL_INSTANCES. The estimator
Expand Down Expand Up @@ -339,26 +351,7 @@ def n_jobs_check(text, estimator, method):
), f"verify if {method} should be in control_n_jobs' decorated_methods for {estimator}"


def fitted_check(text, estimator, method):
"""The estimator should verify that it has been fitted for any non fit* method"""
# remove the _get_backend function from sklearnex from considered _get_backend
if "fit" in method:
pytest.skip(f"{method} fits the estimator and is exempt from fitted_check")

count = len(
[
i
for i in range(len(text[0]))
if text[0][i] == "check_is_fitted" and "sklearn" in text[2][i]
]
)

assert bool(
count
), f"sklearn's 'check_is_fitted' should be used in {estimator}.{method}"


DESIGN_RULES = [n_jobs_check, fitted_check]
DESIGN_RULES = [n_jobs_check]


if sklearn_check_version("1.0"):
Expand Down

0 comments on commit f9707e9

Please sign in to comment.