Skip to content

Commit 55dfe8b

Browse files
committed
WIP lazy_apply_elementwise
1 parent 9631f32 commit 55dfe8b

File tree

4 files changed

+174
-23
lines changed

4 files changed

+174
-23
lines changed

docs/api-lazy.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ lazy backends, e.g. Dask or JAX:
1010
:toctree: generated
1111
1212
lazy_apply
13+
lazy_apply_elementwise
1314
testing.lazy_xp_function
1415
testing.patch_lazy_xp_functions
1516
```

src/array_api_extra/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
setdiff1d,
1515
sinc,
1616
)
17-
from ._lib._lazy import lazy_apply
17+
from ._lib._lazy import lazy_apply, lazy_apply_elementwise
1818

1919
__version__ = "0.7.1.dev0"
2020

@@ -31,6 +31,7 @@
3131
"isclose",
3232
"kron",
3333
"lazy_apply",
34+
"lazy_apply_elementwise",
3435
"nunique",
3536
"pad",
3637
"setdiff1d",

src/array_api_extra/_lib/_lazy.py

Lines changed: 160 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
from __future__ import annotations
44

55
import math
6+
import operator
67
from collections.abc import Callable, Sequence
78
from functools import partial, wraps
89
from types import ModuleType
9-
from typing import TYPE_CHECKING, Any, ParamSpec, TypeAlias, cast, overload
10+
from typing import TYPE_CHECKING, Any, TypeAlias, cast, overload
1011

1112
from ._funcs import broadcast_shapes
1213
from ._utils import _compat
@@ -27,41 +28,39 @@
2728
# Sphinx hack
2829
NumPyObject = Any
2930

30-
P = ParamSpec("P")
31-
3231

3332
@overload
34-
def lazy_apply( # type: ignore[decorated-any, valid-type]
35-
func: Callable[P, Array | ArrayLike],
33+
def lazy_apply( # type: ignore[explicit-any,decorated-any]
34+
func: Callable[..., Array | ArrayLike],
3635
*args: Array | complex | None,
3736
shape: tuple[int | None, ...] | None = None,
3837
dtype: DType | None = None,
3938
as_numpy: bool = False,
4039
xp: ModuleType | None = None,
41-
**kwargs: P.kwargs, # pyright: ignore[reportGeneralTypeIssues]
40+
**kwargs: Any,
4241
) -> Array: ... # numpydoc ignore=GL08
4342

4443

4544
@overload
46-
def lazy_apply( # type: ignore[decorated-any, valid-type]
47-
func: Callable[P, Sequence[Array | ArrayLike]],
45+
def lazy_apply( # type: ignore[explicit-any,decorated-any]
46+
func: Callable[..., Sequence[Array | ArrayLike]],
4847
*args: Array | complex | None,
4948
shape: Sequence[tuple[int | None, ...]],
5049
dtype: Sequence[DType] | None = None,
5150
as_numpy: bool = False,
5251
xp: ModuleType | None = None,
53-
**kwargs: P.kwargs, # pyright: ignore[reportGeneralTypeIssues]
52+
**kwargs: Any,
5453
) -> tuple[Array, ...]: ... # numpydoc ignore=GL08
5554

5655

57-
def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
58-
func: Callable[P, Array | ArrayLike | Sequence[Array | ArrayLike]],
56+
def lazy_apply( # type: ignore[explicit-any] # numpydoc ignore=GL07,SA04
57+
func: Callable[..., Array | ArrayLike | Sequence[Array | ArrayLike]],
5958
*args: Array | complex | None,
6059
shape: tuple[int | None, ...] | Sequence[tuple[int | None, ...]] | None = None,
6160
dtype: DType | Sequence[DType] | None = None,
6261
as_numpy: bool = False,
6362
xp: ModuleType | None = None,
64-
**kwargs: P.kwargs, # pyright: ignore[reportGeneralTypeIssues]
63+
**kwargs: Any,
6564
) -> Array | tuple[Array, ...]:
6665
"""
6766
Lazily apply an eager function.
@@ -157,10 +156,11 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
157156
The outputs will also be returned as a single chunk and you should consider
158157
rechunking them into smaller chunks afterwards.
159158
160-
If you want to distribute the calculation across multiple workers, you
161-
should use :func:`dask.array.map_blocks`, :func:`dask.array.map_overlap`,
162-
:func:`dask.array.blockwise`, or a native Dask wrapper instead of
163-
`lazy_apply`.
159+
If you want to distribute the calculation across multiple workers and your
160+
function is elementwise, you should use :func:`lazy_apply_elementwise` instead.
161+
If the function is not elementwise, you should consider writing an ad-hoc
162+
variant for Dask using primitives like :func:`dask.array.blockwise`,
163+
:func:`dask.array.map_overlap`, or a native Dask algorithm.
164164
165165
Dask wrapping around other backends
166166
If ``as_numpy=False``, `func` will receive in input eager arrays of the meta
@@ -181,9 +181,9 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
181181
182182
See Also
183183
--------
184+
lazy_apply_elementwise
184185
jax.transfer_guard
185186
jax.pure_callback
186-
dask.array.map_blocks
187187
dask.array.map_overlap
188188
dask.array.blockwise
189189
"""
@@ -235,7 +235,7 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
235235
if is_dask_namespace(xp):
236236
import dask
237237

238-
metas: list[Array] = [arg._meta for arg in array_args] # pylint: disable=protected-access # pyright: ignore[reportAttributeAccessIssue]
238+
metas: list[Array] = [arg._meta for arg in array_args] # type: ignore[attr-defined] # pylint: disable=protected-access # pyright: ignore[reportAttributeAccessIssue]
239239
meta_xp = array_namespace(*metas)
240240

241241
wrapped = dask.delayed( # type: ignore[attr-defined] # pyright: ignore[reportPrivateImportUsage]
@@ -350,3 +350,145 @@ def wrapper( # type: ignore[decorated-any,explicit-any]
350350
return (xp.asarray(out, device=device),)
351351

352352
return wrapper
353+
354+
355+
@overload
356+
def lazy_apply_elementwise( # type: ignore[explicit-any,decorated-any]
357+
func: Callable[..., Array | ArrayLike],
358+
*args: Array | complex | None,
359+
dtype: DType | None = None,
360+
as_numpy: bool = False,
361+
xp: ModuleType | None = None,
362+
**kwargs: Any,
363+
) -> Array: ... # numpydoc ignore=GL08
364+
365+
366+
@overload
367+
def lazy_apply_elementwise( # type: ignore[explicit-any,decorated-any]
368+
func: Callable[..., Sequence[Array | ArrayLike]],
369+
*args: Array | complex | None,
370+
dtype: Sequence[DType | None],
371+
as_numpy: bool = False,
372+
xp: ModuleType | None = None,
373+
**kwargs: Any,
374+
) -> tuple[Array, ...]: ... # numpydoc ignore=GL08
375+
376+
377+
def lazy_apply_elementwise( # type: ignore[explicit-any]
378+
func: Callable[..., Array | ArrayLike | Sequence[Array | ArrayLike]],
379+
*args: Array | complex | None,
380+
dtype: DType | Sequence[DType | None] | None = None,
381+
as_numpy: bool = False,
382+
xp: ModuleType | None = None,
383+
**kwargs: Any,
384+
) -> Array | tuple[Array, ...]:
385+
"""
386+
Lazily apply an eager elementwise function.
387+
388+
This is a variant of :func:`lazy_apply` which expects `func` to be elementwise, e.g.
389+
each output point must depend exclusively from the corresponding input point in each
390+
inputarray. This can result in faster execution on some backends.
391+
392+
Parameters
393+
----------
394+
func : callable
395+
As in `lazy_apply`, but in addition it must be elementwise.
396+
*args : Array | int | float | complex | bool | None
397+
As in `lazy_apply`.
398+
dtype : DType | Sequence[DType | None], optional
399+
Output dtype or sequence of output dtypes, one for each output of `func`.
400+
dtype(s) must belong to the same array namespace as the input arrays.
401+
This also informs how many outputs the function has.
402+
Default: assume a single output and infer the result type(s) from
403+
the input arrays.
404+
as_numpy : bool, optional
405+
As in `lazy_apply`.
406+
xp : array_namespace, optional
407+
The standard-compatible namespace for `args`. Default: infer.
408+
**kwargs : Any, optional
409+
As in `lazy_apply`.
410+
411+
Returns
412+
-------
413+
Array | tuple[Array, ...]
414+
The result(s) of `func` applied to the input arrays, wrapped in the same
415+
array namespace as the inputs.
416+
If dtype is omitted or a single dtype, return a single array.
417+
Otherwise, return a tuple of arrays.
418+
419+
See Also
420+
--------
421+
lazy_apply : General version of this function.
422+
dask.array.map_blocks : Dask version of this function.
423+
424+
Notes
425+
-----
426+
Unlike in :func:`lazy_apply`, you can't define output shapes that aren't
427+
broadcasted from the input arrays.
428+
429+
Dask
430+
Unlike :func:`dask.array.map_blocks`, this function allows for multiple outputs.
431+
432+
Dask wrapping around other backends
433+
If ``as_numpy=False``, `func` will receive in input eager arrays of the meta
434+
namespace, as defined by the ``._meta`` attribute of the input Dask arrays. The
435+
outputs of `func` will be wrapped by the meta namespace, and then wrapped again
436+
by Dask.
437+
438+
All other backends
439+
This function is identical to :func:`lazy_apply`.
440+
"""
441+
args_not_none = [arg for arg in args if arg is not None]
442+
array_args = [arg for arg in args_not_none if not is_python_scalar(arg)]
443+
if not array_args:
444+
msg = "Must have at least one argument array"
445+
raise ValueError(msg)
446+
if xp is None:
447+
xp = array_namespace(*array_args)
448+
449+
# Normalize and validate dtype
450+
dtypes: list[DType]
451+
452+
if isinstance(dtype, Sequence):
453+
multi_output = True
454+
if None in dtype:
455+
rtype = xp.result_type(*args_not_none)
456+
dtypes = [d or rtype for d in dtype]
457+
else:
458+
dtypes = list(dtype) # pyright: ignore[reportUnknownArgumentType]
459+
else:
460+
multi_output = False
461+
dtypes = [dtype]
462+
del dtype
463+
464+
if not is_dask_namespace(xp):
465+
shape = broadcast_shapes(*(arg.shape for arg in array_args))
466+
return lazy_apply( # pyright: ignore[reportCallIssue]
467+
func, # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
468+
*args,
469+
shape=[shape] * len(dtypes) if multi_output else shape, # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
470+
dtype=dtypes if multi_output else dtypes[0],
471+
as_numpy=as_numpy,
472+
xp=xp,
473+
**kwargs,
474+
)
475+
476+
# Use da.map_blocks.
477+
# We need to handle multiple outputs, which map_blocks can't.
478+
479+
metas: list[Array] = [arg._meta for arg in array_args] # type: ignore[attr-defined] # pylint: disable=protected-access # pyright: ignore[reportAttributeAccessIssue]
480+
meta_xp = array_namespace(*metas)
481+
482+
wrapped = _lazy_apply_wrapper(func, as_numpy, multi_output, meta_xp)
483+
wrapped = partial(wrapped, **kwargs)
484+
485+
# Hack map_blocks to handle multiple outputs. This intermediate output has bugos
486+
# dtype and meta, but dask.array will never know as long as we always provide
487+
# explicit dtype and meta.
488+
temp = xp.map_blocks(wrapped, *args, dtype=dtypes[0], meta=metas[0])
489+
out = tuple(
490+
temp.map_blocks(operator.itemgetter(i), dtype=dtype, meta=metas[0])
491+
for i, dtype in enumerate(dtypes)
492+
)
493+
494+
return out if multi_output else out[0]

tests/test_lazy.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66

77
import array_api_extra as xpx # Let some tests bypass lazy_xp_function
8-
from array_api_extra import lazy_apply
8+
from array_api_extra import lazy_apply, lazy_apply_elementwise
99
from array_api_extra._lib import Backend
1010
from array_api_extra._lib._testing import xp_assert_equal
1111
from array_api_extra._lib._utils import _compat
@@ -342,7 +342,7 @@ def eager(
342342
return x + 1
343343

344344
# Use explicit namespace to bypass monkey-patching by lazy_xp_function
345-
return xpx.lazy_apply( # pyright: ignore[reportCallIssue]
345+
return xpx.lazy_apply(
346346
eager,
347347
x,
348348
z={0: [1, 2]},
@@ -419,6 +419,13 @@ def f(x: Array) -> Array:
419419
with pytest.raises(ValueError, match="multiple shapes but only one dtype"):
420420
_ = lazy_apply(f, x, shape=[(1,), (2,)], dtype=np.int32) # type: ignore[call-overload] # pyright: ignore[reportCallIssue,reportArgumentType]
421421
with pytest.raises(ValueError, match="single shape but multiple dtypes"):
422-
_ = lazy_apply(f, x, shape=(1,), dtype=[np.int32, np.int64])
422+
_ = lazy_apply(f, x, shape=(1,), dtype=[np.int32, np.int64]) # type: ignore[call-overload]
423423
with pytest.raises(ValueError, match="2 shapes and 1 dtypes"):
424-
_ = lazy_apply(f, x, shape=[(1,), (2,)], dtype=[np.int32]) # type: ignore[arg-type] # pyright: ignore[reportCallIssue,reportArgumentType]
424+
_ = lazy_apply(f, x, shape=[(1,), (2,)], dtype=[np.int32]) # type: ignore[call-overload] # pyright: ignore[reportCallIssue,reportArgumentType]
425+
426+
with pytest.raises(ValueError, match="at least one argument array"):
427+
_ = lazy_apply_elementwise(f, xp=np)
428+
with pytest.raises(ValueError, match="at least one argument array"):
429+
_ = lazy_apply_elementwise(f, 1, xp=np)
430+
with pytest.raises(ValueError, match="at least one argument array"):
431+
_ = lazy_apply_elementwise(f)

0 commit comments

Comments
 (0)