Skip to content

Commit

Permalink
improve bs sparse test
Browse files Browse the repository at this point in the history
  • Loading branch information
md-shafiul-alam committed Oct 11, 2024
1 parent 49f9ad7 commit 32955d4
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 16 deletions.
11 changes: 11 additions & 0 deletions sklearnex/basic_statistics/basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import warnings

import numpy as np
from fromonedal.utils import _is_csr
from sklearn.base import BaseEstimator
from sklearn.utils import check_array
from sklearn.utils.validation import _check_sample_weight
Expand Down Expand Up @@ -140,6 +141,16 @@ def _onedal_supported(self, method_name, *data):
patching_status = PatchingConditionsChain(
f"sklearnex.basic_statistics.{self.__class__.__name__}.{method_name}"
)

X, sample_weight = data[0], data[1]
patching_status.and_conditions(
[
(
_is_csr(X) and sample_weight is not None,
"sample_weight is not supported for CSR data format.",
),
]
)
return patching_status

_onedal_cpu_supported = _onedal_supported
Expand Down
69 changes: 54 additions & 15 deletions sklearnex/basic_statistics/tests/test_basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,14 @@
from onedal.basic_statistics.tests.test_basic_statistics import (
expected_max,
expected_mean,
expected_min,
expected_second_order_raw_moment,
expected_standard_deviation,
expected_sum,
expected_sum_squares,
expected_sum_squares_centered,
expected_variance,
expected_variation,
options_and_tests,
)
from onedal.tests.utils._dataframes_support import (
Expand Down Expand Up @@ -184,13 +191,16 @@ def test_multiple_options_on_random_data(
@pytest.mark.parametrize("queue", get_queues())
@pytest.mark.parametrize("row_count", [100, 1000])
@pytest.mark.parametrize("column_count", [10, 100])
# @pytest.mark.parametrize("weighted", [True, False])
@pytest.mark.parametrize("weighted", [True, False])
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_multiple_options_on_random_sparse_data(
queue, row_count, column_count, dtype
queue, row_count, column_count, weighted, dtype
):
seed = 77
random_state = 42

if weighted:
pytest.skip("Weighted sparse computation is not supported for sparse data")
gen = np.random.default_rng(seed)
X, _ = make_blobs(
n_samples=row_count, n_features=column_count, random_state=random_state
Expand All @@ -199,33 +209,62 @@ def test_multiple_options_on_random_sparse_data(
X_sparse = csr_matrix(X * (np.random.rand(*X.shape) < density))
X_dense = X_sparse.toarray()

weighted = False
if weighted:
weights = gen.uniform(low=-0.5, high=1.0, size=row_count)
weights = weights.astype(dtype=dtype)
basicstat = BasicStatistics(result_options=["mean", "sum"])

options = [
"sum",
"max",
"min",
"mean",
"standard_deviation" "variance",
"sum_squares",
"sum_squares_centered",
"second_order_raw_moment",
]
basicstat = BasicStatistics(result_options=options)
if result_option == "max":
pytest.skip("There is a bug in oneDAL's max computations on GPU")

if weighted:
result = basicstat.fit(X_sparse, sample_weight=weights)
else:
result = basicstat.fit(X_sparse)

res_mean, res_sum = result.mean, result.sum
if weighted:
weighted_data = np.diag(weights) @ X_dense
gtr_mean, gtr_sum = (
expected_mean(weighted_data),
expected_sum(weighted_data),
)

gtr_sum = expected_sum(weighted_data)
gtr_min = expected_min(weighted_data)
gtr_mean = expected_mean(weighted_data)
gtr_std = expected_standard_deviation(weighted_data)
gtr_var = expected_variance(weighted_data)
gtr_variation = expected_variation(weighted_data)
gtr_ss = expected_sum_squares(weighted_data)
gtr_ssc = expected_sum_squares_centered(weighted_data)
gtr_seconf_moment = expected_second_order_raw_moment(weighted_data)
else:
gtr_mean, gtr_sum = (
expected_mean(X_dense),
expected_sum(X_dense),
)
gtr_sum = expected_sum(X_dense)
gtr_min = expected_min(X_dense)
gtr_mean = expected_mean(X_dense)
gtr_std = expected_standard_deviation(X_dense)
gtr_var = expected_variance(X_dense)
gtr_variation = expected_variation(X_dense)
gtr_ss = expected_sum_squares(X_dense)
gtr_ssc = expected_sum_squares_centered(X_dense)
gtr_seconf_moment = expected_second_order_raw_moment(X_dense)

tol = 5e-4 if res_mean.dtype == np.float32 else 1e-7
assert_allclose(gtr_mean, res_mean, atol=tol)
assert_allclose(gtr_sum, res_sum, atol=tol)
assert_allclose(gtr_sum, result.sum_, atol=tol)
assert_allclose(gtr_min, result.min_, atol=tol)
assert_allclose(gtr_mean, result.mean_, atol=tol)
assert_allclose(gtr_std, result.standard_deviation_, atol=tol)
assert_allclose(gtr_var, result.variance_, atol=tol)
assert_allclose(gtr_variation, result.variation_, atol=tol)
assert_allclose(gtr_ss, result.sum_squares_, atol=tol)
assert_allclose(gtr_ssc, result.sum_squares_centered_, atol=tol)
assert_allclose(gtr_seconf_moment, result.second_order_raw_moment_, atol=tol)


@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
Expand Down
8 changes: 7 additions & 1 deletion sklearnex/tests/test_run_to_run_stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,13 @@ def _run_test(estimator, method, datasets):


_sparse_instances = [SVC()]
if daal_check_version((2024, "P", 700)): # Test for > 2024.7.0
if daal_check_version((2025, "P", 100)): # Test for >= 2025.1.0
_sparse_instances.extend(
[
BasicStatistics(),
]
)
if daal_check_version((2024, "P", 700)): # Test for >= 2024.7.0
_sparse_instances.extend(
[
KMeans(),
Expand Down

0 comments on commit 32955d4

Please sign in to comment.