Skip to content

Commit

Permalink
Merge pull request #10584 from mattjj:remove-jnp-array-handling-raw-b…
Browse files Browse the repository at this point in the history
…uffers

PiperOrigin-RevId: 448720084
  • Loading branch information
jax authors committed May 14, 2022
2 parents 86899ee + c0d6a04 commit f26133c
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 17 deletions.
29 changes: 13 additions & 16 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1806,15 +1806,16 @@ def array(object, dtype=None, copy=True, order="K", ndmin=0):
lax_internal._check_user_dtype_supported(dtype, "array")

# Here we make a judgment call: we only return a weakly-typed array when the
# input object itself is weakly typed. That ensures asarray(x) is a no-op whenever
# x is weak, but avoids introducing weak types with something like array([1, 2, 3])
# input object itself is weakly typed. That ensures asarray(x) is a no-op
# whenever x is weak, but avoids introducing weak types with something like
# array([1, 2, 3])
weak_type = dtype is None and dtypes.is_weakly_typed(object)

# For Python scalar literals, call coerce_to_array to catch any overflow errors.
# We don't use dtypes.is_python_scalar because we don't want this triggering for
# traced values. We do this here because it matters whether or not dtype is None.
# We don't assign the result because we want the raw object to be used for type
# inference below.
# For Python scalar literals, call coerce_to_array to catch any overflow
# errors. We don't use dtypes.is_python_scalar because we don't want this
# triggering for traced values. We do this here because it matters whether or
# not dtype is None. We don't assign the result because we want the raw object
# to be used for type inference below.
if isinstance(object, (bool, int, float, complex)):
_ = dtypes.coerce_to_array(object, dtype)

Expand All @@ -1838,17 +1839,13 @@ def array(object, dtype=None, copy=True, order="K", ndmin=0):
ndarray_types = (device_array.DeviceArray, core.Tracer)

if not _any(isinstance(leaf, ndarray_types) for leaf in leaves):
# TODO(jakevdp): falling back to numpy here fails to overflow for lists containing
# large integers; see discussion in https://github.com/google/jax/pull/6047.
# More correct would be to call coerce_to_array on each leaf, but this may have
# performance implications.
# TODO(jakevdp): falling back to numpy here fails to overflow for lists
# containing large integers; see discussion in
# https://github.com/google/jax/pull/6047. More correct would be to call
# coerce_to_array on each leaf, but this may have performance implications.
out = np.array(object, dtype=dtype, ndmin=ndmin, copy=False)
elif isinstance(object, ndarray_types):
if object.aval is None:
# object is a raw buffer; convert to device array on its current device.
aval = ShapedArray(object.xla_shape().dimensions(), object.dtype,
weak_type=bool(getattr(object, "weak_type", False)))
object = device_array.make_device_array(aval, object.device(), object)
assert object.aval is not None
out = _array_copy(object) if copy else object
elif isinstance(object, (list, tuple)):
if object:
Expand Down
7 changes: 6 additions & 1 deletion tests/pmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,12 @@ def pmap(self):

def testDeviceBufferToArray(self):
sda = self.pmap(lambda x: x)(jnp.ones((jax.device_count(), 2)))
buf = sda.device_buffers[-1]

# Changed in https://github.com/google/jax/pull/10584 not to access
# sda.device_buffers, which isn't supported, and instead ensure fast slices
# of the arrays returned by pmap are set up correctly.
# buf = sda.device_buffers[-1]
buf = sda[-1]

view = jnp.array(buf, copy=False)
self.assertArraysEqual(sda[-1], view)
Expand Down

0 comments on commit f26133c

Please sign in to comment.