Skip to content

Commit

Permalink
we should not instantiate cupy for everyone :)
Browse files Browse the repository at this point in the history
  • Loading branch information
ikrommyd committed Feb 6, 2025
1 parent 261b1ac commit 01fa4dc
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 8 deletions.
10 changes: 4 additions & 6 deletions src/awkward/_backends/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from collections.abc import Collection

from awkward._backends.backend import Backend
from awkward._nplikes.cupy import Cupy
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.numpy_like import NumpyLike, NumpyMetadata
from awkward._nplikes.virtual import VirtualArray
Expand All @@ -14,7 +13,6 @@

np = NumpyMetadata.instance()
numpy = Numpy.instance()
cupy = Cupy.instance()

D = TypeVar("D")
T = TypeVar("T")
Expand Down Expand Up @@ -74,10 +72,10 @@ def common_backend(backends: Collection[Backend]) -> Backend:

def backend_of_obj(obj, default: D | Sentinel = UNSET) -> Backend | D:
if isinstance(obj, VirtualArray):
if obj.nplike is numpy:
cls = numpy.ndarray
elif obj.nplike is cupy:
cls = cupy.ndarray
if type(obj.nplike).__name__ == "Numpy":
cls = obj.nplike.ndarray
elif type(obj.nplike).__name__ == "Cupy":
cls = obj.nplike.ndarray
else:
raise ValueError(
f"Only numpy and cupy nplikes are supported for VirtualArray. Received {type(obj.nplike)}"
Expand Down
2 changes: 0 additions & 2 deletions src/awkward/_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,6 @@ class JaxKernel(NumpyKernel):
def __call__(self, *args) -> None:
assert len(args) == len(self._impl.argtypes)

args = materialize_if_virtual(*args)

if not any(Jax.is_tracer_type(type(arg)) for arg in args):
return super().__call__(*args)

Expand Down

0 comments on commit 01fa4dc

Please sign in to comment.