Skip to content

Commit 1cf4974

Browse files
committed
WIP tougher restrictions on array_api_strict
1 parent 1e3614e commit 1cf4974

File tree

2 files changed

+25
-7
lines changed

2 files changed

+25
-7
lines changed

src/array_api_extra/_lib/_backends.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class Backend(Enum): # numpydoc ignore=PR01,PR02 # type: ignore[no-subclass-an
2424
"""
2525

2626
ARRAY_API_STRICT = "array_api_strict", _compat.is_array_api_strict_namespace
27+
ARRAY_API_STRICTEST = "array_api_strictest", _compat.is_array_api_strict_namespace
2728
NUMPY = "numpy", _compat.is_numpy_namespace
2829
NUMPY_READONLY = "numpy_readonly", _compat.is_numpy_namespace
2930
CUPY = "cupy", _compat.is_cupy_namespace

tests/conftest.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Pytest fixtures."""
22

3-
from collections.abc import Callable
3+
from collections.abc import Callable, Generator
44
from contextlib import suppress
55
from functools import partial, wraps
66
from types import ModuleType
@@ -104,7 +104,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
104104
@pytest.fixture
105105
def xp(
106106
library: Backend, request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch
107-
) -> ModuleType: # numpydoc ignore=PR01,RT03
107+
) -> Generator[ModuleType]: # numpydoc ignore=PR01,RT03
108108
"""
109109
Parameterized fixture that iterates on all libraries.
110110
@@ -113,7 +113,25 @@ def xp(
113113
The current array namespace.
114114
"""
115115
if library == Backend.NUMPY_READONLY:
116-
return NumPyReadOnly() # type: ignore[return-value] # pyright: ignore[reportReturnType]
116+
yield NumPyReadOnly() # type: ignore[return-value] # pyright: ignore[reportReturnType]
117+
return
118+
119+
if (
120+
library in (Backend.ARRAY_API_STRICT, Backend.ARRAY_API_STRICTEST)
121+
and np.__version__ < "1.26"
122+
):
123+
pytest.skip("array_api_strict is untested on NumPy <1.26")
124+
125+
if library == Backend.ARRAY_API_STRICTEST:
126+
xp = pytest.importorskip("array_api_strict")
127+
with xp.ArrayAPIStrictFlags(
128+
boolean_indexing=False,
129+
data_dependent_shapes=False,
130+
enabled_extensions=(),
131+
):
132+
yield xp
133+
return
134+
117135
xp = pytest.importorskip(library.value)
118136
# Possibly wrap module with array_api_compat
119137
xp = array_namespace(xp.empty(0))
@@ -122,16 +140,15 @@ def xp(
122140
# in the global scope of the module containing the test function.
123141
patch_lazy_xp_functions(request, monkeypatch, xp=xp)
124142

125-
if library == Backend.ARRAY_API_STRICT and np.__version__ < "1.26":
126-
pytest.skip("array_api_strict is untested on NumPy <1.26")
127-
128143
if library == Backend.JAX:
129144
import jax
130145

131146
# suppress unused-ignore to run mypy in -e lint as well as -e dev
132147
jax.config.update("jax_enable_x64", True) # type: ignore[no-untyped-call,unused-ignore]
148+
yield xp
149+
return
133150

134-
return xp
151+
yield xp
135152

136153

137154
@pytest.fixture(params=[Backend.DASK]) # Can select the test with `pytest -k dask`

0 commit comments

Comments
 (0)