From c0d6a04b7666ff58fc92d5fee78772bf11012a8a Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Wed, 4 May 2022 22:59:06 -0700 Subject: [PATCH] remove jnp.array case for handling buffers w/ aval=None This functionality was added in #8134, but was superceded by later changes which ensured that we never produce DeviceArrays with their 'aval' property set to None (even when indexing ShardedDeviceArrays with integers, which used to be a problem case). --- jax/_src/numpy/lax_numpy.py | 29 +++++++++++++---------------- tests/pmap_test.py | 7 ++++++- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 17a7328ae78c..7b10b17b1d07 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -1806,15 +1806,16 @@ def array(object, dtype=None, copy=True, order="K", ndmin=0): lax_internal._check_user_dtype_supported(dtype, "array") # Here we make a judgment call: we only return a weakly-typed array when the - # input object itself is weakly typed. That ensures asarray(x) is a no-op whenever - # x is weak, but avoids introducing weak types with something like array([1, 2, 3]) + # input object itself is weakly typed. That ensures asarray(x) is a no-op + # whenever x is weak, but avoids introducing weak types with something like + # array([1, 2, 3]) weak_type = dtype is None and dtypes.is_weakly_typed(object) - # For Python scalar literals, call coerce_to_array to catch any overflow errors. - # We don't use dtypes.is_python_scalar because we don't want this triggering for - # traced values. We do this here because it matters whether or not dtype is None. - # We don't assign the result because we want the raw object to be used for type - # inference below. + # For Python scalar literals, call coerce_to_array to catch any overflow + # errors. We don't use dtypes.is_python_scalar because we don't want this + # triggering for traced values. We do this here because it matters whether or + # not dtype is None. We don't assign the result because we want the raw object + # to be used for type inference below. if isinstance(object, (bool, int, float, complex)): _ = dtypes.coerce_to_array(object, dtype) @@ -1838,17 +1839,13 @@ def array(object, dtype=None, copy=True, order="K", ndmin=0): ndarray_types = (device_array.DeviceArray, core.Tracer) if not _any(isinstance(leaf, ndarray_types) for leaf in leaves): - # TODO(jakevdp): falling back to numpy here fails to overflow for lists containing - # large integers; see discussion in https://github.com/google/jax/pull/6047. - # More correct would be to call coerce_to_array on each leaf, but this may have - # performance implications. + # TODO(jakevdp): falling back to numpy here fails to overflow for lists + # containing large integers; see discussion in + # https://github.com/google/jax/pull/6047. More correct would be to call + # coerce_to_array on each leaf, but this may have performance implications. out = np.array(object, dtype=dtype, ndmin=ndmin, copy=False) elif isinstance(object, ndarray_types): - if object.aval is None: - # object is a raw buffer; convert to device array on its current device. - aval = ShapedArray(object.xla_shape().dimensions(), object.dtype, - weak_type=bool(getattr(object, "weak_type", False))) - object = device_array.make_device_array(aval, object.device(), object) + assert object.aval is not None out = _array_copy(object) if copy else object elif isinstance(object, (list, tuple)): if object: diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 35d2ec84fcc8..0004751e4544 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -117,7 +117,12 @@ def pmap(self): def testDeviceBufferToArray(self): sda = self.pmap(lambda x: x)(jnp.ones((jax.device_count(), 2))) - buf = sda.device_buffers[-1] + + # Changed in https://github.com/google/jax/pull/10584 not to access + # sda.device_buffers, which isn't supported, and instead ensure fast slices + # of the arrays returned by pmap are set up correctly. + # buf = sda.device_buffers[-1] + buf = sda[-1] view = jnp.array(buf, copy=False) self.assertArraysEqual(sda[-1], view)