Skip to content

Commit 13181a6

Browse files
committed
Merge branch 'main' into apply
2 parents 514b9f7 + 904a7e2 commit 13181a6

File tree

8 files changed

+86
-135
lines changed

8 files changed

+86
-135
lines changed

pixi.lock

Lines changed: 48 additions & 73 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -125,22 +125,21 @@ python = "~=3.10.0"
125125
python = "~=3.13.0"
126126

127127
# Backends that can run on CPU-only hosts
128-
[tool.pixi.feature.backends.target.linux-64.dependencies]
128+
[tool.pixi.feature.backends.dependencies]
129129
pytorch = "*"
130130
dask = "*"
131-
sparse = ">=0.15"
131+
numba = "*" # sparse dependency
132+
133+
[tool.pixi.feature.backends.pypi-dependencies]
134+
sparse = { version = ">= 0.16.0b3" }
135+
136+
[tool.pixi.feature.backends.target.linux-64.dependencies]
132137
jax = "*"
133138

134139
[tool.pixi.feature.backends.target.osx-arm64.dependencies]
135-
pytorch = "*"
136-
dask = "*"
137-
sparse = ">=0.15"
138140
jax = "*"
139141

140142
[tool.pixi.feature.backends.target.win-64.dependencies]
141-
pytorch = "*"
142-
dask = "*"
143-
sparse = ">=0.15"
144143
# jax = "*" # unavailable
145144

146145
# Backends that require a GPU host and a CUDA driver

src/array_api_extra/_delegation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def pad(
125125
pad_width: int | tuple[int, int] | Sequence[tuple[int, int]],
126126
mode: Literal["constant"] = "constant",
127127
*,
128-
constant_values: bool | int | float | complex = 0,
128+
constant_values: complex = 0,
129129
xp: ModuleType | None = None,
130130
) -> Array:
131131
"""
@@ -168,7 +168,7 @@ def pad(
168168
pad_width = xp.flip(pad_width, axis=(0,)).flatten()
169169
return xp.nn.functional.pad(x, tuple(pad_width), value=constant_values) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
170170

171-
if _delegate(xp, Backend.NUMPY, Backend.JAX, Backend.CUPY):
171+
if _delegate(xp, Backend.NUMPY, Backend.JAX, Backend.CUPY, Backend.SPARSE):
172172
return xp.pad(x, pad_width, mode, constant_values=constant_values)
173173

174174
return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp)

src/array_api_extra/_lib/_funcs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,7 @@ def pad(
575575
x: Array,
576576
pad_width: int | tuple[int, int] | Sequence[tuple[int, int]],
577577
*,
578-
constant_values: bool | int | float | complex = 0,
578+
constant_values: complex = 0,
579579
xp: ModuleType,
580580
) -> Array: # numpydoc ignore=PR01,RT01
581581
"""See docstring in `array_api_extra._delegation.py`."""

src/array_api_extra/_lib/_testing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,8 @@ def xp_assert_close(
166166
import numpy as np # pylint: disable=import-outside-toplevel
167167

168168
if is_pydata_sparse_namespace(xp):
169-
actual = actual.to_dense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
170-
desired = desired.to_dense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
169+
actual = actual.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
170+
desired = desired.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
171171

172172
# JAX uses `np.testing`
173173
assert isinstance(rtol, float)

tests/test_funcs.py

Lines changed: 21 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import contextlib
21
import math
32
import warnings
43
from types import ModuleType
@@ -24,7 +23,7 @@
2423
from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
2524
from array_api_extra._lib._utils._compat import device as get_device
2625
from array_api_extra._lib._utils._helpers import eager_shape, ndindex
27-
from array_api_extra._lib._utils._typing import Array, Device
26+
from array_api_extra._lib._utils._typing import Device
2827
from array_api_extra.testing import lazy_xp_function
2928

3029
# some xp backends are untyped
@@ -42,7 +41,6 @@
4241
lazy_xp_function(sinc, static_argnames="xp")
4342

4443

45-
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no expand_dims")
4644
class TestAtLeastND:
4745
def test_0D(self, xp: ModuleType):
4846
x = xp.asarray(1.0)
@@ -69,7 +67,7 @@ def test_1D(self, xp: ModuleType):
6967
xp_assert_equal(y, xp.asarray([[0, 1]]))
7068

7169
y = atleast_nd(x, ndim=5)
72-
xp_assert_equal(y, xp.reshape(xp.arange(2), (1, 1, 1, 1, 2)))
70+
xp_assert_equal(y, xp.asarray([[[[[0, 1]]]]]))
7371

7472
def test_2D(self, xp: ModuleType):
7573
x = xp.asarray([[3.0]])
@@ -218,8 +216,10 @@ def test_xp(self, xp: ModuleType):
218216
)
219217

220218

219+
@pytest.mark.skip_xp_backend(
220+
Backend.SPARSE, reason="read-only backend without .at support"
221+
)
221222
class TestCreateDiagonal:
222-
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in zeros()")
223223
def test_1d_from_numpy(self, xp: ModuleType):
224224
# from np.diag tests
225225
vals = 100 * xp.arange(5, dtype=xp.float64)
@@ -235,7 +235,6 @@ def test_1d_from_numpy(self, xp: ModuleType):
235235
xp_assert_equal(create_diagonal(vals, offset=2), b)
236236
xp_assert_equal(create_diagonal(vals, offset=-2), c)
237237

238-
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in zeros()")
239238
@pytest.mark.parametrize("n", range(1, 10))
240239
@pytest.mark.parametrize("offset", range(1, 10))
241240
def test_1d_from_scipy(self, xp: ModuleType, n: int, offset: int):
@@ -251,7 +250,6 @@ def test_0d_raises(self, xp: ModuleType):
251250
with pytest.raises(ValueError, match="1-dimensional"):
252251
_ = create_diagonal(xp.asarray(1))
253252

254-
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in zeros()")
255253
@pytest.mark.parametrize(
256254
"shape",
257255
[
@@ -277,38 +275,24 @@ def test_nd(self, xp: ModuleType, shape: tuple[int, ...]):
277275
for i in ndindex(*eager_shape(c)):
278276
xp_assert_equal(c[i], b[i[:-1]] if i[-2] == i[-1] else zero)
279277

280-
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in zeros()")
281278
def test_device(self, xp: ModuleType, device: Device):
282279
x = xp.asarray([1, 2, 3], device=device)
283280
assert get_device(create_diagonal(x)) == device
284281

285-
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in zeros()")
286282
def test_xp(self, xp: ModuleType):
287283
x = xp.asarray([1, 2])
288284
y = create_diagonal(x, xp=xp)
289285
xp_assert_equal(y, xp.asarray([[1, 0], [0, 2]]))
290286

291287

292288
class TestExpandDims:
293-
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no expand_dims")
294-
@pytest.mark.xfail_xp_backend(Backend.DASK, reason="tuple index out of range")
295-
@pytest.mark.xfail_xp_backend(Backend.TORCH, reason="tuple index out of range")
296-
def test_functionality(self, xp: ModuleType):
297-
def _squeeze_all(b: Array) -> Array:
298-
"""Mimics `np.squeeze(b)`. `xpx.squeeze`?"""
299-
for axis in range(b.ndim):
300-
with contextlib.suppress(ValueError):
301-
b = xp.squeeze(b, axis=axis)
302-
return b
303-
304-
s = (2, 3, 4, 5)
305-
a = xp.empty(s)
289+
def test_single_axis(self, xp: ModuleType):
290+
"""Trivial case where xpx.expand_dims doesn't add anything to xp.expand_dims"""
291+
a = xp.empty((2, 3, 4, 5))
306292
for axis in range(-5, 4):
307293
b = expand_dims(a, axis=axis)
308-
assert b.shape[axis] == 1
309-
assert _squeeze_all(b).shape == s
294+
xp_assert_equal(b, xp.expand_dims(a, axis=axis))
310295

311-
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no expand_dims")
312296
def test_axis_tuple(self, xp: ModuleType):
313297
a = xp.empty((3, 3, 3))
314298
assert expand_dims(a, axis=(0, 1, 2)).shape == (1, 1, 1, 3, 3, 3)
@@ -317,8 +301,7 @@ def test_axis_tuple(self, xp: ModuleType):
317301
assert expand_dims(a, axis=(0, -3, -5)).shape == (1, 1, 3, 1, 3, 3)
318302

319303
def test_axis_out_of_range(self, xp: ModuleType):
320-
s = (2, 3, 4, 5)
321-
a = xp.empty(s)
304+
a = xp.empty((2, 3, 4, 5))
322305
with pytest.raises(IndexError, match="out of bounds"):
323306
_ = expand_dims(a, axis=-6)
324307
with pytest.raises(IndexError, match="out of bounds"):
@@ -341,12 +324,10 @@ def test_positive_negative_repeated(self, xp: ModuleType):
341324
with pytest.raises(ValueError, match="Duplicate dimensions"):
342325
_ = expand_dims(a, axis=(3, -3))
343326

344-
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no expand_dims")
345327
def test_device(self, xp: ModuleType, device: Device):
346328
x = xp.asarray([1, 2, 3], device=device)
347329
assert get_device(expand_dims(x, axis=0)) == device
348330

349-
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no expand_dims")
350331
def test_xp(self, xp: ModuleType):
351332
x = xp.asarray([1, 2, 3])
352333
y = expand_dims(x, axis=(0, 1, 2), xp=xp)
@@ -513,7 +494,6 @@ def test_xp(self, xp: ModuleType):
513494
xp_assert_equal(isclose(a, b, xp=xp), xp.asarray([True, False]))
514495

515496

516-
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no expand_dims")
517497
class TestKron:
518498
def test_basic(self, xp: ModuleType):
519499
# Using 0-dimensional array
@@ -572,6 +552,7 @@ def test_kron_shape(
572552
k = kron(a, b)
573553
assert k.shape == expected_shape
574554

555+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
575556
def test_python_scalar(self, xp: ModuleType):
576557
a = 1
577558
# Test no dtype promotion to xp.asarray(a); use b.dtype
@@ -614,25 +595,27 @@ def test_xp(self, xp: ModuleType):
614595
xp_assert_equal(nunique(a, xp=xp), xp.asarray(3))
615596

616597

617-
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no arange, no device")
618598
class TestPad:
619599
def test_simple(self, xp: ModuleType):
620-
a = xp.arange(1, 4)
600+
a = xp.asarray([1, 2, 3])
621601
padded = pad(a, 2)
622602
xp_assert_equal(padded, xp.asarray([0, 0, 1, 2, 3, 0, 0]))
623603

604+
@pytest.mark.xfail_xp_backend(
605+
Backend.SPARSE, reason="constant_values can only be equal to fill value"
606+
)
624607
def test_fill_value(self, xp: ModuleType):
625-
a = xp.arange(1, 4)
608+
a = xp.asarray([1, 2, 3])
626609
padded = pad(a, 2, constant_values=42)
627610
xp_assert_equal(padded, xp.asarray([42, 42, 1, 2, 3, 42, 42]))
628611

629612
def test_ndim(self, xp: ModuleType):
630-
a = xp.reshape(xp.arange(2 * 3 * 4), (2, 3, 4))
613+
a = xp.asarray(np.reshape(np.arange(2 * 3 * 4), (2, 3, 4)))
631614
padded = pad(a, 2)
632615
assert padded.shape == (6, 7, 8)
633616

634617
def test_mode_not_implemented(self, xp: ModuleType):
635-
a = xp.arange(3)
618+
a = xp.asarray([1, 2, 3])
636619
with pytest.raises(NotImplementedError, match="Only `'constant'`"):
637620
_ = pad(a, 2, mode="edge") # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
638621

@@ -645,7 +628,7 @@ def test_xp(self, xp: ModuleType):
645628
xp_assert_equal(padded, xp.asarray(0))
646629

647630
def test_tuple_width(self, xp: ModuleType):
648-
a = xp.reshape(xp.arange(12), (3, 4))
631+
a = xp.asarray(np.reshape(np.arange(12), (3, 4)))
649632
padded = pad(a, (1, 0))
650633
assert padded.shape == (4, 5)
651634

@@ -656,7 +639,7 @@ def test_tuple_width(self, xp: ModuleType):
656639
_ = pad(a, [(1, 2, 3)]) # type: ignore[list-item] # pyright: ignore[reportArgumentType]
657640

658641
def test_sequence_of_tuples_width(self, xp: ModuleType):
659-
a = xp.reshape(xp.arange(12), (3, 4))
642+
a = xp.asarray(np.reshape(np.arange(12), (3, 4)))
660643

661644
padded = pad(a, ((1, 0), (0, 2)))
662645
assert padded.shape == (4, 6)
@@ -678,7 +661,7 @@ def test_sequence_of_tuples_width(self, xp: ModuleType):
678661
)
679662

680663

681-
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in asarray()")
664+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no argsort")
682665
class TestSetDiff1D:
683666
@pytest.mark.xfail_xp_backend(Backend.DASK, reason="NaN-shaped arrays")
684667
@pytest.mark.xfail_xp_backend(

tests/test_helpers.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,8 @@
1717
lazy_xp_function(in1d, jax_jit=False, static_argnames=("assume_unique", "invert", "xp"))
1818

1919

20+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no unique_inverse")
2021
class TestIn1D:
21-
@pytest.mark.xfail_xp_backend(
22-
Backend.SPARSE, reason="no unique_inverse, no device kwarg in asarray()"
23-
)
2422
# cover both code paths
2523
@pytest.mark.parametrize(
2624
"n",
@@ -42,19 +40,15 @@ def test_no_invert_assume_unique(self, xp: ModuleType, n: int):
4240
actual = in1d(x1, x2)
4341
xp_assert_equal(actual, expected)
4442

45-
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in asarray")
4643
def test_device(self, xp: ModuleType, device: Device):
4744
x1 = xp.asarray([3, 8, 20], device=device)
4845
x2 = xp.asarray([2, 3, 4], device=device)
4946
assert get_device(in1d(x1, x2)) == device
5047

5148
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="xp=xp")
52-
@pytest.mark.xfail_xp_backend(
53-
Backend.SPARSE, reason="no arange, no device kwarg in asarray"
54-
)
5549
def test_xp(self, xp: ModuleType):
5650
x1 = xp.asarray([1, 6])
57-
x2 = xp.arange(5)
51+
x2 = xp.asarray([0, 1, 2, 3, 4])
5852
expected = xp.asarray([True, False])
5953
actual = in1d(x1, x2, xp=xp)
6054
xp_assert_equal(actual, expected)
@@ -90,7 +84,7 @@ class TestAsArrays:
9084
],
9185
)
9286
def test_array_vs_scalar(
93-
self, dtype: str, b: int | float | complex, defined: bool, xp: ModuleType
87+
self, dtype: str, b: complex, defined: bool, xp: ModuleType
9488
):
9589
a = xp.asarray(1, dtype=getattr(xp, dtype))
9690

@@ -158,7 +152,7 @@ def test_ndindex(shape: tuple[int, ...]):
158152
assert tuple(ndindex(*shape)) == tuple(np.ndindex(*shape))
159153

160154

161-
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="index by sparse array")
155+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="index by sparse array")
162156
def test_eager_shape(xp: ModuleType, library: Backend):
163157
a = xp.asarray([1, 2, 3])
164158
# Lazy arrays, like Dask, have an eager shape until you slice them with

tests/test_testing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def test_assert_close_tolerance(xp: ModuleType):
7272

7373

7474
@param_assert_equal_close
75-
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no bool indexing")
75+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="index by sparse array")
7676
def test_assert_close_equal_none_shape(xp: ModuleType, func: Callable[..., None]): # type: ignore[no-any-explicit]
7777
"""On dask and other lazy backends, test that a shape with NaN's or None's
7878
can be compared to a real shape.

0 commit comments

Comments
 (0)