33from __future__ import annotations
44
55import math
6+ import operator
67from collections .abc import Callable , Sequence
78from functools import partial , wraps
89from types import ModuleType
9- from typing import TYPE_CHECKING , Any , ParamSpec , TypeAlias , cast , overload
10+ from typing import TYPE_CHECKING , Any , TypeAlias , cast , overload
1011
1112from ._funcs import broadcast_shapes
1213from ._utils import _compat
2728 # Sphinx hack
2829 NumPyObject = Any
2930
30- P = ParamSpec ("P" )
31-
3231
3332@overload
34- def lazy_apply ( # type: ignore[decorated -any, valid-type ]
35- func : Callable [P , Array | ArrayLike ],
33+ def lazy_apply ( # type: ignore[explicit -any,decorated-any ]
34+ func : Callable [... , Array | ArrayLike ],
3635 * args : Array | complex | None ,
3736 shape : tuple [int | None , ...] | None = None ,
3837 dtype : DType | None = None ,
3938 as_numpy : bool = False ,
4039 xp : ModuleType | None = None ,
41- ** kwargs : P . kwargs , # pyright: ignore[reportGeneralTypeIssues]
40+ ** kwargs : Any ,
4241) -> Array : ... # numpydoc ignore=GL08
4342
4443
4544@overload
46- def lazy_apply ( # type: ignore[decorated -any, valid-type ]
47- func : Callable [P , Sequence [Array | ArrayLike ]],
45+ def lazy_apply ( # type: ignore[explicit -any,decorated-any ]
46+ func : Callable [... , Sequence [Array | ArrayLike ]],
4847 * args : Array | complex | None ,
4948 shape : Sequence [tuple [int | None , ...]],
5049 dtype : Sequence [DType ] | None = None ,
5150 as_numpy : bool = False ,
5251 xp : ModuleType | None = None ,
53- ** kwargs : P . kwargs , # pyright: ignore[reportGeneralTypeIssues]
52+ ** kwargs : Any ,
5453) -> tuple [Array , ...]: ... # numpydoc ignore=GL08
5554
5655
57- def lazy_apply ( # type: ignore[valid-type ] # numpydoc ignore=GL07,SA04
58- func : Callable [P , Array | ArrayLike | Sequence [Array | ArrayLike ]],
56+ def lazy_apply ( # type: ignore[explicit-any ] # numpydoc ignore=GL07,SA04
57+ func : Callable [... , Array | ArrayLike | Sequence [Array | ArrayLike ]],
5958 * args : Array | complex | None ,
6059 shape : tuple [int | None , ...] | Sequence [tuple [int | None , ...]] | None = None ,
6160 dtype : DType | Sequence [DType ] | None = None ,
6261 as_numpy : bool = False ,
6362 xp : ModuleType | None = None ,
64- ** kwargs : P . kwargs , # pyright: ignore[reportGeneralTypeIssues]
63+ ** kwargs : Any ,
6564) -> Array | tuple [Array , ...]:
6665 """
6766 Lazily apply an eager function.
@@ -157,10 +156,11 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
157156 The outputs will also be returned as a single chunk and you should consider
158157 rechunking them into smaller chunks afterwards.
159158
160- If you want to distribute the calculation across multiple workers, you
161- should use :func:`dask.array.map_blocks`, :func:`dask.array.map_overlap`,
162- :func:`dask.array.blockwise`, or a native Dask wrapper instead of
163- `lazy_apply`.
159+ If you want to distribute the calculation across multiple workers and your
160+ function is elementwise, you should use :func:`lazy_apply_elementwise` instead.
161+ If the function is not elementwise, you should consider writing an ad-hoc
162+ variant for Dask using primitives like :func:`dask.array.blockwise`,
163+ :func:`dask.array.map_overlap`, or a native Dask algorithm.
164164
165165 Dask wrapping around other backends
166166 If ``as_numpy=False``, `func` will receive in input eager arrays of the meta
@@ -181,9 +181,9 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
181181
182182 See Also
183183 --------
184+ lazy_apply_elementwise
184185 jax.transfer_guard
185186 jax.pure_callback
186- dask.array.map_blocks
187187 dask.array.map_overlap
188188 dask.array.blockwise
189189 """
@@ -235,7 +235,7 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
235235 if is_dask_namespace (xp ):
236236 import dask
237237
238- metas : list [Array ] = [arg ._meta for arg in array_args ] # pylint: disable=protected-access # pyright: ignore[reportAttributeAccessIssue]
238+ metas : list [Array ] = [arg ._meta for arg in array_args ] # type: ignore[attr-defined] # pylint: disable=protected-access # pyright: ignore[reportAttributeAccessIssue]
239239 meta_xp = array_namespace (* metas )
240240
241241 wrapped = dask .delayed ( # type: ignore[attr-defined] # pyright: ignore[reportPrivateImportUsage]
@@ -350,3 +350,145 @@ def wrapper( # type: ignore[decorated-any,explicit-any]
350350 return (xp .asarray (out , device = device ),)
351351
352352 return wrapper
353+
354+
355+ @overload
356+ def lazy_apply_elementwise ( # type: ignore[explicit-any,decorated-any]
357+ func : Callable [..., Array | ArrayLike ],
358+ * args : Array | complex | None ,
359+ dtype : DType | None = None ,
360+ as_numpy : bool = False ,
361+ xp : ModuleType | None = None ,
362+ ** kwargs : Any ,
363+ ) -> Array : ... # numpydoc ignore=GL08
364+
365+
366+ @overload
367+ def lazy_apply_elementwise ( # type: ignore[explicit-any,decorated-any]
368+ func : Callable [..., Sequence [Array | ArrayLike ]],
369+ * args : Array | complex | None ,
370+ dtype : Sequence [DType | None ],
371+ as_numpy : bool = False ,
372+ xp : ModuleType | None = None ,
373+ ** kwargs : Any ,
374+ ) -> tuple [Array , ...]: ... # numpydoc ignore=GL08
375+
376+
377+ def lazy_apply_elementwise ( # type: ignore[explicit-any]
378+ func : Callable [..., Array | ArrayLike | Sequence [Array | ArrayLike ]],
379+ * args : Array | complex | None ,
380+ dtype : DType | Sequence [DType | None ] | None = None ,
381+ as_numpy : bool = False ,
382+ xp : ModuleType | None = None ,
383+ ** kwargs : Any ,
384+ ) -> Array | tuple [Array , ...]:
385+ """
386+ Lazily apply an eager elementwise function.
387+
388+ This is a variant of :func:`lazy_apply` which expects `func` to be elementwise, e.g.
389+ each output point must depend exclusively from the corresponding input point in each
390+ inputarray. This can result in faster execution on some backends.
391+
392+ Parameters
393+ ----------
394+ func : callable
395+ As in `lazy_apply`, but in addition it must be elementwise.
396+ *args : Array | int | float | complex | bool | None
397+ As in `lazy_apply`.
398+ dtype : DType | Sequence[DType | None], optional
399+ Output dtype or sequence of output dtypes, one for each output of `func`.
400+ dtype(s) must belong to the same array namespace as the input arrays.
401+ This also informs how many outputs the function has.
402+ Default: assume a single output and infer the result type(s) from
403+ the input arrays.
404+ as_numpy : bool, optional
405+ As in `lazy_apply`.
406+ xp : array_namespace, optional
407+ The standard-compatible namespace for `args`. Default: infer.
408+ **kwargs : Any, optional
409+ As in `lazy_apply`.
410+
411+ Returns
412+ -------
413+ Array | tuple[Array, ...]
414+ The result(s) of `func` applied to the input arrays, wrapped in the same
415+ array namespace as the inputs.
416+ If dtype is omitted or a single dtype, return a single array.
417+ Otherwise, return a tuple of arrays.
418+
419+ See Also
420+ --------
421+ lazy_apply : General version of this function.
422+ dask.array.map_blocks : Dask version of this function.
423+
424+ Notes
425+ -----
426+ Unlike in :func:`lazy_apply`, you can't define output shapes that aren't
427+ broadcasted from the input arrays.
428+
429+ Dask
430+ Unlike :func:`dask.array.map_blocks`, this function allows for multiple outputs.
431+
432+ Dask wrapping around other backends
433+ If ``as_numpy=False``, `func` will receive in input eager arrays of the meta
434+ namespace, as defined by the ``._meta`` attribute of the input Dask arrays. The
435+ outputs of `func` will be wrapped by the meta namespace, and then wrapped again
436+ by Dask.
437+
438+ All other backends
439+ This function is identical to :func:`lazy_apply`.
440+ """
441+ args_not_none = [arg for arg in args if arg is not None ]
442+ array_args = [arg for arg in args_not_none if not is_python_scalar (arg )]
443+ if not array_args :
444+ msg = "Must have at least one argument array"
445+ raise ValueError (msg )
446+ if xp is None :
447+ xp = array_namespace (* array_args )
448+
449+ # Normalize and validate dtype
450+ dtypes : list [DType ]
451+
452+ if isinstance (dtype , Sequence ):
453+ multi_output = True
454+ if None in dtype :
455+ rtype = xp .result_type (* args_not_none )
456+ dtypes = [d or rtype for d in dtype ]
457+ else :
458+ dtypes = list (dtype ) # pyright: ignore[reportUnknownArgumentType]
459+ else :
460+ multi_output = False
461+ dtypes = [dtype ]
462+ del dtype
463+
464+ if not is_dask_namespace (xp ):
465+ shape = broadcast_shapes (* (arg .shape for arg in array_args ))
466+ return lazy_apply ( # pyright: ignore[reportCallIssue]
467+ func , # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
468+ * args ,
469+ shape = [shape ] * len (dtypes ) if multi_output else shape , # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
470+ dtype = dtypes if multi_output else dtypes [0 ],
471+ as_numpy = as_numpy ,
472+ xp = xp ,
473+ ** kwargs ,
474+ )
475+
476+ # Use da.map_blocks.
477+ # We need to handle multiple outputs, which map_blocks can't.
478+
479+ metas : list [Array ] = [arg ._meta for arg in array_args ] # type: ignore[attr-defined] # pylint: disable=protected-access # pyright: ignore[reportAttributeAccessIssue]
480+ meta_xp = array_namespace (* metas )
481+
482+ wrapped = _lazy_apply_wrapper (func , as_numpy , multi_output , meta_xp )
483+ wrapped = partial (wrapped , ** kwargs )
484+
485+ # Hack map_blocks to handle multiple outputs. This intermediate output has bugos
486+ # dtype and meta, but dask.array will never know as long as we always provide
487+ # explicit dtype and meta.
488+ temp = xp .map_blocks (wrapped , * args , dtype = dtypes [0 ], meta = metas [0 ])
489+ out = tuple (
490+ temp .map_blocks (operator .itemgetter (i ), dtype = dtype , meta = metas [0 ])
491+ for i , dtype in enumerate (dtypes )
492+ )
493+
494+ return out if multi_output else out [0 ]
0 commit comments