From d8bf640af2e3920fefed186be9c6c1d0e8cd7171 Mon Sep 17 00:00:00 2001 From: Md Shafiul Alam Date: Tue, 8 Oct 2024 09:27:47 -0400 Subject: [PATCH 01/14] basic-stat-sparse-test --- .../basic_statistics/basic_statistics.py | 2 +- .../tests/test_basic_statistics.py | 45 +++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/sklearnex/basic_statistics/basic_statistics.py b/sklearnex/basic_statistics/basic_statistics.py index 8b1dc3d02a..15d73416ce 100644 --- a/sklearnex/basic_statistics/basic_statistics.py +++ b/sklearnex/basic_statistics/basic_statistics.py @@ -150,7 +150,7 @@ def _onedal_fit(self, X, sample_weight=None, queue=None): self._validate_params() if sklearn_check_version("1.0"): - X = validate_data(self, X, dtype=[np.float64, np.float32], ensure_2d=False) + X = validate_data(self, X, dtype=[np.float64, np.float32], ensure_2d=False, accept_sparse="csr") else: X = check_array(X, dtype=[np.float64, np.float32]) diff --git a/sklearnex/basic_statistics/tests/test_basic_statistics.py b/sklearnex/basic_statistics/tests/test_basic_statistics.py index 125402f4e2..afa9ee7641 100644 --- a/sklearnex/basic_statistics/tests/test_basic_statistics.py +++ b/sklearnex/basic_statistics/tests/test_basic_statistics.py @@ -178,6 +178,51 @@ def test_multiple_options_on_random_data( assert_allclose(gtr_sum, res_sum, atol=tol) +@pytest.mark.parametrize("queue", get_queues()) +@pytest.mark.parametrize("row_count", [100, 1000]) +@pytest.mark.parametrize("column_count", [10, 100]) +@pytest.mark.parametrize("weighted", [True, False]) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_multiple_options_on_random_sparse_data( + queue, row_count, column_count, weighted, dtype +): + seed = 77 + gen = np.random.default_rng(seed) + X_dense, _ = make_blobs(n_samples=row_count, n_features=column_count, random_state=random_state) + density = 0.05 + X_sparse = csr_matrix(X_dense * (np.random.rand(*X_dense.shape) < density)) + + if weighted: + weights = gen.uniform(low=-0.5, high=1.0, size=row_count) + weights = weights.astype(dtype=dtype) + basicstat = BasicStatistics(result_options=["mean", "max", "sum"]) + + if weighted: + result = basicstat.fit(X_sparse, sample_weight=weights) + else: + result = basicstat.fit(X_sparse) + + res_mean, res_max, res_sum = result.mean, result.max, result.sum + if weighted: + weighted_data = np.diag(weights) @ X_dense + gtr_mean, gtr_max, gtr_sum = ( + expected_mean(weighted_data), + expected_max(weighted_data), + expected_sum(weighted_data), + ) + else: + gtr_mean, gtr_max, gtr_sum = ( + expected_mean(X_dense), + expected_max(X_dense), + expected_sum(X_dense), + ) + + tol = 5e-4 if res_mean.dtype == np.float32 else 1e-7 + assert_allclose(gtr_mean, res_mean, atol=tol) + assert_allclose(gtr_max, res_max, atol=tol) + assert_allclose(gtr_sum, res_sum, atol=tol) + + @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues()) @pytest.mark.parametrize("row_count", [100, 1000]) @pytest.mark.parametrize("column_count", [10, 100]) From f1fb204dc6498a2f05ea15f3e64ebc10fb8e62e3 Mon Sep 17 00:00:00 2001 From: Md Shafiul Alam Date: Tue, 8 Oct 2024 09:33:25 -0400 Subject: [PATCH 02/14] lint --- sklearnex/basic_statistics/basic_statistics.py | 8 +++++++- sklearnex/basic_statistics/tests/test_basic_statistics.py | 6 ++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/sklearnex/basic_statistics/basic_statistics.py b/sklearnex/basic_statistics/basic_statistics.py index 15d73416ce..396f0a5544 100644 --- a/sklearnex/basic_statistics/basic_statistics.py +++ b/sklearnex/basic_statistics/basic_statistics.py @@ -150,7 +150,13 @@ def _onedal_fit(self, X, sample_weight=None, queue=None): self._validate_params() if sklearn_check_version("1.0"): - X = validate_data(self, X, dtype=[np.float64, np.float32], ensure_2d=False, accept_sparse="csr") + X = validate_data( + self, + X, + dtype=[np.float64, np.float32], + ensure_2d=False, + accept_sparse="csr", + ) else: X = check_array(X, dtype=[np.float64, np.float32]) diff --git a/sklearnex/basic_statistics/tests/test_basic_statistics.py b/sklearnex/basic_statistics/tests/test_basic_statistics.py index afa9ee7641..6ac4fde5f2 100644 --- a/sklearnex/basic_statistics/tests/test_basic_statistics.py +++ b/sklearnex/basic_statistics/tests/test_basic_statistics.py @@ -188,10 +188,12 @@ def test_multiple_options_on_random_sparse_data( ): seed = 77 gen = np.random.default_rng(seed) - X_dense, _ = make_blobs(n_samples=row_count, n_features=column_count, random_state=random_state) + X_dense, _ = make_blobs( + n_samples=row_count, n_features=column_count, random_state=random_state + ) density = 0.05 X_sparse = csr_matrix(X_dense * (np.random.rand(*X_dense.shape) < density)) - + if weighted: weights = gen.uniform(low=-0.5, high=1.0, size=row_count) weights = weights.astype(dtype=dtype) From 49804f9ac0c0a4d3aabeae0eaab34d76e8601c96 Mon Sep 17 00:00:00 2001 From: Md Shafiul Alam Date: Tue, 8 Oct 2024 09:53:18 -0400 Subject: [PATCH 03/14] fix --- sklearnex/basic_statistics/tests/test_basic_statistics.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearnex/basic_statistics/tests/test_basic_statistics.py b/sklearnex/basic_statistics/tests/test_basic_statistics.py index 6ac4fde5f2..f31d678d94 100644 --- a/sklearnex/basic_statistics/tests/test_basic_statistics.py +++ b/sklearnex/basic_statistics/tests/test_basic_statistics.py @@ -28,6 +28,7 @@ from onedal.tests.utils._dataframes_support import ( _convert_to_dataframe, get_dataframes_and_queues, + get_queues, ) from sklearnex.basic_statistics import BasicStatistics From 353cf0c89f23e6f5bca058085caab95c87520b2a Mon Sep 17 00:00:00 2001 From: Md Shafiul Alam Date: Tue, 8 Oct 2024 10:31:20 -0400 Subject: [PATCH 04/14] import --- sklearnex/basic_statistics/tests/test_basic_statistics.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearnex/basic_statistics/tests/test_basic_statistics.py b/sklearnex/basic_statistics/tests/test_basic_statistics.py index f31d678d94..bc0e13aadb 100644 --- a/sklearnex/basic_statistics/tests/test_basic_statistics.py +++ b/sklearnex/basic_statistics/tests/test_basic_statistics.py @@ -17,6 +17,7 @@ import numpy as np import pytest from numpy.testing import assert_allclose +from sklearn.datasets import make_blobs from daal4py.sklearn._utils import daal_check_version from onedal.basic_statistics.tests.test_basic_statistics import ( From 7df2a3fb3b32de5d56e8925548be0724a9bac5a3 Mon Sep 17 00:00:00 2001 From: Md Shafiul Alam Date: Tue, 8 Oct 2024 13:28:06 -0400 Subject: [PATCH 05/14] minor fix --- sklearnex/basic_statistics/tests/test_basic_statistics.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearnex/basic_statistics/tests/test_basic_statistics.py b/sklearnex/basic_statistics/tests/test_basic_statistics.py index bc0e13aadb..d9bb87493d 100644 --- a/sklearnex/basic_statistics/tests/test_basic_statistics.py +++ b/sklearnex/basic_statistics/tests/test_basic_statistics.py @@ -189,6 +189,7 @@ def test_multiple_options_on_random_sparse_data( queue, row_count, column_count, weighted, dtype ): seed = 77 + random_state = 42 gen = np.random.default_rng(seed) X_dense, _ = make_blobs( n_samples=row_count, n_features=column_count, random_state=random_state From 793ab2388b42a07c6ea09d25b1df93d316b8dfa4 Mon Sep 17 00:00:00 2001 From: Md Shafiul Alam Date: Tue, 8 Oct 2024 14:04:48 -0400 Subject: [PATCH 06/14] import --- sklearnex/basic_statistics/tests/test_basic_statistics.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearnex/basic_statistics/tests/test_basic_statistics.py b/sklearnex/basic_statistics/tests/test_basic_statistics.py index d9bb87493d..f75745d255 100644 --- a/sklearnex/basic_statistics/tests/test_basic_statistics.py +++ b/sklearnex/basic_statistics/tests/test_basic_statistics.py @@ -17,6 +17,7 @@ import numpy as np import pytest from numpy.testing import assert_allclose +from scipy.sparse import csr_matrix from sklearn.datasets import make_blobs from daal4py.sklearn._utils import daal_check_version From f5210ea70d83f6876ed84f539c8d8d3eb6334c99 Mon Sep 17 00:00:00 2001 From: Md Shafiul Alam Date: Tue, 8 Oct 2024 14:42:35 -0400 Subject: [PATCH 07/14] add dense --- sklearnex/basic_statistics/tests/test_basic_statistics.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sklearnex/basic_statistics/tests/test_basic_statistics.py b/sklearnex/basic_statistics/tests/test_basic_statistics.py index f75745d255..4a18ffb48f 100644 --- a/sklearnex/basic_statistics/tests/test_basic_statistics.py +++ b/sklearnex/basic_statistics/tests/test_basic_statistics.py @@ -192,11 +192,12 @@ def test_multiple_options_on_random_sparse_data( seed = 77 random_state = 42 gen = np.random.default_rng(seed) - X_dense, _ = make_blobs( + X, _ = make_blobs( n_samples=row_count, n_features=column_count, random_state=random_state ) density = 0.05 - X_sparse = csr_matrix(X_dense * (np.random.rand(*X_dense.shape) < density)) + X_sparse = csr_matrix(X * (np.random.rand(*X.shape) < density)) + X_dense = X_sparse.toarray() if weighted: weights = gen.uniform(low=-0.5, high=1.0, size=row_count) From 168a8978be9c73fb586aab5be443c9211e7dee80 Mon Sep 17 00:00:00 2001 From: Md Shafiul Alam Date: Wed, 9 Oct 2024 05:02:35 -0400 Subject: [PATCH 08/14] exclude max --- .../basic_statistics/tests/test_basic_statistics.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/sklearnex/basic_statistics/tests/test_basic_statistics.py b/sklearnex/basic_statistics/tests/test_basic_statistics.py index 4a18ffb48f..8cfca88fdc 100644 --- a/sklearnex/basic_statistics/tests/test_basic_statistics.py +++ b/sklearnex/basic_statistics/tests/test_basic_statistics.py @@ -202,7 +202,7 @@ def test_multiple_options_on_random_sparse_data( if weighted: weights = gen.uniform(low=-0.5, high=1.0, size=row_count) weights = weights.astype(dtype=dtype) - basicstat = BasicStatistics(result_options=["mean", "max", "sum"]) + basicstat = BasicStatistics(result_options=["mean", "sum"]) if weighted: result = basicstat.fit(X_sparse, sample_weight=weights) @@ -212,21 +212,18 @@ def test_multiple_options_on_random_sparse_data( res_mean, res_max, res_sum = result.mean, result.max, result.sum if weighted: weighted_data = np.diag(weights) @ X_dense - gtr_mean, gtr_max, gtr_sum = ( + gtr_mean, gtr_sum = ( expected_mean(weighted_data), - expected_max(weighted_data), expected_sum(weighted_data), ) else: - gtr_mean, gtr_max, gtr_sum = ( + gtr_mean, gtr_sum = ( expected_mean(X_dense), - expected_max(X_dense), expected_sum(X_dense), ) tol = 5e-4 if res_mean.dtype == np.float32 else 1e-7 assert_allclose(gtr_mean, res_mean, atol=tol) - assert_allclose(gtr_max, res_max, atol=tol) assert_allclose(gtr_sum, res_sum, atol=tol) From 8f50d854ec43a010948e704f63996dd1fa070fdf Mon Sep 17 00:00:00 2001 From: Md Shafiul Alam Date: Wed, 9 Oct 2024 16:28:55 -0400 Subject: [PATCH 09/14] remove test for max --- sklearnex/basic_statistics/tests/test_basic_statistics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearnex/basic_statistics/tests/test_basic_statistics.py b/sklearnex/basic_statistics/tests/test_basic_statistics.py index 8cfca88fdc..21b16038fc 100644 --- a/sklearnex/basic_statistics/tests/test_basic_statistics.py +++ b/sklearnex/basic_statistics/tests/test_basic_statistics.py @@ -209,7 +209,7 @@ def test_multiple_options_on_random_sparse_data( else: result = basicstat.fit(X_sparse) - res_mean, res_max, res_sum = result.mean, result.max, result.sum + res_mean, res_sum = result.mean, result.sum if weighted: weighted_data = np.diag(weights) @ X_dense gtr_mean, gtr_sum = ( From ecd34e3888d2cd71427b7cf7cb484e4381b14ac1 Mon Sep 17 00:00:00 2001 From: Md Shafiul Alam Date: Wed, 9 Oct 2024 17:19:44 -0400 Subject: [PATCH 10/14] turn of weighted test --- sklearnex/basic_statistics/tests/test_basic_statistics.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sklearnex/basic_statistics/tests/test_basic_statistics.py b/sklearnex/basic_statistics/tests/test_basic_statistics.py index 21b16038fc..0784125e70 100644 --- a/sklearnex/basic_statistics/tests/test_basic_statistics.py +++ b/sklearnex/basic_statistics/tests/test_basic_statistics.py @@ -184,7 +184,7 @@ def test_multiple_options_on_random_data( @pytest.mark.parametrize("queue", get_queues()) @pytest.mark.parametrize("row_count", [100, 1000]) @pytest.mark.parametrize("column_count", [10, 100]) -@pytest.mark.parametrize("weighted", [True, False]) +# @pytest.mark.parametrize("weighted", [True, False]) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) def test_multiple_options_on_random_sparse_data( queue, row_count, column_count, weighted, dtype @@ -199,6 +199,7 @@ def test_multiple_options_on_random_sparse_data( X_sparse = csr_matrix(X * (np.random.rand(*X.shape) < density)) X_dense = X_sparse.toarray() + weighted = False if weighted: weights = gen.uniform(low=-0.5, high=1.0, size=row_count) weights = weights.astype(dtype=dtype) From 49f9ad7f75e78a65a4666ad34076698353dac418 Mon Sep 17 00:00:00 2001 From: Md Shafiul Alam Date: Thu, 10 Oct 2024 19:05:05 -0700 Subject: [PATCH 11/14] test without weighted --- sklearnex/basic_statistics/tests/test_basic_statistics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearnex/basic_statistics/tests/test_basic_statistics.py b/sklearnex/basic_statistics/tests/test_basic_statistics.py index 0784125e70..4d0674ad0e 100644 --- a/sklearnex/basic_statistics/tests/test_basic_statistics.py +++ b/sklearnex/basic_statistics/tests/test_basic_statistics.py @@ -187,7 +187,7 @@ def test_multiple_options_on_random_data( # @pytest.mark.parametrize("weighted", [True, False]) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) def test_multiple_options_on_random_sparse_data( - queue, row_count, column_count, weighted, dtype + queue, row_count, column_count, dtype ): seed = 77 random_state = 42 From 32955d4d3a072e84e2c4af4e6eb8cc7eb345e509 Mon Sep 17 00:00:00 2001 From: Md Shafiul Alam Date: Fri, 11 Oct 2024 00:06:27 -0700 Subject: [PATCH 12/14] improve bs sparse test --- .../basic_statistics/basic_statistics.py | 11 +++ .../tests/test_basic_statistics.py | 69 +++++++++++++++---- sklearnex/tests/test_run_to_run_stability.py | 8 ++- 3 files changed, 72 insertions(+), 16 deletions(-) diff --git a/sklearnex/basic_statistics/basic_statistics.py b/sklearnex/basic_statistics/basic_statistics.py index 396f0a5544..4d8bd820bf 100644 --- a/sklearnex/basic_statistics/basic_statistics.py +++ b/sklearnex/basic_statistics/basic_statistics.py @@ -17,6 +17,7 @@ import warnings import numpy as np +from fromonedal.utils import _is_csr from sklearn.base import BaseEstimator from sklearn.utils import check_array from sklearn.utils.validation import _check_sample_weight @@ -140,6 +141,16 @@ def _onedal_supported(self, method_name, *data): patching_status = PatchingConditionsChain( f"sklearnex.basic_statistics.{self.__class__.__name__}.{method_name}" ) + + X, sample_weight = data[0], data[1] + patching_status.and_conditions( + [ + ( + _is_csr(X) and sample_weight is not None, + "sample_weight is not supported for CSR data format.", + ), + ] + ) return patching_status _onedal_cpu_supported = _onedal_supported diff --git a/sklearnex/basic_statistics/tests/test_basic_statistics.py b/sklearnex/basic_statistics/tests/test_basic_statistics.py index 4d0674ad0e..3cf372b316 100644 --- a/sklearnex/basic_statistics/tests/test_basic_statistics.py +++ b/sklearnex/basic_statistics/tests/test_basic_statistics.py @@ -24,7 +24,14 @@ from onedal.basic_statistics.tests.test_basic_statistics import ( expected_max, expected_mean, + expected_min, + expected_second_order_raw_moment, + expected_standard_deviation, expected_sum, + expected_sum_squares, + expected_sum_squares_centered, + expected_variance, + expected_variation, options_and_tests, ) from onedal.tests.utils._dataframes_support import ( @@ -184,13 +191,16 @@ def test_multiple_options_on_random_data( @pytest.mark.parametrize("queue", get_queues()) @pytest.mark.parametrize("row_count", [100, 1000]) @pytest.mark.parametrize("column_count", [10, 100]) -# @pytest.mark.parametrize("weighted", [True, False]) +@pytest.mark.parametrize("weighted", [True, False]) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) def test_multiple_options_on_random_sparse_data( - queue, row_count, column_count, dtype + queue, row_count, column_count, weighted, dtype ): seed = 77 random_state = 42 + + if weighted: + pytest.skip("Weighted sparse computation is not supported for sparse data") gen = np.random.default_rng(seed) X, _ = make_blobs( n_samples=row_count, n_features=column_count, random_state=random_state @@ -199,33 +209,62 @@ def test_multiple_options_on_random_sparse_data( X_sparse = csr_matrix(X * (np.random.rand(*X.shape) < density)) X_dense = X_sparse.toarray() - weighted = False if weighted: weights = gen.uniform(low=-0.5, high=1.0, size=row_count) weights = weights.astype(dtype=dtype) - basicstat = BasicStatistics(result_options=["mean", "sum"]) + + options = [ + "sum", + "max", + "min", + "mean", + "standard_deviation" "variance", + "sum_squares", + "sum_squares_centered", + "second_order_raw_moment", + ] + basicstat = BasicStatistics(result_options=options) + if result_option == "max": + pytest.skip("There is a bug in oneDAL's max computations on GPU") if weighted: result = basicstat.fit(X_sparse, sample_weight=weights) else: result = basicstat.fit(X_sparse) - res_mean, res_sum = result.mean, result.sum if weighted: weighted_data = np.diag(weights) @ X_dense - gtr_mean, gtr_sum = ( - expected_mean(weighted_data), - expected_sum(weighted_data), - ) + + gtr_sum = expected_sum(weighted_data) + gtr_min = expected_min(weighted_data) + gtr_mean = expected_mean(weighted_data) + gtr_std = expected_standard_deviation(weighted_data) + gtr_var = expected_variance(weighted_data) + gtr_variation = expected_variation(weighted_data) + gtr_ss = expected_sum_squares(weighted_data) + gtr_ssc = expected_sum_squares_centered(weighted_data) + gtr_seconf_moment = expected_second_order_raw_moment(weighted_data) else: - gtr_mean, gtr_sum = ( - expected_mean(X_dense), - expected_sum(X_dense), - ) + gtr_sum = expected_sum(X_dense) + gtr_min = expected_min(X_dense) + gtr_mean = expected_mean(X_dense) + gtr_std = expected_standard_deviation(X_dense) + gtr_var = expected_variance(X_dense) + gtr_variation = expected_variation(X_dense) + gtr_ss = expected_sum_squares(X_dense) + gtr_ssc = expected_sum_squares_centered(X_dense) + gtr_seconf_moment = expected_second_order_raw_moment(X_dense) tol = 5e-4 if res_mean.dtype == np.float32 else 1e-7 - assert_allclose(gtr_mean, res_mean, atol=tol) - assert_allclose(gtr_sum, res_sum, atol=tol) + assert_allclose(gtr_sum, result.sum_, atol=tol) + assert_allclose(gtr_min, result.min_, atol=tol) + assert_allclose(gtr_mean, result.mean_, atol=tol) + assert_allclose(gtr_std, result.standard_deviation_, atol=tol) + assert_allclose(gtr_var, result.variance_, atol=tol) + assert_allclose(gtr_variation, result.variation_, atol=tol) + assert_allclose(gtr_ss, result.sum_squares_, atol=tol) + assert_allclose(gtr_ssc, result.sum_squares_centered_, atol=tol) + assert_allclose(gtr_seconf_moment, result.second_order_raw_moment_, atol=tol) @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues()) diff --git a/sklearnex/tests/test_run_to_run_stability.py b/sklearnex/tests/test_run_to_run_stability.py index bae2d27a83..1e374ca8b0 100755 --- a/sklearnex/tests/test_run_to_run_stability.py +++ b/sklearnex/tests/test_run_to_run_stability.py @@ -117,7 +117,13 @@ def _run_test(estimator, method, datasets): _sparse_instances = [SVC()] -if daal_check_version((2024, "P", 700)): # Test for > 2024.7.0 +if daal_check_version((2025, "P", 100)): # Test for >= 2025.1.0 + _sparse_instances.extend( + [ + BasicStatistics(), + ] + ) +if daal_check_version((2024, "P", 700)): # Test for >= 2024.7.0 _sparse_instances.extend( [ KMeans(), From 2bf51fb81f4d230a45b2588ad814a0e087f34ece Mon Sep 17 00:00:00 2001 From: Md Shafiul Alam Date: Fri, 11 Oct 2024 01:36:59 -0700 Subject: [PATCH 13/14] minor fix --- sklearnex/basic_statistics/basic_statistics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearnex/basic_statistics/basic_statistics.py b/sklearnex/basic_statistics/basic_statistics.py index 4d8bd820bf..c6c8ff7329 100644 --- a/sklearnex/basic_statistics/basic_statistics.py +++ b/sklearnex/basic_statistics/basic_statistics.py @@ -17,7 +17,7 @@ import warnings import numpy as np -from fromonedal.utils import _is_csr +from onedal.utils import _is_csr from sklearn.base import BaseEstimator from sklearn.utils import check_array from sklearn.utils.validation import _check_sample_weight From f5eb5c07b1cb5ecf1f110895d44a2b31d5bec826 Mon Sep 17 00:00:00 2001 From: Md Shafiul Alam Date: Fri, 11 Oct 2024 01:43:01 -0700 Subject: [PATCH 14/14] minor --- sklearnex/basic_statistics/basic_statistics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearnex/basic_statistics/basic_statistics.py b/sklearnex/basic_statistics/basic_statistics.py index c6c8ff7329..975b934852 100644 --- a/sklearnex/basic_statistics/basic_statistics.py +++ b/sklearnex/basic_statistics/basic_statistics.py @@ -17,7 +17,6 @@ import warnings import numpy as np -from onedal.utils import _is_csr from sklearn.base import BaseEstimator from sklearn.utils import check_array from sklearn.utils.validation import _check_sample_weight @@ -25,6 +24,7 @@ from daal4py.sklearn._n_jobs_support import control_n_jobs from daal4py.sklearn._utils import sklearn_check_version from onedal.basic_statistics import BasicStatistics as onedal_BasicStatistics +from onedal.utils import _is_csr from .._device_offload import dispatch from .._utils import PatchingConditionsChain