Skip to content

Commit fd6b9d8

Browse files
committed
Merge branch 'main' into pad-delegate
2 parents 71edc05 + ec890f1 commit fd6b9d8

File tree

17 files changed

+585
-206
lines changed

17 files changed

+585
-206
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ repos:
4444
- repo: https://github.com/astral-sh/ruff-pre-commit
4545
rev: "v0.8.2"
4646
hooks:
47+
- id: ruff-format
4748
- id: ruff
4849
args: ["--fix", "--show-fixes"]
49-
- id: ruff-format
5050

5151
- repo: https://github.com/codespell-project/codespell
5252
rev: "v2.3.0"

docs/api-reference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
create_diagonal
1313
expand_dims
1414
kron
15+
nunique
1516
pad
1617
setdiff1d
1718
sinc

pixi.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ filterwarnings = [
184184
]
185185
log_cli_level = "INFO"
186186
testpaths = ["tests"]
187-
187+
markers = ["skip_xp_backend(library, *, reason=None): Skip test for a specific backend"]
188188

189189
# Coverage
190190

@@ -303,6 +303,7 @@ messages_control.disable = [
303303
"line-too-long",
304304
"missing-module-docstring",
305305
"missing-function-docstring",
306+
"too-many-lines",
306307
"wrong-import-position",
307308
]
308309

@@ -319,6 +320,7 @@ checks = [
319320
exclude = [ # don't report on objects that match any of these regex
320321
'.*test_at.*',
321322
'.*test_funcs.*',
323+
'.*test_testing.*',
322324
'.*test_utils.*',
323325
'.*test_version.*',
324326
'.*test_vendor.*',

src/array_api_extra/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
create_diagonal,
99
expand_dims,
1010
kron,
11+
nunique,
1112
setdiff1d,
1213
sinc,
1314
)
@@ -23,6 +24,7 @@
2324
"create_diagonal",
2425
"expand_dims",
2526
"kron",
27+
"nunique",
2628
"pad",
2729
"setdiff1d",
2830
"sinc",

src/array_api_extra/_lib/_funcs.py

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
44
from __future__ import annotations
55

6+
import math
67
import operator
78
import warnings
89
from collections.abc import Callable
@@ -21,6 +22,7 @@
2122
"create_diagonal",
2223
"expand_dims",
2324
"kron",
25+
"nunique",
2426
"pad",
2527
"setdiff1d",
2628
"sinc",
@@ -210,8 +212,12 @@ def create_diagonal(
210212
raise ValueError(err_msg)
211213
n = x.shape[0] + abs(offset)
212214
diag = xp.zeros(n**2, dtype=x.dtype, device=_compat.device(x))
213-
i = offset if offset >= 0 else abs(offset) * n
214-
diag[i : min(n * (n - offset), diag.shape[0]) : n + 1] = x
215+
216+
start = offset if offset >= 0 else abs(offset) * n
217+
stop = min(n * (n - offset), diag.shape[0])
218+
step = n + 1
219+
diag = at(diag)[start:stop:step].set(x)
220+
215221
return xp.reshape(diag, (n, n))
216222

217223

@@ -403,9 +409,8 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
403409
result = xp.multiply(a_arr, b_arr)
404410

405411
# Reshape back and return
406-
a_shape = xp.asarray(a_shape)
407-
b_shape = xp.asarray(b_shape)
408-
return xp.reshape(result, tuple(xp.multiply(a_shape, b_shape)))
412+
res_shape = tuple(a_s * b_s for a_s, b_s in zip(a_shape, b_shape, strict=True))
413+
return xp.reshape(result, res_shape)
409414

410415

411416
def setdiff1d(
@@ -565,7 +570,7 @@ def pad(
565570
if isinstance(pad_width, tuple):
566571
pad_width = [pad_width] * x_ndim
567572

568-
# https://github.com/data-apis/array-api-extra/pull/82#discussion_r1905688819
573+
# https://github.com/python/typeshed/issues/13376
569574
slices: list[slice] = [] # type: ignore[no-any-explicit]
570575
newshape: list[int] = []
571576
for ax, w_tpl in enumerate(pad_width):
@@ -592,8 +597,43 @@ def pad(
592597
dtype=x.dtype,
593598
device=_compat.device(x),
594599
)
595-
padded[tuple(slices)] = x
596-
return padded
600+
return at(padded, tuple(slices)).set(x)
601+
602+
603+
def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array:
604+
"""
605+
Count the number of unique elements in an array.
606+
607+
Compatible with JAX and Dask, whose laziness would be otherwise
608+
problematic.
609+
610+
Parameters
611+
----------
612+
x : Array
613+
Input array.
614+
xp : array_namespace, optional
615+
The standard-compatible namespace for `x`. Default: infer.
616+
617+
Returns
618+
-------
619+
array: 0-dimensional integer array
620+
The number of unique elements in `x`. It can be lazy.
621+
"""
622+
if xp is None:
623+
xp = array_namespace(x)
624+
625+
if is_jax_array(x):
626+
# size= is JAX-specific
627+
# https://github.com/data-apis/array-api/issues/883
628+
_, counts = xp.unique_counts(x, size=_compat.size(x))
629+
return xp.astype(counts, xp.bool).sum()
630+
631+
_, counts = xp.unique_counts(x)
632+
n = _compat.size(counts)
633+
# FIXME https://github.com/data-apis/array-api-compat/pull/231
634+
if n is None or math.isnan(n): # e.g. Dask, ndonnx
635+
return xp.astype(counts, xp.bool).sum()
636+
return xp.asarray(n, device=_compat.device(x))
597637

598638

599639
class _AtOp(Enum):
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
"""
2+
Testing utilities.
3+
4+
Note that this is private API; don't expect it to be stable.
5+
"""
6+
7+
from types import ModuleType
8+
9+
from ._utils._compat import (
10+
array_namespace,
11+
is_cupy_namespace,
12+
is_pydata_sparse_namespace,
13+
is_torch_namespace,
14+
)
15+
from ._utils._typing import Array
16+
17+
__all__ = ["xp_assert_close", "xp_assert_equal"]
18+
19+
20+
def _check_ns_shape_dtype(
21+
actual: Array, desired: Array
22+
) -> ModuleType: # numpydoc ignore=RT03
23+
"""
24+
Assert that namespace, shape and dtype of the two arrays match.
25+
26+
Parameters
27+
----------
28+
actual : Array
29+
The array produced by the tested function.
30+
desired : Array
31+
The expected array (typically hardcoded).
32+
33+
Returns
34+
-------
35+
Arrays namespace.
36+
"""
37+
actual_xp = array_namespace(actual) # Raises on scalars and lists
38+
desired_xp = array_namespace(desired)
39+
40+
msg = f"namespaces do not match: {actual_xp} != f{desired_xp}"
41+
assert actual_xp == desired_xp, msg
42+
43+
msg = f"shapes do not match: {actual.shape} != f{desired.shape}"
44+
assert actual.shape == desired.shape, msg
45+
46+
msg = f"dtypes do not match: {actual.dtype} != {desired.dtype}"
47+
assert actual.dtype == desired.dtype, msg
48+
49+
return desired_xp
50+
51+
52+
def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
53+
"""
54+
Array-API compatible version of `np.testing.assert_array_equal`.
55+
56+
Parameters
57+
----------
58+
actual : Array
59+
The array produced by the tested function.
60+
desired : Array
61+
The expected array (typically hardcoded).
62+
err_msg : str, optional
63+
Error message to display on failure.
64+
"""
65+
xp = _check_ns_shape_dtype(actual, desired)
66+
67+
if is_cupy_namespace(xp):
68+
xp.testing.assert_array_equal(actual, desired, err_msg=err_msg)
69+
elif is_torch_namespace(xp):
70+
# PyTorch recommends using `rtol=0, atol=0` like this
71+
# to test for exact equality
72+
xp.testing.assert_close(
73+
actual,
74+
desired,
75+
rtol=0,
76+
atol=0,
77+
equal_nan=True,
78+
check_dtype=False,
79+
msg=err_msg or None,
80+
)
81+
else:
82+
import numpy as np # pylint: disable=import-outside-toplevel
83+
84+
if is_pydata_sparse_namespace(xp):
85+
actual = actual.todense()
86+
desired = desired.todense()
87+
88+
# JAX uses `np.testing`
89+
np.testing.assert_array_equal(actual, desired, err_msg=err_msg)
90+
91+
92+
def xp_assert_close(
93+
actual: Array,
94+
desired: Array,
95+
*,
96+
rtol: float | None = None,
97+
atol: float = 0,
98+
err_msg: str = "",
99+
) -> None:
100+
"""
101+
Array-API compatible version of `np.testing.assert_allclose`.
102+
103+
Parameters
104+
----------
105+
actual : Array
106+
The array produced by the tested function.
107+
desired : Array
108+
The expected array (typically hardcoded).
109+
rtol : float, optional
110+
Relative tolerance. Default: dtype-dependent.
111+
atol : float, optional
112+
Absolute tolerance. Default: 0.
113+
err_msg : str, optional
114+
Error message to display on failure.
115+
"""
116+
xp = _check_ns_shape_dtype(actual, desired)
117+
118+
floating = xp.isdtype(actual.dtype, ("real floating", "complex floating"))
119+
if rtol is None and floating:
120+
# multiplier of 4 is used as for `np.float64` this puts the default `rtol`
121+
# roughly half way between sqrt(eps) and the default for
122+
# `numpy.testing.assert_allclose`, 1e-7
123+
rtol = xp.finfo(actual.dtype).eps ** 0.5 * 4
124+
elif rtol is None:
125+
rtol = 1e-7
126+
127+
if is_cupy_namespace(xp):
128+
xp.testing.assert_allclose(
129+
actual, desired, rtol=rtol, atol=atol, err_msg=err_msg
130+
)
131+
elif is_torch_namespace(xp):
132+
xp.testing.assert_close(
133+
actual, desired, rtol=rtol, atol=atol, equal_nan=True, msg=err_msg or None
134+
)
135+
else:
136+
import numpy as np # pylint: disable=import-outside-toplevel
137+
138+
if is_pydata_sparse_namespace(xp):
139+
actual = actual.to_dense()
140+
desired = desired.to_dense()
141+
142+
# JAX uses `np.testing`
143+
assert isinstance(rtol, float)
144+
np.testing.assert_allclose(
145+
actual, desired, rtol=rtol, atol=atol, err_msg=err_msg
146+
)

src/array_api_extra/_lib/_utils/_compat.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
is_jax_array,
1111
is_jax_namespace,
1212
is_numpy_namespace,
13+
is_pydata_sparse_namespace,
1314
is_torch_namespace,
1415
is_writeable_array,
16+
size,
1517
)
1618
except ImportError:
1719
from array_api_compat import ( # pyright: ignore[reportMissingTypeStubs]
@@ -21,8 +23,10 @@
2123
is_jax_array,
2224
is_jax_namespace,
2325
is_numpy_namespace,
26+
is_pydata_sparse_namespace,
2427
is_torch_namespace,
2528
is_writeable_array,
29+
size,
2630
)
2731

2832
__all__ = [
@@ -32,6 +36,8 @@
3236
"is_jax_array",
3337
"is_jax_namespace",
3438
"is_numpy_namespace",
39+
"is_pydata_sparse_namespace",
3540
"is_torch_namespace",
3641
"is_writeable_array",
42+
"size",
3743
]

src/array_api_extra/_lib/_utils/_compat.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,6 @@ def is_jax_namespace(xp: ModuleType, /) -> bool: ...
2323
def is_numpy_namespace(xp: ModuleType, /) -> bool: ...
2424
def is_torch_namespace(xp: ModuleType, /) -> bool: ...
2525
def is_jax_array(x: object, /) -> bool: ...
26+
def is_pydata_sparse_namespace(xp: ModuleType, /) -> bool: ...
2627
def is_writeable_array(x: object, /) -> bool: ...
28+
def size(x: Array, /) -> int | None: ...

src/array_api_extra/_lib/_utils/_helpers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@ def in1d(
5656
order = xp.argsort(ar, stable=True)
5757
reverse_order = xp.argsort(order, stable=True)
5858
sar = xp.take(ar, order, axis=0)
59-
if sar.size >= 1:
59+
ar_size = _compat.size(sar)
60+
assert ar_size is not None, "xp.unique*() on lazy backends raises"
61+
if ar_size >= 1:
6062
bool_ar = sar[1:] != sar[:-1] if invert else sar[1:] == sar[:-1]
6163
else:
6264
bool_ar = xp.asarray([False]) if invert else xp.asarray([True])

0 commit comments

Comments
 (0)