Skip to content

Commit

Permalink
[fix] update BasicStatistics and ``IncrementalBasicStatistics
Browse files Browse the repository at this point in the history
…`` to follow additional sklearn conventions (#2038)

* Update test_basic_statistics.py

* Update test_basic_statistics.py

* Update basic_statistics.py

* formatting

* forgotten switch

* handle all better

* make updates to IncBS

* n_features_in_ safety

* convert to 2d data

* add daal_check_version import

* formatting
  • Loading branch information
icfaust authored Sep 23, 2024
1 parent 0b8bec8 commit 07d688b
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 14 deletions.
69 changes: 62 additions & 7 deletions sklearnex/basic_statistics/basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# limitations under the License.
# ==============================================================================

import warnings

import numpy as np
from sklearn.base import BaseEstimator
from sklearn.utils import check_array
Expand All @@ -26,6 +28,9 @@
from .._device_offload import dispatch
from .._utils import PatchingConditionsChain

if sklearn_check_version("1.2"):
from sklearn.utils._param_validation import StrOptions


@control_n_jobs(decorated_methods=["fit"])
class BasicStatistics(BaseEstimator):
Expand Down Expand Up @@ -62,23 +67,69 @@ class BasicStatistics(BaseEstimator):
"""

def __init__(self, result_options="all"):
self.options = result_options
self.result_options = result_options

_onedal_basic_statistics = staticmethod(onedal_BasicStatistics)

if sklearn_check_version("1.2"):
_parameter_constraints: dict = {
"result_options": [
StrOptions(
{
"all",
"min",
"max",
"sum",
"mean",
"variance",
"variation",
"sum_squares",
"standard_deviation",
"sum_squares_centered",
"second_order_raw_moment",
}
),
list,
],
}

def _save_attributes(self):
assert hasattr(self, "_onedal_estimator")

if self.options == "all":
if self.result_options == "all":
result_options = onedal_BasicStatistics.get_all_result_options()
else:
result_options = self.options
result_options = self.result_options

if isinstance(result_options, str):
setattr(self, result_options, getattr(self._onedal_estimator, result_options))
setattr(
self,
result_options + "_",
getattr(self._onedal_estimator, result_options),
)
elif isinstance(result_options, list):
for option in result_options:
setattr(self, option, getattr(self._onedal_estimator, option))
setattr(self, option + "_", getattr(self._onedal_estimator, option))

def __getattr__(self, attr):
if self.result_options == "all":
result_options = onedal_BasicStatistics.get_all_result_options()
else:
result_options = self.result_options
is_deprecated_attr = (
isinstance(result_options, str) and (attr == result_options)
) or (isinstance(result_options, list) and (attr in result_options))
if is_deprecated_attr:
warnings.warn(
"Result attributes without a trailing underscore were deprecated in version 2025.1 and will be removed in 2026.0"
)
attr += "_"
if attr in self.__dict__:
return self.__dict__[attr]

raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{attr}'"
)

def _onedal_supported(self, method_name, *data):
patching_status = PatchingConditionsChain(
Expand All @@ -90,6 +141,9 @@ def _onedal_supported(self, method_name, *data):
_onedal_gpu_supported = _onedal_supported

def _onedal_fit(self, X, sample_weight=None, queue=None):
if sklearn_check_version("1.2"):
self._validate_params()

if sklearn_check_version("1.0"):
X = self._validate_data(X, dtype=[np.float64, np.float32], ensure_2d=False)
else:
Expand All @@ -99,16 +153,17 @@ def _onedal_fit(self, X, sample_weight=None, queue=None):
sample_weight = _check_sample_weight(sample_weight, X)

onedal_params = {
"result_options": self.options,
"result_options": self.result_options,
}

if not hasattr(self, "_onedal_estimator"):
self._onedal_estimator = self._onedal_basic_statistics(**onedal_params)
self._onedal_estimator.fit(X, sample_weight, queue)
self._save_attributes()
self.n_features_in_ = X.shape[1] if len(X.shape) > 1 else 1

def fit(self, X, y=None, *, sample_weight=None):
"""Compute statistics with X, using minibatches of size batch_size.
"""Calculate statistics of X.
Parameters
----------
Expand Down
20 changes: 13 additions & 7 deletions sklearnex/basic_statistics/incremental_basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from sklearn.utils._param_validation import Interval, StrOptions

import numbers
import warnings


@control_n_jobs(decorated_methods=["partial_fit", "_onedal_finalize_fit"])
Expand Down Expand Up @@ -175,6 +176,9 @@ def _onedal_partial_fit(self, X, sample_weight=None, queue=None):
self._need_to_finalize = True

def _onedal_fit(self, X, sample_weight=None, queue=None):
if sklearn_check_version("1.2"):
self._validate_params()

if sklearn_check_version("1.0"):
X = self._validate_data(X, dtype=[np.float64, np.float32])
else:
Expand All @@ -198,9 +202,6 @@ def _onedal_fit(self, X, sample_weight=None, queue=None):
weights_batch = sample_weight[batch] if sample_weight is not None else None
self._onedal_partial_fit(X_batch, weights_batch, queue=queue)

if sklearn_check_version("1.2"):
self._validate_params()

self.n_features_in_ = X.shape[1]

self._onedal_finalize_fit(queue=queue)
Expand All @@ -209,13 +210,18 @@ def _onedal_fit(self, X, sample_weight=None, queue=None):

def __getattr__(self, attr):
result_options = self.__dict__["result_options"]
sattr = attr.removesuffix("_")
is_statistic_attr = (
isinstance(result_options, str) and (attr == result_options)
) or (isinstance(result_options, list) and (attr in result_options))
isinstance(result_options, str) and (sattr == result_options)
) or (isinstance(result_options, list) and (sattr in result_options))
if is_statistic_attr:
if self._need_to_finalize:
self._onedal_finalize_fit()
return getattr(self._onedal_estimator, attr)
if sattr == attr:
warnings.warn(
"Result attributes without a trailing underscore were deprecated in version 2025.1 and will be removed in 2026.0"
)
return getattr(self._onedal_estimator, sattr)
if attr in self.__dict__:
return self.__dict__[attr]

Expand Down Expand Up @@ -256,7 +262,7 @@ def partial_fit(self, X, sample_weight=None):
return self

def fit(self, X, y=None, sample_weight=None):
"""Compute statistics with X, using minibatches of size batch_size.
"""Calculate statistics of X using minibatches of size batch_size.
Parameters
----------
Expand Down
19 changes: 19 additions & 0 deletions sklearnex/basic_statistics/tests/test_basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pytest
from numpy.testing import assert_allclose

from daal4py.sklearn._utils import daal_check_version
from onedal.basic_statistics.tests.test_basic_statistics import (
expected_max,
expected_mean,
Expand Down Expand Up @@ -249,3 +250,21 @@ def test_1d_input_on_random_data(dataframe, queue, option, data_size, weighted,

tol = fp32tol if res.dtype == np.float32 else fp64tol
assert_allclose(gtr, res, atol=tol)


def test_warning():
basicstat = BasicStatistics("all")
data = np.array([0, 1])

basicstat.fit(data)
for i in basicstat._onedal_estimator.get_all_result_options():
with pytest.warns(
UserWarning,
match="Result attributes without a trailing underscore were deprecated in version 2025.1 and will be removed in 2026.0",
) as warn_record:
getattr(basicstat, i)

if daal_check_version((2026, "P", 0)):
assert len(warn_record) == 0, i
else:
assert len(warn_record) == 1, i
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pytest
from numpy.testing import assert_allclose

from daal4py.sklearn._utils import daal_check_version
from onedal.basic_statistics.tests.test_basic_statistics import (
expected_max,
expected_mean,
Expand Down Expand Up @@ -382,3 +383,22 @@ def test_fit_all_option_on_random_data(
gtr = function(X)
tol = fp32tol if res.dtype == np.float32 else fp64tol
assert_allclose(gtr, res, atol=tol)


def test_warning():
basicstat = IncrementalBasicStatistics("all")
# Only 2d inputs supported into IncrementalBasicStatistics
data = np.array([[0.0], [1.0]])

basicstat.fit(data)
for i in basicstat._onedal_estimator.get_all_result_options():
with pytest.warns(
UserWarning,
match="Result attributes without a trailing underscore were deprecated in version 2025.1 and will be removed in 2026.0",
) as warn_record:
getattr(basicstat, i)

if daal_check_version((2026, "P", 0)):
assert len(warn_record) == 0, i
else:
assert len(warn_record) == 1, i

0 comments on commit 07d688b

Please sign in to comment.