Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions marray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,15 @@ def take(x, indices, /, *, axis=None):
indices_mask = getattr(indices, 'mask', xp.broadcast_to(xp.asarray(False), shape))
indices_data[indices_mask] = 0 # ensure valid index
data = xp.take(x.data, indices_data, axis=axis)
mask = xp.take(x.mask, indices_data, axis=axis) | indices_mask
return MArray(data, mask=mask)
mask = xp.take(x.mask, indices_data, axis=axis)
# align `indices_mask` along `axis`
# at this point, standard guarantees that `x` has at least 1 dim,
# `axis` is `None` only if `x.ndim == 1`, and `indices` was 1d.
working_axis = -1 if x.ndim == 1 else axis
new_shape = [1,] * x.ndim
new_shape[working_axis] = _get_size(indices_data)
indices_mask = xp.reshape(indices_mask, tuple(new_shape))
return MArray(data, mask=mask | indices_mask)
mod.take = take

def take_along_axis(x, indices, /, *, axis=-1):
Expand Down Expand Up @@ -609,7 +616,7 @@ def _cumulative_op(x, *args, _identity, _op, **kwargs):
x = asarray(x)
axis = kwargs.get('axis', None)
if axis is None:
x = mod.reshape(x, -1)
x = mod.reshape(x, (-1,))

data = xp.asarray(x.data, copy=True)
mask = x.mask
Expand Down
23 changes: 21 additions & 2 deletions marray/tests/test_marray.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,25 @@ def test_take_along_axis(dtype, xp, seed=None):
assert_equal(res, ref, xp=xp, seed=seed)


@pytest.mark.parametrize("dtype", dtypes_all)
@pytest.mark.parametrize("xp", xps)
@pass_exceptions(allowed=torch_exceptions)
def test_take(dtype, xp, seed=None):
mxp = marray.masked_namespace(xp)
marrays, masked_arrays, seed = get_arrays(1, dtype=dtype, xp=xp, seed=seed)
ndim = marrays[0].ndim
shape = marrays[0].shape

rng = np.random.default_rng(seed)
axis = rng.integers(-ndim, ndim)
index_size = rng.integers(100)
indices = rng.integers(shape[axis], size=index_size)

res = mxp.take(marrays[0], xp.asarray(indices), axis=axis)
ref = np.ma.take(masked_arrays[0], indices, axis=axis)
assert_equal(res, ref, xp=xp, seed=seed)


@pytest.mark.parametrize("dtype", dtypes_all)
@pytest.mark.parametrize('xp', xps)
@pass_exceptions(allowed=["object has no attribute 'to_device'"]) # torch/cupy
Expand Down Expand Up @@ -1451,5 +1470,5 @@ def test_gh99(xp):

def test_test():
# dev tool to reproduce a particular failure of a `parametrize`d test
seed = 91803015965563856304156452253329804912
test_nonzero("complex128", np, seed=seed)
seed = 98806759374046640850898260001383604577
test_take("bool", np, seed=seed)
Loading