Skip to content

Commit

Permalink
first attempt at cupy nplike for virtual arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
ikrommyd committed Feb 5, 2025
1 parent 7c59dba commit 261b1ac
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 5 deletions.
15 changes: 14 additions & 1 deletion src/awkward/_backends/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
from collections.abc import Collection

from awkward._backends.backend import Backend
from awkward._nplikes.cupy import Cupy
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.numpy_like import NumpyLike, NumpyMetadata
from awkward._nplikes.virtual import VirtualArray
from awkward._typing import Callable, TypeAlias, TypeVar, cast
from awkward._util import UNSET, Sentinel

np = NumpyMetadata.instance()
numpy = Numpy.instance()
cupy = Cupy.instance()

D = TypeVar("D")
T = TypeVar("T")
Expand Down Expand Up @@ -70,7 +73,17 @@ def common_backend(backends: Collection[Backend]) -> Backend:


def backend_of_obj(obj, default: D | Sentinel = UNSET) -> Backend | D:
cls = type(obj)
if isinstance(obj, VirtualArray):
if obj.nplike is numpy:
cls = numpy.ndarray
elif obj.nplike is cupy:
cls = cupy.ndarray
else:
raise ValueError(
f"Only numpy and cupy nplikes are supported for VirtualArray. Received {type(obj.nplike)}"
)
else:
cls = type(obj)
try:
lookup = _type_to_backend_lookup[cls]
return lookup(obj)
Expand Down
15 changes: 12 additions & 3 deletions src/awkward/_nplikes/cupy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from awkward._nplikes.numpy_like import ArrayLike
from awkward._nplikes.placeholder import PlaceholderArray
from awkward._nplikes.shape import ShapeItem
from awkward._nplikes.virtual import VirtualArray, materialize_if_virtual
from awkward._typing import TYPE_CHECKING, Final

if TYPE_CHECKING:
Expand Down Expand Up @@ -47,8 +48,8 @@ def ndarray(self):
def frombuffer(
self, buffer, *, dtype: DTypeLike | None = None, count: ShapeItem = -1
) -> ArrayLike:
assert not isinstance(buffer, PlaceholderArray)
assert not isinstance(count, PlaceholderArray)
assert not isinstance(buffer, (PlaceholderArray, VirtualArray))
assert not isinstance(count, (PlaceholderArray, VirtualArray))
np_array = numpy.frombuffer(buffer, dtype=dtype, count=count)
return self._module.asarray(np_array)

Expand All @@ -57,6 +58,7 @@ def array_equal(
) -> bool:
assert not isinstance(x1, PlaceholderArray)
assert not isinstance(x2, PlaceholderArray)
x1, x2 = materialize_if_virtual(x1, x2)
if x1.shape != x2.shape:
return False
else:
Expand All @@ -67,6 +69,7 @@ def repeat(
):
assert not isinstance(x, PlaceholderArray)
assert not isinstance(repeats, PlaceholderArray)
x, repeats = materialize_if_virtual(x, repeats)
if axis is not None:
raise NotImplementedError(f"repeat for CuPy with axis={axis!r}")
# https://github.com/cupy/cupy/issues/3849
Expand All @@ -91,6 +94,7 @@ def all(
maybe_out: ArrayLike | None = None,
) -> ArrayLike:
assert not isinstance(x, PlaceholderArray)
(x,) = materialize_if_virtual(x)
out = self._module.all(x, axis=axis, out=maybe_out)
if axis is None and isinstance(out, self._module.ndarray):
return out.item()
Expand All @@ -106,6 +110,7 @@ def any(
maybe_out: ArrayLike | None = None,
) -> ArrayLike:
assert not isinstance(x, PlaceholderArray)
(x,) = materialize_if_virtual(x)
out = self._module.any(x, axis=axis, out=maybe_out)
if axis is None and isinstance(out, self._module.ndarray):
return out.item()
Expand All @@ -116,6 +121,7 @@ def count_nonzero(
self, x: ArrayLike, *, axis: ShapeItem | tuple[ShapeItem, ...] | None = None
) -> ArrayLike:
assert not isinstance(x, PlaceholderArray)
(x,) = materialize_if_virtual(x)
assert isinstance(axis, int) or axis is None
out = self._module.count_nonzero(x, axis=axis)
if axis is None and isinstance(out, self._module.ndarray):
Expand All @@ -132,6 +138,7 @@ def min(
maybe_out: ArrayLike | None = None,
) -> ArrayLike:
assert not isinstance(x, PlaceholderArray)
(x,) = materialize_if_virtual(x)
out = self._module.min(x, axis=axis, out=maybe_out)
if axis is None and isinstance(out, self._module.ndarray):
return out.item()
Expand All @@ -147,6 +154,7 @@ def max(
maybe_out: ArrayLike | None = None,
) -> ArrayLike:
assert not isinstance(x, PlaceholderArray)
(x,) = materialize_if_virtual(x)
out = self._module.max(x, axis=axis, out=maybe_out)
if axis is None and isinstance(out, self._module.ndarray):
return out.item()
Expand All @@ -166,7 +174,8 @@ def is_own_array_type(cls, type_: type) -> bool:
return module == "cupy"

def is_c_contiguous(self, x: ArrayLike) -> bool:
if isinstance(x, PlaceholderArray):
# TODO: Should this materialize virtual arrays?
if isinstance(x, (PlaceholderArray, VirtualArray)):
return True
else:
return x.flags["C_CONTIGUOUS"] # type: ignore[attr-defined]
13 changes: 12 additions & 1 deletion src/awkward/_nplikes/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from awkward._nplikes.array_like import ArrayLike
from awkward._nplikes.numpy_like import NumpyLike
from awkward._nplikes.virtual import VirtualArray
from awkward._typing import Any, TypeVar, cast
from awkward._util import UNSET, Sentinel

Expand Down Expand Up @@ -38,7 +39,17 @@ def nplike_of_obj(
if it is set, otherwise `Numpy.instance()`.
"""

cls = type(obj)
if isinstance(obj, VirtualArray):
if type(obj.nplike).__name__ == "Numpy":
cls = obj.nplike.ndarray
elif type(obj.nplike).__name__ == "Cupy":
cls = obj.nplike.ndarray
else:
raise ValueError(
f"Only numpy and cupy nplikes are supported for VirtualArray. Received {type(obj.nplike)}"
)
else:
cls = type(obj)
try:
return _type_to_nplike[cls]
except KeyError:
Expand Down
5 changes: 5 additions & 0 deletions src/awkward/_nplikes/virtual.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ def __init__(
], # annotation (should) make clear that it's a callable without(!) arguments that returns an ArrayLike
form_key: str | None = None,
) -> None:
if type(nplike).__name__ not in ("Numpy", "Cupy"):
raise ValueError(
f"Only numpy and cupy nplikes are supported for VirtualArray. Received {type(nplike)}"
)

# array metadata
self._nplike = nplike
self._shape = shape
Expand Down

0 comments on commit 261b1ac

Please sign in to comment.