diff --git a/onedal/datatypes/_data_conversion.py b/onedal/datatypes/_data_conversion.py index d1dedba81c..46384c19be 100644 --- a/onedal/datatypes/_data_conversion.py +++ b/onedal/datatypes/_data_conversion.py @@ -32,10 +32,14 @@ dpctl_available = False -def _apply_and_pass(func, *args): +def _apply_and_pass(func, *args, **kwargs): if len(args) == 1: - return func(args[0]) - return tuple(map(func, args)) + return func(args[0], **kwargs) if len(kwargs) > 0 else func(args[0]) + return ( + tuple(func(arg, **kwargs) for arg in args) + if len(kwargs) > 0 + else tuple(func(arg) for arg in args) + ) def from_table(*args): @@ -59,7 +63,7 @@ def to_table(*args): if _is_dpc_backend: from ..common._policy import _HostInteropPolicy - def _convert_to_supported(policy, *data): + def _convert_to_supported(policy, *data, xp=np): def func(x): return x @@ -71,13 +75,13 @@ def func(x): device = policy._queue.sycl_device def convert_or_pass(x): - if (x is not None) and (x.dtype == np.float64): + if (x is not None) and (x.dtype == xp.float64): warnings.warn( "Data will be converted into float32 from " "float64 because device does not support it", RuntimeWarning, ) - return x.astype(np.float32) + return x.astype(xp.float32) else: return x @@ -88,7 +92,7 @@ def convert_or_pass(x): else: - def _convert_to_supported(policy, *data): + def _convert_to_supported(policy, *data, xp=np): def func(x): return x