From 07d688b1ef7ba87769a7c938c4126b48286bcaf5 Mon Sep 17 00:00:00 2001 From: Ian Faust Date: Mon, 23 Sep 2024 17:22:57 +0200 Subject: [PATCH] [fix] update ```BasicStatistics``` and ```IncrementalBasicStatistics``` to follow additional sklearn conventions (#2038) * Update test_basic_statistics.py * Update test_basic_statistics.py * Update basic_statistics.py * formatting * forgotten switch * handle all better * make updates to IncBS * n_features_in_ safety * convert to 2d data * add daal_check_version import * formatting --- .../basic_statistics/basic_statistics.py | 69 +++++++++++++++++-- .../incremental_basic_statistics.py | 20 ++++-- .../tests/test_basic_statistics.py | 19 +++++ .../test_incremental_basic_statistics.py | 20 ++++++ 4 files changed, 114 insertions(+), 14 deletions(-) diff --git a/sklearnex/basic_statistics/basic_statistics.py b/sklearnex/basic_statistics/basic_statistics.py index 546d52b5b3..8a53e5898b 100644 --- a/sklearnex/basic_statistics/basic_statistics.py +++ b/sklearnex/basic_statistics/basic_statistics.py @@ -14,6 +14,8 @@ # limitations under the License. # ============================================================================== +import warnings + import numpy as np from sklearn.base import BaseEstimator from sklearn.utils import check_array @@ -26,6 +28,9 @@ from .._device_offload import dispatch from .._utils import PatchingConditionsChain +if sklearn_check_version("1.2"): + from sklearn.utils._param_validation import StrOptions + @control_n_jobs(decorated_methods=["fit"]) class BasicStatistics(BaseEstimator): @@ -62,23 +67,69 @@ class BasicStatistics(BaseEstimator): """ def __init__(self, result_options="all"): - self.options = result_options + self.result_options = result_options _onedal_basic_statistics = staticmethod(onedal_BasicStatistics) + if sklearn_check_version("1.2"): + _parameter_constraints: dict = { + "result_options": [ + StrOptions( + { + "all", + "min", + "max", + "sum", + "mean", + "variance", + "variation", + "sum_squares", + "standard_deviation", + "sum_squares_centered", + "second_order_raw_moment", + } + ), + list, + ], + } + def _save_attributes(self): assert hasattr(self, "_onedal_estimator") - if self.options == "all": + if self.result_options == "all": result_options = onedal_BasicStatistics.get_all_result_options() else: - result_options = self.options + result_options = self.result_options if isinstance(result_options, str): - setattr(self, result_options, getattr(self._onedal_estimator, result_options)) + setattr( + self, + result_options + "_", + getattr(self._onedal_estimator, result_options), + ) elif isinstance(result_options, list): for option in result_options: - setattr(self, option, getattr(self._onedal_estimator, option)) + setattr(self, option + "_", getattr(self._onedal_estimator, option)) + + def __getattr__(self, attr): + if self.result_options == "all": + result_options = onedal_BasicStatistics.get_all_result_options() + else: + result_options = self.result_options + is_deprecated_attr = ( + isinstance(result_options, str) and (attr == result_options) + ) or (isinstance(result_options, list) and (attr in result_options)) + if is_deprecated_attr: + warnings.warn( + "Result attributes without a trailing underscore were deprecated in version 2025.1 and will be removed in 2026.0" + ) + attr += "_" + if attr in self.__dict__: + return self.__dict__[attr] + + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{attr}'" + ) def _onedal_supported(self, method_name, *data): patching_status = PatchingConditionsChain( @@ -90,6 +141,9 @@ def _onedal_supported(self, method_name, *data): _onedal_gpu_supported = _onedal_supported def _onedal_fit(self, X, sample_weight=None, queue=None): + if sklearn_check_version("1.2"): + self._validate_params() + if sklearn_check_version("1.0"): X = self._validate_data(X, dtype=[np.float64, np.float32], ensure_2d=False) else: @@ -99,16 +153,17 @@ def _onedal_fit(self, X, sample_weight=None, queue=None): sample_weight = _check_sample_weight(sample_weight, X) onedal_params = { - "result_options": self.options, + "result_options": self.result_options, } if not hasattr(self, "_onedal_estimator"): self._onedal_estimator = self._onedal_basic_statistics(**onedal_params) self._onedal_estimator.fit(X, sample_weight, queue) self._save_attributes() + self.n_features_in_ = X.shape[1] if len(X.shape) > 1 else 1 def fit(self, X, y=None, *, sample_weight=None): - """Compute statistics with X, using minibatches of size batch_size. + """Calculate statistics of X. Parameters ---------- diff --git a/sklearnex/basic_statistics/incremental_basic_statistics.py b/sklearnex/basic_statistics/incremental_basic_statistics.py index 2ffa143421..ca0bf2db60 100644 --- a/sklearnex/basic_statistics/incremental_basic_statistics.py +++ b/sklearnex/basic_statistics/incremental_basic_statistics.py @@ -32,6 +32,7 @@ from sklearn.utils._param_validation import Interval, StrOptions import numbers +import warnings @control_n_jobs(decorated_methods=["partial_fit", "_onedal_finalize_fit"]) @@ -175,6 +176,9 @@ def _onedal_partial_fit(self, X, sample_weight=None, queue=None): self._need_to_finalize = True def _onedal_fit(self, X, sample_weight=None, queue=None): + if sklearn_check_version("1.2"): + self._validate_params() + if sklearn_check_version("1.0"): X = self._validate_data(X, dtype=[np.float64, np.float32]) else: @@ -198,9 +202,6 @@ def _onedal_fit(self, X, sample_weight=None, queue=None): weights_batch = sample_weight[batch] if sample_weight is not None else None self._onedal_partial_fit(X_batch, weights_batch, queue=queue) - if sklearn_check_version("1.2"): - self._validate_params() - self.n_features_in_ = X.shape[1] self._onedal_finalize_fit(queue=queue) @@ -209,13 +210,18 @@ def _onedal_fit(self, X, sample_weight=None, queue=None): def __getattr__(self, attr): result_options = self.__dict__["result_options"] + sattr = attr.removesuffix("_") is_statistic_attr = ( - isinstance(result_options, str) and (attr == result_options) - ) or (isinstance(result_options, list) and (attr in result_options)) + isinstance(result_options, str) and (sattr == result_options) + ) or (isinstance(result_options, list) and (sattr in result_options)) if is_statistic_attr: if self._need_to_finalize: self._onedal_finalize_fit() - return getattr(self._onedal_estimator, attr) + if sattr == attr: + warnings.warn( + "Result attributes without a trailing underscore were deprecated in version 2025.1 and will be removed in 2026.0" + ) + return getattr(self._onedal_estimator, sattr) if attr in self.__dict__: return self.__dict__[attr] @@ -256,7 +262,7 @@ def partial_fit(self, X, sample_weight=None): return self def fit(self, X, y=None, sample_weight=None): - """Compute statistics with X, using minibatches of size batch_size. + """Calculate statistics of X using minibatches of size batch_size. Parameters ---------- diff --git a/sklearnex/basic_statistics/tests/test_basic_statistics.py b/sklearnex/basic_statistics/tests/test_basic_statistics.py index 8abbd6db1d..125402f4e2 100644 --- a/sklearnex/basic_statistics/tests/test_basic_statistics.py +++ b/sklearnex/basic_statistics/tests/test_basic_statistics.py @@ -18,6 +18,7 @@ import pytest from numpy.testing import assert_allclose +from daal4py.sklearn._utils import daal_check_version from onedal.basic_statistics.tests.test_basic_statistics import ( expected_max, expected_mean, @@ -249,3 +250,21 @@ def test_1d_input_on_random_data(dataframe, queue, option, data_size, weighted, tol = fp32tol if res.dtype == np.float32 else fp64tol assert_allclose(gtr, res, atol=tol) + + +def test_warning(): + basicstat = BasicStatistics("all") + data = np.array([0, 1]) + + basicstat.fit(data) + for i in basicstat._onedal_estimator.get_all_result_options(): + with pytest.warns( + UserWarning, + match="Result attributes without a trailing underscore were deprecated in version 2025.1 and will be removed in 2026.0", + ) as warn_record: + getattr(basicstat, i) + + if daal_check_version((2026, "P", 0)): + assert len(warn_record) == 0, i + else: + assert len(warn_record) == 1, i diff --git a/sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py b/sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py index 0931e4b524..f5e8bf63e8 100644 --- a/sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +++ b/sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py @@ -18,6 +18,7 @@ import pytest from numpy.testing import assert_allclose +from daal4py.sklearn._utils import daal_check_version from onedal.basic_statistics.tests.test_basic_statistics import ( expected_max, expected_mean, @@ -382,3 +383,22 @@ def test_fit_all_option_on_random_data( gtr = function(X) tol = fp32tol if res.dtype == np.float32 else fp64tol assert_allclose(gtr, res, atol=tol) + + +def test_warning(): + basicstat = IncrementalBasicStatistics("all") + # Only 2d inputs supported into IncrementalBasicStatistics + data = np.array([[0.0], [1.0]]) + + basicstat.fit(data) + for i in basicstat._onedal_estimator.get_all_result_options(): + with pytest.warns( + UserWarning, + match="Result attributes without a trailing underscore were deprecated in version 2025.1 and will be removed in 2026.0", + ) as warn_record: + getattr(basicstat, i) + + if daal_check_version((2026, "P", 0)): + assert len(warn_record) == 0, i + else: + assert len(warn_record) == 1, i