Skip to content

Commit 0f4920f

Browse files
committed
WIP: ENH: xpx.at support for sparse with bool mask
1 parent 70c7c80 commit 0f4920f

File tree

2 files changed

+52
-20
lines changed

2 files changed

+52
-20
lines changed

src/array_api_extra/_lib/_at.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
array_namespace,
1414
is_dask_array,
1515
is_jax_array,
16+
is_pydata_sparse_array,
1617
is_writeable_array,
1718
)
1819
from ._utils._helpers import meta_namespace
@@ -287,11 +288,12 @@ def _op(
287288
writeable = None if copy else is_writeable_array(x)
288289

289290
# JAX inside jax.jit doesn't support in-place updates with boolean
290-
# masks; Dask exclusively supports __setitem__ but not iops.
291+
# masks; Dask exclusively supports __setitem__ but not iops;
292+
# Sparse doesn't support in-place updates full-stop.
291293
# We can handle the common special case of 0-dimensional y
292294
# with where(idx, y, x) instead.
293295
if (
294-
(is_dask_array(idx) or is_jax_array(idx))
296+
(is_dask_array(idx) or is_jax_array(idx) or is_pydata_sparse_array(idx))
295297
and idx.dtype == xp.bool
296298
and idx.shape == x.shape
297299
):
@@ -337,7 +339,7 @@ def _op(
337339
if writeable is None:
338340
writeable = is_writeable_array(x)
339341
if not writeable:
340-
# sparse crashes here
342+
# sparse with idx other than bool mask or shaped y crashes here
341343
msg = f"Can't update read-only array {x}"
342344
raise ValueError(msg)
343345

tests/test_at.py

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,9 @@
1616
from array_api_extra._lib._utils._typing import Array, SetIndex
1717
from array_api_extra.testing import lazy_xp_function
1818

19-
pytestmark = [
20-
pytest.mark.skip_xp_backend(
21-
Backend.SPARSE, reason="read-only backend without .at support"
22-
)
23-
]
19+
sparse_xfail = pytest.mark.skip_xp_backend(
20+
Backend.SPARSE, reason="read-only backend without .at support"
21+
)
2422

2523

2624
def at_op(
@@ -100,16 +98,28 @@ def assert_copy(
10098
(_AtOp.MULTIPLY, 2.0, [10.0, 40.0, 60.0]),
10199
(_AtOp.DIVIDE, 2.0, [10.0, 10.0, 15.0]),
102100
(_AtOp.POWER, 2.0, [10.0, 400.0, 900.0]),
103-
(_AtOp.MIN, 25.0, [10.0, 20.0, 25.0]),
104-
(_AtOp.MAX, 25.0, [10.0, 25.0, 30.0]),
101+
pytest.param(
102+
_AtOp.MIN,
103+
25.0,
104+
[10.0, 20.0, 25.0],
105+
# test passes when copy=False
106+
marks=pytest.mark.skip(reason="no minimum"),
107+
),
108+
pytest.param(
109+
_AtOp.MAX,
110+
25.0,
111+
[10.0, 25.0, 30.0],
112+
# test passes when copy=False
113+
marks=pytest.mark.skip(reason="no maximum"),
114+
),
105115
],
106116
)
107117
@pytest.mark.parametrize(
108118
("bool_mask", "x_ndim", "y_ndim"),
109119
[
110-
(False, 1, 0),
111-
(False, 1, 1),
112-
(True, 1, 0), # Uses xp.where(idx, y, x) on JAX and Dask
120+
pytest.param(False, 1, 0, marks=sparse_xfail),
121+
pytest.param(False, 1, 1, marks=sparse_xfail),
122+
(True, 1, 0), # Uses xp.where(idx, y, x) on JAX, Dask, and Sparse
113123
pytest.param(
114124
*(True, 1, 1),
115125
marks=(
@@ -119,9 +129,12 @@ def assert_copy(
119129
pytest.mark.xfail_xp_backend(
120130
Backend.DASK, reason="bool mask update with shaped rhs"
121131
),
132+
pytest.mark.skip_xp_backend( # test passes when copy=False
133+
Backend.SPARSE, reason="bool mask update with shaped rhs"
134+
),
122135
),
123136
),
124-
(False, 0, 0),
137+
pytest.param(False, 0, 0, marks=sparse_xfail),
125138
(True, 0, 0),
126139
],
127140
)
@@ -158,8 +171,9 @@ def test_update_ops(
158171
xp_assert_equal(z, xp.asarray(expect))
159172

160173

174+
@sparse_xfail
161175
@pytest.mark.parametrize("op", list(_AtOp))
162-
def test_copy_default(xp: ModuleType, library: Backend, op: _AtOp):
176+
def test_copy_default(xp: ModuleType, op: _AtOp):
163177
"""
164178
Test that the default copy behaviour is False for writeable arrays
165179
and True for read-only ones.
@@ -170,6 +184,12 @@ def test_copy_default(xp: ModuleType, library: Backend, op: _AtOp):
170184
with assert_copy(x, None, expect_copy):
171185
_ = meth(2.0)
172186

187+
188+
@pytest.mark.parametrize("op", list(_AtOp))
189+
def test_copy_default_bool_mask(xp: ModuleType, library: Backend, op: _AtOp):
190+
if op in (_AtOp.MIN, _AtOp.MAX) and library is Backend.SPARSE:
191+
pytest.xfail("no minimum/maximum")
192+
173193
x = xp.asarray([1.0, 10.0, 20.0])
174194
# Dask's default copy value is True for bool masks,
175195
# even if the arrays are writeable.
@@ -215,7 +235,7 @@ def test_alternate_index_syntax():
215235

216236

217237
@pytest.mark.parametrize("copy", [True, None])
218-
@pytest.mark.parametrize("bool_mask", [False, True])
238+
@pytest.mark.parametrize("bool_mask", [pytest.param(False, marks=sparse_xfail), True])
219239
@pytest.mark.parametrize("op", list(_AtOp))
220240
def test_incompatible_dtype(
221241
xp: ModuleType,
@@ -255,9 +275,19 @@ def test_incompatible_dtype(
255275
elif library is Backend.DASK:
256276
z = at_op(x, idx, op, 1.1, copy=copy)
257277

258-
elif library is Backend.ARRAY_API_STRICT and op is not _AtOp.SET:
259-
with pytest.raises(Exception, match=r"cast|promote|dtype"):
260-
_ = at_op(x, idx, op, 1.1, copy=copy)
278+
elif library is Backend.SPARSE:
279+
if op in (_AtOp.MIN, _AtOp.MAX):
280+
pytest.xfail("no minimum/maximum")
281+
z = at_op(x, idx, op, 1.1, copy=copy)
282+
283+
elif library is Backend.ARRAY_API_STRICT:
284+
if op is _AtOp.SET:
285+
z = at_op(x, idx, op, 1.1, copy=copy)
286+
else:
287+
with pytest.raises(Exception, match=r"cast|promote|dtype"):
288+
_ = at_op(x, idx, op, 1.1, copy=copy)
289+
290+
# numpy, torch, and cupy
261291

262292
elif op in (_AtOp.SET, _AtOp.MIN, _AtOp.MAX):
263293
# There is no __i<op>__ version of these operations
@@ -305,7 +335,7 @@ def test_no_inf_warnings(xp: ModuleType, bool_mask: bool):
305335
),
306336
],
307337
)
308-
@pytest.mark.parametrize("bool_mask", [False, True])
338+
@pytest.mark.parametrize("bool_mask", [pytest.param(False, marks=sparse_xfail), True])
309339
def test_gh134(xp: ModuleType, bool_mask: bool, copy: bool | None):
310340
"""
311341
Test that xpx.at doesn't encroach in a bug of dask.array.Array.__setitem__, which

0 commit comments

Comments
 (0)