1- import contextlib
21import math
32import warnings
43from types import ModuleType
2423from array_api_extra ._lib ._testing import xp_assert_close , xp_assert_equal
2524from array_api_extra ._lib ._utils ._compat import device as get_device
2625from array_api_extra ._lib ._utils ._helpers import eager_shape , ndindex
27- from array_api_extra ._lib ._utils ._typing import Array , Device
26+ from array_api_extra ._lib ._utils ._typing import Device
2827from array_api_extra .testing import lazy_xp_function
2928
3029# some xp backends are untyped
4241lazy_xp_function (sinc , static_argnames = "xp" )
4342
4443
45- @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no expand_dims" )
4644class TestAtLeastND :
4745 def test_0D (self , xp : ModuleType ):
4846 x = xp .asarray (1.0 )
@@ -69,7 +67,7 @@ def test_1D(self, xp: ModuleType):
6967 xp_assert_equal (y , xp .asarray ([[0 , 1 ]]))
7068
7169 y = atleast_nd (x , ndim = 5 )
72- xp_assert_equal (y , xp .reshape ( xp . arange ( 2 ), ( 1 , 1 , 1 , 1 , 2 ) ))
70+ xp_assert_equal (y , xp .asarray ([[[[[ 0 , 1 ]]]]] ))
7371
7472 def test_2D (self , xp : ModuleType ):
7573 x = xp .asarray ([[3.0 ]])
@@ -218,8 +216,10 @@ def test_xp(self, xp: ModuleType):
218216 )
219217
220218
219+ @pytest .mark .skip_xp_backend (
220+ Backend .SPARSE , reason = "read-only backend without .at support"
221+ )
221222class TestCreateDiagonal :
222- @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no device kwarg in zeros()" )
223223 def test_1d_from_numpy (self , xp : ModuleType ):
224224 # from np.diag tests
225225 vals = 100 * xp .arange (5 , dtype = xp .float64 )
@@ -235,7 +235,6 @@ def test_1d_from_numpy(self, xp: ModuleType):
235235 xp_assert_equal (create_diagonal (vals , offset = 2 ), b )
236236 xp_assert_equal (create_diagonal (vals , offset = - 2 ), c )
237237
238- @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no device kwarg in zeros()" )
239238 @pytest .mark .parametrize ("n" , range (1 , 10 ))
240239 @pytest .mark .parametrize ("offset" , range (1 , 10 ))
241240 def test_1d_from_scipy (self , xp : ModuleType , n : int , offset : int ):
@@ -251,7 +250,6 @@ def test_0d_raises(self, xp: ModuleType):
251250 with pytest .raises (ValueError , match = "1-dimensional" ):
252251 _ = create_diagonal (xp .asarray (1 ))
253252
254- @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no device kwarg in zeros()" )
255253 @pytest .mark .parametrize (
256254 "shape" ,
257255 [
@@ -277,38 +275,24 @@ def test_nd(self, xp: ModuleType, shape: tuple[int, ...]):
277275 for i in ndindex (* eager_shape (c )):
278276 xp_assert_equal (c [i ], b [i [:- 1 ]] if i [- 2 ] == i [- 1 ] else zero )
279277
280- @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no device kwarg in zeros()" )
281278 def test_device (self , xp : ModuleType , device : Device ):
282279 x = xp .asarray ([1 , 2 , 3 ], device = device )
283280 assert get_device (create_diagonal (x )) == device
284281
285- @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no device kwarg in zeros()" )
286282 def test_xp (self , xp : ModuleType ):
287283 x = xp .asarray ([1 , 2 ])
288284 y = create_diagonal (x , xp = xp )
289285 xp_assert_equal (y , xp .asarray ([[1 , 0 ], [0 , 2 ]]))
290286
291287
292288class TestExpandDims :
293- @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no expand_dims" )
294- @pytest .mark .xfail_xp_backend (Backend .DASK , reason = "tuple index out of range" )
295- @pytest .mark .xfail_xp_backend (Backend .TORCH , reason = "tuple index out of range" )
296- def test_functionality (self , xp : ModuleType ):
297- def _squeeze_all (b : Array ) -> Array :
298- """Mimics `np.squeeze(b)`. `xpx.squeeze`?"""
299- for axis in range (b .ndim ):
300- with contextlib .suppress (ValueError ):
301- b = xp .squeeze (b , axis = axis )
302- return b
303-
304- s = (2 , 3 , 4 , 5 )
305- a = xp .empty (s )
289+ def test_single_axis (self , xp : ModuleType ):
290+ """Trivial case where xpx.expand_dims doesn't add anything to xp.expand_dims"""
291+ a = xp .empty ((2 , 3 , 4 , 5 ))
306292 for axis in range (- 5 , 4 ):
307293 b = expand_dims (a , axis = axis )
308- assert b .shape [axis ] == 1
309- assert _squeeze_all (b ).shape == s
294+ xp_assert_equal (b , xp .expand_dims (a , axis = axis ))
310295
311- @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no expand_dims" )
312296 def test_axis_tuple (self , xp : ModuleType ):
313297 a = xp .empty ((3 , 3 , 3 ))
314298 assert expand_dims (a , axis = (0 , 1 , 2 )).shape == (1 , 1 , 1 , 3 , 3 , 3 )
@@ -317,8 +301,7 @@ def test_axis_tuple(self, xp: ModuleType):
317301 assert expand_dims (a , axis = (0 , - 3 , - 5 )).shape == (1 , 1 , 3 , 1 , 3 , 3 )
318302
319303 def test_axis_out_of_range (self , xp : ModuleType ):
320- s = (2 , 3 , 4 , 5 )
321- a = xp .empty (s )
304+ a = xp .empty ((2 , 3 , 4 , 5 ))
322305 with pytest .raises (IndexError , match = "out of bounds" ):
323306 _ = expand_dims (a , axis = - 6 )
324307 with pytest .raises (IndexError , match = "out of bounds" ):
@@ -341,12 +324,10 @@ def test_positive_negative_repeated(self, xp: ModuleType):
341324 with pytest .raises (ValueError , match = "Duplicate dimensions" ):
342325 _ = expand_dims (a , axis = (3 , - 3 ))
343326
344- @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no expand_dims" )
345327 def test_device (self , xp : ModuleType , device : Device ):
346328 x = xp .asarray ([1 , 2 , 3 ], device = device )
347329 assert get_device (expand_dims (x , axis = 0 )) == device
348330
349- @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no expand_dims" )
350331 def test_xp (self , xp : ModuleType ):
351332 x = xp .asarray ([1 , 2 , 3 ])
352333 y = expand_dims (x , axis = (0 , 1 , 2 ), xp = xp )
@@ -513,7 +494,6 @@ def test_xp(self, xp: ModuleType):
513494 xp_assert_equal (isclose (a , b , xp = xp ), xp .asarray ([True , False ]))
514495
515496
516- @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no expand_dims" )
517497class TestKron :
518498 def test_basic (self , xp : ModuleType ):
519499 # Using 0-dimensional array
@@ -572,6 +552,7 @@ def test_kron_shape(
572552 k = kron (a , b )
573553 assert k .shape == expected_shape
574554
555+ @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no isdtype" )
575556 def test_python_scalar (self , xp : ModuleType ):
576557 a = 1
577558 # Test no dtype promotion to xp.asarray(a); use b.dtype
@@ -614,25 +595,27 @@ def test_xp(self, xp: ModuleType):
614595 xp_assert_equal (nunique (a , xp = xp ), xp .asarray (3 ))
615596
616597
617- @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no arange, no device" )
618598class TestPad :
619599 def test_simple (self , xp : ModuleType ):
620- a = xp .arange ( 1 , 4 )
600+ a = xp .asarray ([ 1 , 2 , 3 ] )
621601 padded = pad (a , 2 )
622602 xp_assert_equal (padded , xp .asarray ([0 , 0 , 1 , 2 , 3 , 0 , 0 ]))
623603
604+ @pytest .mark .xfail_xp_backend (
605+ Backend .SPARSE , reason = "constant_values can only be equal to fill value"
606+ )
624607 def test_fill_value (self , xp : ModuleType ):
625- a = xp .arange ( 1 , 4 )
608+ a = xp .asarray ([ 1 , 2 , 3 ] )
626609 padded = pad (a , 2 , constant_values = 42 )
627610 xp_assert_equal (padded , xp .asarray ([42 , 42 , 1 , 2 , 3 , 42 , 42 ]))
628611
629612 def test_ndim (self , xp : ModuleType ):
630- a = xp .reshape (xp .arange (2 * 3 * 4 ), (2 , 3 , 4 ))
613+ a = xp .asarray ( np . reshape (np .arange (2 * 3 * 4 ), (2 , 3 , 4 ) ))
631614 padded = pad (a , 2 )
632615 assert padded .shape == (6 , 7 , 8 )
633616
634617 def test_mode_not_implemented (self , xp : ModuleType ):
635- a = xp .arange ( 3 )
618+ a = xp .asarray ([ 1 , 2 , 3 ] )
636619 with pytest .raises (NotImplementedError , match = "Only `'constant'`" ):
637620 _ = pad (a , 2 , mode = "edge" ) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
638621
@@ -645,7 +628,7 @@ def test_xp(self, xp: ModuleType):
645628 xp_assert_equal (padded , xp .asarray (0 ))
646629
647630 def test_tuple_width (self , xp : ModuleType ):
648- a = xp .reshape (xp .arange (12 ), (3 , 4 ))
631+ a = xp .asarray ( np . reshape (np .arange (12 ), (3 , 4 ) ))
649632 padded = pad (a , (1 , 0 ))
650633 assert padded .shape == (4 , 5 )
651634
@@ -656,7 +639,7 @@ def test_tuple_width(self, xp: ModuleType):
656639 _ = pad (a , [(1 , 2 , 3 )]) # type: ignore[list-item] # pyright: ignore[reportArgumentType]
657640
658641 def test_sequence_of_tuples_width (self , xp : ModuleType ):
659- a = xp .reshape (xp .arange (12 ), (3 , 4 ))
642+ a = xp .asarray ( np . reshape (np .arange (12 ), (3 , 4 ) ))
660643
661644 padded = pad (a , ((1 , 0 ), (0 , 2 )))
662645 assert padded .shape == (4 , 6 )
@@ -678,7 +661,7 @@ def test_sequence_of_tuples_width(self, xp: ModuleType):
678661)
679662
680663
681- @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no device kwarg in asarray() " )
664+ @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no argsort " )
682665class TestSetDiff1D :
683666 @pytest .mark .xfail_xp_backend (Backend .DASK , reason = "NaN-shaped arrays" )
684667 @pytest .mark .xfail_xp_backend (
0 commit comments