1616from array_api_extra ._lib ._utils ._typing import Array , SetIndex
1717from array_api_extra .testing import lazy_xp_function
1818
19- pytestmark = [
20- pytest .mark .skip_xp_backend (
21- Backend .SPARSE , reason = "read-only backend without .at support"
22- )
23- ]
19+ sparse_xfail = pytest .mark .skip_xp_backend (
20+ Backend .SPARSE , reason = "read-only backend without .at support"
21+ )
2422
2523
2624def at_op (
@@ -100,16 +98,28 @@ def assert_copy(
10098 (_AtOp .MULTIPLY , 2.0 , [10.0 , 40.0 , 60.0 ]),
10199 (_AtOp .DIVIDE , 2.0 , [10.0 , 10.0 , 15.0 ]),
102100 (_AtOp .POWER , 2.0 , [10.0 , 400.0 , 900.0 ]),
103- (_AtOp .MIN , 25.0 , [10.0 , 20.0 , 25.0 ]),
104- (_AtOp .MAX , 25.0 , [10.0 , 25.0 , 30.0 ]),
101+ pytest .param (
102+ _AtOp .MIN ,
103+ 25.0 ,
104+ [10.0 , 20.0 , 25.0 ],
105+ # test passes when copy=False
106+ marks = pytest .mark .skip (reason = "no minimum" ),
107+ ),
108+ pytest .param (
109+ _AtOp .MAX ,
110+ 25.0 ,
111+ [10.0 , 25.0 , 30.0 ],
112+ # test passes when copy=False
113+ marks = pytest .mark .skip (reason = "no maximum" ),
114+ ),
105115 ],
106116)
107117@pytest .mark .parametrize (
108118 ("bool_mask" , "x_ndim" , "y_ndim" ),
109119 [
110- (False , 1 , 0 ),
111- (False , 1 , 1 ),
112- (True , 1 , 0 ), # Uses xp.where(idx, y, x) on JAX and Dask
120+ pytest . param (False , 1 , 0 , marks = sparse_xfail ),
121+ pytest . param (False , 1 , 1 , marks = sparse_xfail ),
122+ (True , 1 , 0 ), # Uses xp.where(idx, y, x) on JAX, Dask, and Sparse
113123 pytest .param (
114124 * (True , 1 , 1 ),
115125 marks = (
@@ -119,9 +129,12 @@ def assert_copy(
119129 pytest .mark .xfail_xp_backend (
120130 Backend .DASK , reason = "bool mask update with shaped rhs"
121131 ),
132+ pytest .mark .skip_xp_backend ( # test passes when copy=False
133+ Backend .SPARSE , reason = "bool mask update with shaped rhs"
134+ ),
122135 ),
123136 ),
124- (False , 0 , 0 ),
137+ pytest . param (False , 0 , 0 , marks = sparse_xfail ),
125138 (True , 0 , 0 ),
126139 ],
127140)
@@ -158,8 +171,9 @@ def test_update_ops(
158171 xp_assert_equal (z , xp .asarray (expect ))
159172
160173
174+ @sparse_xfail
161175@pytest .mark .parametrize ("op" , list (_AtOp ))
162- def test_copy_default (xp : ModuleType , library : Backend , op : _AtOp ):
176+ def test_copy_default (xp : ModuleType , op : _AtOp ):
163177 """
164178 Test that the default copy behaviour is False for writeable arrays
165179 and True for read-only ones.
@@ -170,6 +184,12 @@ def test_copy_default(xp: ModuleType, library: Backend, op: _AtOp):
170184 with assert_copy (x , None , expect_copy ):
171185 _ = meth (2.0 )
172186
187+
188+ @pytest .mark .parametrize ("op" , list (_AtOp ))
189+ def test_copy_default_bool_mask (xp : ModuleType , library : Backend , op : _AtOp ):
190+ if op in (_AtOp .MIN , _AtOp .MAX ) and library is Backend .SPARSE :
191+ pytest .xfail ("no minimum/maximum" )
192+
173193 x = xp .asarray ([1.0 , 10.0 , 20.0 ])
174194 # Dask's default copy value is True for bool masks,
175195 # even if the arrays are writeable.
@@ -215,7 +235,7 @@ def test_alternate_index_syntax():
215235
216236
217237@pytest .mark .parametrize ("copy" , [True , None ])
218- @pytest .mark .parametrize ("bool_mask" , [False , True ])
238+ @pytest .mark .parametrize ("bool_mask" , [pytest . param ( False , marks = sparse_xfail ) , True ])
219239@pytest .mark .parametrize ("op" , list (_AtOp ))
220240def test_incompatible_dtype (
221241 xp : ModuleType ,
@@ -255,9 +275,19 @@ def test_incompatible_dtype(
255275 elif library is Backend .DASK :
256276 z = at_op (x , idx , op , 1.1 , copy = copy )
257277
258- elif library is Backend .ARRAY_API_STRICT and op is not _AtOp .SET :
259- with pytest .raises (Exception , match = r"cast|promote|dtype" ):
260- _ = at_op (x , idx , op , 1.1 , copy = copy )
278+ elif library is Backend .SPARSE :
279+ if op in (_AtOp .MIN , _AtOp .MAX ):
280+ pytest .xfail ("no minimum/maximum" )
281+ z = at_op (x , idx , op , 1.1 , copy = copy )
282+
283+ elif library is Backend .ARRAY_API_STRICT :
284+ if op is _AtOp .SET :
285+ z = at_op (x , idx , op , 1.1 , copy = copy )
286+ else :
287+ with pytest .raises (Exception , match = r"cast|promote|dtype" ):
288+ _ = at_op (x , idx , op , 1.1 , copy = copy )
289+
290+ # numpy, torch, and cupy
261291
262292 elif op in (_AtOp .SET , _AtOp .MIN , _AtOp .MAX ):
263293 # There is no __i<op>__ version of these operations
@@ -305,7 +335,7 @@ def test_no_inf_warnings(xp: ModuleType, bool_mask: bool):
305335 ),
306336 ],
307337)
308- @pytest .mark .parametrize ("bool_mask" , [False , True ])
338+ @pytest .mark .parametrize ("bool_mask" , [pytest . param ( False , marks = sparse_xfail ) , True ])
309339def test_gh134 (xp : ModuleType , bool_mask : bool , copy : bool | None ):
310340 """
311341 Test that xpx.at doesn't encroach in a bug of dask.array.Array.__setitem__, which
0 commit comments