Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Array API dispatching #2096

Draft
wants to merge 30 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
8bedde1
ENH: array api dispatching
samir-nasibli Oct 2, 2024
b11fcf3
Deselect some scikit-learn Array API tests
samir-nasibli Oct 4, 2024
467634a
Merge branch 'intel:main' into enh/array_api_dispatching
samir-nasibli Oct 4, 2024
31030f7
Merge branch 'intel:main' into enh/array_api_dispatching
samir-nasibli Oct 8, 2024
943796e
deselect more tests
samir-nasibli Oct 8, 2024
ef42daa
deselect more tests
samir-nasibli Oct 8, 2024
3bc755d
disabled tests for
samir-nasibli Oct 8, 2024
76f1876
fix the deselection comment
samir-nasibli Oct 8, 2024
ce0b8e1
disabled test for Ridge regression
samir-nasibli Oct 8, 2024
404e8c0
Disabled tests and added comment
samir-nasibli Oct 8, 2024
ced43bf
ENH: Array API dispatching
samir-nasibli Oct 8, 2024
968365f
Merge branch 'intel:main' into enh/array_api_dispatching_testing
samir-nasibli Oct 9, 2024
c395d03
Revert adding dpctl into Array PI conformance testing
samir-nasibli Oct 9, 2024
9271479
Merge branch 'enh/array_api_dispatching_testing' of https://github.co…
samir-nasibli Oct 9, 2024
5784c25
minor refactoring onedal _array_api
samir-nasibli Oct 9, 2024
8d7f664
add tests
samir-nasibli Oct 9, 2024
63d8f30
addressed memory usage tests
samir-nasibli Oct 9, 2024
6bd0280
Address some array api test fails
samir-nasibli Oct 9, 2024
90411e7
linting
samir-nasibli Oct 9, 2024
2b7bbc5
addressed test_get_namespace
samir-nasibli Oct 9, 2024
b7b8f03
adding test case for validate_data check with Array API inputs
samir-nasibli Oct 9, 2024
169009d
minor refactoring
samir-nasibli Oct 9, 2024
9ca118c
addressed test_patch_map_match fail
samir-nasibli Oct 9, 2024
7ddcf40
Added docstrings for get_namespace
samir-nasibli Oct 9, 2024
ec90d43
docstrings for Array API tests
samir-nasibli Oct 9, 2024
6e7e547
updated minimal scikit-learn version for Array API dispatching
samir-nasibli Oct 9, 2024
e5db839
updated minimal scikit-learn version for Array API dispatching in _de…
samir-nasibli Oct 9, 2024
f99a92b
fix test test_get_namespace_with_config_context
samir-nasibli Oct 9, 2024
8844f0e
Merge branch 'intel:main' into enh/array_api_dispatching_testing
samir-nasibli Oct 10, 2024
3771fc2
refactor onedal/datatypes/_data_conversion.py
samir-nasibli Oct 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions onedal/utils/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,19 @@ def _get_sycl_namespace(*arrays):
"""Get namespace of sycl arrays."""

# sycl support designed to work regardless of array_api_dispatch sklearn global value
sycl_type = {type(x): x for x in arrays if hasattr(x, "__sycl_usm_array_interface__")}
sua_iface = {type(x): x for x in arrays if hasattr(x, "__sycl_usm_array_interface__")}

if len(sycl_type) > 1:
raise ValueError(f"Multiple SYCL types for array inputs: {sycl_type}")
if len(sua_iface) > 1:
raise ValueError(f"Multiple SYCL types for array inputs: {sua_iface}")

if sycl_type:
(X,) = sycl_type.values()
if sua_iface:
(X,) = sua_iface.values()

if hasattr(X, "__array_namespace__"):
return sycl_type, X.__array_namespace__(), True
return sua_iface, X.__array_namespace__(), True
elif dpnp_available and isinstance(X, dpnp.ndarray):
return sycl_type, dpnp, False
return sua_iface, dpnp, False
else:
raise ValueError(f"SYCL type not recognized: {sycl_type}")
raise ValueError(f"SYCL type not recognized: {sua_iface}")

return sycl_type, None, False
return sua_iface, None, False
5 changes: 3 additions & 2 deletions sklearnex/_device_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@

from functools import wraps

from daal4py.sklearn._utils import sklearn_check_version
from onedal._device_offload import (
_copy_to_usm,
_get_global_queue,
_transfer_to_host,
dpnp_available,
)
from onedal.utils._array_api import _asarray, _is_numpy_namespace
from onedal.utils._array_api import _asarray

if dpnp_available:
import dpnp
Expand Down Expand Up @@ -74,7 +75,7 @@ def dispatch(obj, method_name, branches, *args, **kwargs):
return branches[backend](obj, *hostargs, **hostkwargs, queue=q)
if backend == "sklearn":
if (
"array_api_dispatch" in get_config()
sklearn_check_version("1.2")
and get_config()["array_api_dispatch"]
and "array_api_support" in obj._get_tags()
and obj._get_tags()["array_api_support"]
Expand Down
25 changes: 25 additions & 0 deletions sklearnex/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ def get_patch_map_core(preview=False):
from ._config import get_config as get_config_sklearnex
from ._config import set_config as set_config_sklearnex

if sklearn_check_version("1.2"):
import sklearn.utils._array_api as _array_api_module

if sklearn_check_version("1.2.1"):
from .utils.parallel import _FuncWrapper as _FuncWrapper_sklearnex
else:
Expand Down Expand Up @@ -165,6 +168,10 @@ def get_patch_map_core(preview=False):
from .svm import NuSVC as NuSVC_sklearnex
from .svm import NuSVR as NuSVR_sklearnex

if sklearn_check_version("1.2"):
from .utils._array_api import _convert_to_numpy as _convert_to_numpy_sklearnex
from .utils._array_api import get_namespace as get_namespace_sklearnex

# DBSCAN
mapping.pop("dbscan")
mapping["dbscan"] = [[(cluster_module, "DBSCAN", DBSCAN_sklearnex), None]]
Expand Down Expand Up @@ -440,6 +447,24 @@ def get_patch_map_core(preview=False):
mapping["_funcwrapper"] = [
[(parallel_module, "_FuncWrapper", _FuncWrapper_sklearnex), None]
]
if sklearn_check_version("1.2"):
# Necessary for array_api support
mapping["get_namespace"] = [
[
(
_array_api_module,
"get_namespace",
get_namespace_sklearnex,
),
None,
]
]
mapping["_convert_to_numpy"] = [
[
(_array_api_module, "_convert_to_numpy", _convert_to_numpy_sklearnex),
None,
]
]
return mapping


Expand Down
2 changes: 2 additions & 0 deletions sklearnex/tests/test_memory_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@


CPU_SKIP_LIST = (
"_convert_to_numpy", # additional memory allocation is expected proportional to the input data
"TSNE", # too slow for using in testing on common data size
"config_context", # does not malloc
"get_config", # does not malloc
Expand All @@ -59,6 +60,7 @@
)

GPU_SKIP_LIST = (
"_convert_to_numpy", # additional memory allocation is expected proportional to the input data
"TSNE", # too slow for using in testing on common data size
"RandomForestRegressor", # too slow for using in testing on common data size
"KMeans", # does not support GPU offloading
Expand Down
121 changes: 79 additions & 42 deletions sklearnex/utils/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,64 +19,101 @@
import numpy as np

from daal4py.sklearn._utils import sklearn_check_version
from onedal.utils._array_api import _get_sycl_namespace
from onedal.utils._array_api import _asarray, _get_sycl_namespace

# TODO:
# check the version of skl.
samir-nasibli marked this conversation as resolved.
Show resolved Hide resolved
if sklearn_check_version("1.2"):
from sklearn.utils._array_api import get_namespace as sklearn_get_namespace
from sklearn.utils._array_api import _convert_to_numpy as _sklearn_convert_to_numpy

from onedal._device_offload import dpctl_available, dpnp_available

def get_namespace(*arrays):
"""Get namespace of arrays.
if dpctl_available:
import dpctl.tensor as dpt

Introspect `arrays` arguments and return their common Array API
compatible namespace object, if any. NumPy 1.22 and later can
construct such containers using the `numpy.array_api` namespace
for instance.
if dpnp_available:
import dpnp

This function will return the namespace of SYCL-related arrays
which define the __sycl_usm_array_interface__ attribute
regardless of array_api support, the configuration of
array_api_dispatch, or scikit-learn version.

See: https://numpy.org/neps/nep-0047-array-api-standard.html
def _convert_to_numpy(array, xp):
"""Convert X into a NumPy ndarray on the CPU."""
xp_name = xp.__name__

If `arrays` are regular numpy arrays, an instance of the
`_NumPyApiWrapper` compatibility wrapper is returned instead.
if dpctl_available and xp_name in {
"dpctl.tensor",
}:
return dpt.to_numpy(array)
elif dpnp_available and isinstance(array, dpnp.ndarray):
return dpnp.asnumpy(array)
elif sklearn_check_version("1.2"):
return _sklearn_convert_to_numpy(array, xp)
else:
return _asarray(array, xp)

Namespace support is not enabled by default. To enabled it
call:

sklearn.set_config(array_api_dispatch=True)
if sklearn_check_version("1.5"):

or:
def get_namespace(*arrays, remove_none=True, remove_types=(str,), xp=None):
"""Get namespace of arrays.

with sklearn.config_context(array_api_dispatch=True):
# your code here
TBD

Otherwise an instance of the `_NumPyApiWrapper`
compatibility wrapper is always returned irrespective of
the fact that arrays implement the `__array_namespace__`
protocol or not.
Parameters
----------
*arrays : array objects
Array objects.

Parameters
----------
*arrays : array objects
Array objects.
Returns
-------
namespace : module
Namespace shared by array objects.

Returns
-------
namespace : module
Namespace shared by array objects.
is_array_api : bool
True of the arrays are containers that implement the Array API spec.
"""

is_array_api : bool
True of the arrays are containers that implement the Array API spec.
"""
usm_iface, xp_sycl_namespace, is_array_api_compliant = _get_sycl_namespace(
*arrays
)

sycl_type, xp, is_array_api_compliant = _get_sycl_namespace(*arrays)
if usm_iface:
return xp_sycl_namespace, is_array_api_compliant
elif sklearn_check_version("1.2"):
return sklearn_get_namespace(
*arrays, remove_none=remove_none, remove_types=remove_types, xp=xp
)
else:
return np, False

if sycl_type:
return xp, is_array_api_compliant
elif sklearn_check_version("1.2"):
return sklearn_get_namespace(*arrays)
else:
return np, False
else:

def get_namespace(*arrays):
"""Get namespace of arrays.

TBD

Parameters
----------
*arrays : array objects
Array objects.

Returns
-------
namespace : module
Namespace shared by array objects.

is_array_api : bool
True of the arrays are containers that implement the Array API spec.
"""

usm_iface, xp_sycl_namespace, is_array_api_compliant = _get_sycl_namespace(
*arrays
)

if usm_iface:
return xp_sycl_namespace, is_array_api_compliant
elif sklearn_check_version("1.2"):
return sklearn_get_namespace(*arrays)
else:
return np, False
Loading
Loading