diff --git a/marray/__init__.py b/marray/__init__.py index 8edaea4..205bae1 100644 --- a/marray/__init__.py +++ b/marray/__init__.py @@ -388,8 +388,15 @@ def take(x, indices, /, *, axis=None): indices_mask = getattr(indices, 'mask', xp.broadcast_to(xp.asarray(False), shape)) indices_data[indices_mask] = 0 # ensure valid index data = xp.take(x.data, indices_data, axis=axis) - mask = xp.take(x.mask, indices_data, axis=axis) | indices_mask - return MArray(data, mask=mask) + mask = xp.take(x.mask, indices_data, axis=axis) + # align `indices_mask` along `axis` + # at this point, standard guarantees that `x` has at least 1 dim, + # `axis` is `None` only if `x.ndim == 1`, and `indices` was 1d. + working_axis = -1 if x.ndim == 1 else axis + new_shape = [1,] * x.ndim + new_shape[working_axis] = _get_size(indices_data) + indices_mask = xp.reshape(indices_mask, tuple(new_shape)) + return MArray(data, mask=mask | indices_mask) mod.take = take def take_along_axis(x, indices, /, *, axis=-1): @@ -609,7 +616,7 @@ def _cumulative_op(x, *args, _identity, _op, **kwargs): x = asarray(x) axis = kwargs.get('axis', None) if axis is None: - x = mod.reshape(x, -1) + x = mod.reshape(x, (-1,)) data = xp.asarray(x.data, copy=True) mask = x.mask diff --git a/marray/tests/test_marray.py b/marray/tests/test_marray.py index 71d5d32..e72d9eb 100644 --- a/marray/tests/test_marray.py +++ b/marray/tests/test_marray.py @@ -593,6 +593,25 @@ def test_take_along_axis(dtype, xp, seed=None): assert_equal(res, ref, xp=xp, seed=seed) +@pytest.mark.parametrize("dtype", dtypes_all) +@pytest.mark.parametrize("xp", xps) +@pass_exceptions(allowed=torch_exceptions) +def test_take(dtype, xp, seed=None): + mxp = marray.masked_namespace(xp) + marrays, masked_arrays, seed = get_arrays(1, dtype=dtype, xp=xp, seed=seed) + ndim = marrays[0].ndim + shape = marrays[0].shape + + rng = np.random.default_rng(seed) + axis = rng.integers(-ndim, ndim) + index_size = rng.integers(100) + indices = rng.integers(shape[axis], size=index_size) + + res = mxp.take(marrays[0], xp.asarray(indices), axis=axis) + ref = np.ma.take(masked_arrays[0], indices, axis=axis) + assert_equal(res, ref, xp=xp, seed=seed) + + @pytest.mark.parametrize("dtype", dtypes_all) @pytest.mark.parametrize('xp', xps) @pass_exceptions(allowed=["object has no attribute 'to_device'"]) # torch/cupy @@ -1451,5 +1470,5 @@ def test_gh99(xp): def test_test(): # dev tool to reproduce a particular failure of a `parametrize`d test - seed = 91803015965563856304156452253329804912 - test_nonzero("complex128", np, seed=seed) + seed = 98806759374046640850898260001383604577 + test_take("bool", np, seed=seed)