From 3771fc292d02b0481966d2910d841ba140501239 Mon Sep 17 00:00:00 2001 From: Samir Nasibli Date: Fri, 11 Oct 2024 14:48:06 -0700 Subject: [PATCH] refactor onedal/datatypes/_data_conversion.py --- onedal/datatypes/_data_conversion.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/onedal/datatypes/_data_conversion.py b/onedal/datatypes/_data_conversion.py index d1dedba81c..46384c19be 100644 --- a/onedal/datatypes/_data_conversion.py +++ b/onedal/datatypes/_data_conversion.py @@ -32,10 +32,14 @@ dpctl_available = False -def _apply_and_pass(func, *args): +def _apply_and_pass(func, *args, **kwargs): if len(args) == 1: - return func(args[0]) - return tuple(map(func, args)) + return func(args[0], **kwargs) if len(kwargs) > 0 else func(args[0]) + return ( + tuple(func(arg, **kwargs) for arg in args) + if len(kwargs) > 0 + else tuple(func(arg) for arg in args) + ) def from_table(*args): @@ -59,7 +63,7 @@ def to_table(*args): if _is_dpc_backend: from ..common._policy import _HostInteropPolicy - def _convert_to_supported(policy, *data): + def _convert_to_supported(policy, *data, xp=np): def func(x): return x @@ -71,13 +75,13 @@ def func(x): device = policy._queue.sycl_device def convert_or_pass(x): - if (x is not None) and (x.dtype == np.float64): + if (x is not None) and (x.dtype == xp.float64): warnings.warn( "Data will be converted into float32 from " "float64 because device does not support it", RuntimeWarning, ) - return x.astype(np.float32) + return x.astype(xp.float32) else: return x @@ -88,7 +92,7 @@ def convert_or_pass(x): else: - def _convert_to_supported(policy, *data): + def _convert_to_supported(policy, *data, xp=np): def func(x): return x