Skip to content

Commit

Permalink
refactor onedal/datatypes/_data_conversion.py
Browse files Browse the repository at this point in the history
  • Loading branch information
samir-nasibli committed Oct 11, 2024
1 parent 8844f0e commit 3771fc2
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions onedal/datatypes/_data_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,14 @@
dpctl_available = False


def _apply_and_pass(func, *args):
def _apply_and_pass(func, *args, **kwargs):
if len(args) == 1:
return func(args[0])
return tuple(map(func, args))
return func(args[0], **kwargs) if len(kwargs) > 0 else func(args[0])
return (
tuple(func(arg, **kwargs) for arg in args)
if len(kwargs) > 0
else tuple(func(arg) for arg in args)
)


def from_table(*args):
Expand All @@ -59,7 +63,7 @@ def to_table(*args):
if _is_dpc_backend:
from ..common._policy import _HostInteropPolicy

def _convert_to_supported(policy, *data):
def _convert_to_supported(policy, *data, xp=np):
def func(x):
return x

Expand All @@ -71,13 +75,13 @@ def func(x):
device = policy._queue.sycl_device

def convert_or_pass(x):
if (x is not None) and (x.dtype == np.float64):
if (x is not None) and (x.dtype == xp.float64):
warnings.warn(
"Data will be converted into float32 from "
"float64 because device does not support it",
RuntimeWarning,
)
return x.astype(np.float32)
return x.astype(xp.float32)
else:
return x

Expand All @@ -88,7 +92,7 @@ def convert_or_pass(x):

else:

def _convert_to_supported(policy, *data):
def _convert_to_supported(policy, *data, xp=np):
def func(x):
return x

Expand Down

0 comments on commit 3771fc2

Please sign in to comment.