11"""Pytest fixtures."""
22
3- from collections .abc import Callable
3+ from collections .abc import Callable , Generator
44from contextlib import suppress
55from functools import partial , wraps
66from types import ModuleType
@@ -104,7 +104,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
104104@pytest .fixture
105105def xp (
106106 library : Backend , request : pytest .FixtureRequest , monkeypatch : pytest .MonkeyPatch
107- ) -> ModuleType : # numpydoc ignore=PR01,RT03
107+ ) -> Generator [ ModuleType ] : # numpydoc ignore=PR01,RT03
108108 """
109109 Parameterized fixture that iterates on all libraries.
110110
@@ -113,7 +113,25 @@ def xp(
113113 The current array namespace.
114114 """
115115 if library == Backend .NUMPY_READONLY :
116- return NumPyReadOnly () # type: ignore[return-value] # pyright: ignore[reportReturnType]
116+ yield NumPyReadOnly () # type: ignore[return-value] # pyright: ignore[reportReturnType]
117+ return
118+
119+ if (
120+ library in (Backend .ARRAY_API_STRICT , Backend .ARRAY_API_STRICTEST )
121+ and np .__version__ < "1.26"
122+ ):
123+ pytest .skip ("array_api_strict is untested on NumPy <1.26" )
124+
125+ if library == Backend .ARRAY_API_STRICTEST :
126+ xp = pytest .importorskip ("array_api_strict" )
127+ with xp .ArrayAPIStrictFlags (
128+ boolean_indexing = False ,
129+ data_dependent_shapes = False ,
130+ enabled_extensions = (),
131+ ):
132+ yield xp
133+ return
134+
117135 xp = pytest .importorskip (library .value )
118136 # Possibly wrap module with array_api_compat
119137 xp = array_namespace (xp .empty (0 ))
@@ -122,16 +140,15 @@ def xp(
122140 # in the global scope of the module containing the test function.
123141 patch_lazy_xp_functions (request , monkeypatch , xp = xp )
124142
125- if library == Backend .ARRAY_API_STRICT and np .__version__ < "1.26" :
126- pytest .skip ("array_api_strict is untested on NumPy <1.26" )
127-
128143 if library == Backend .JAX :
129144 import jax
130145
131146 # suppress unused-ignore to run mypy in -e lint as well as -e dev
132147 jax .config .update ("jax_enable_x64" , True ) # type: ignore[no-untyped-call,unused-ignore]
148+ yield xp
149+ return
133150
134- return xp
151+ yield xp
135152
136153
137154@pytest .fixture (params = [Backend .DASK ]) # Can select the test with `pytest -k dask`
0 commit comments