@@ -235,6 +235,12 @@ def is_jax_array(x: object) -> TypeIs[jax.Array]:
235235 is_pydata_sparse_array
236236 """
237237 cls = cast (Hashable , type (x ))
238+ # We test for jax.core.Tracer here to identify jax arrays during jit tracing. From jax 0.8.2 on,
239+ # tracers are not a subclass of jax.Array anymore. Note that tracers can also represent
240+ # non-array values and a fully correct implementation would need to use isinstance checks. Since
241+ # we use hash-based caching with type names as keys, we cannot use instance checks without
242+ # losing performance here. For more information, see
243+ # https://github.com/data-apis/array-api-compat/pull/369 and the corresponding issue.
238244 return (
239245 _issubclass_fast (cls , "jax" , "Array" )
240246 or _issubclass_fast (cls , "jax.core" , "Tracer" )
@@ -300,7 +306,7 @@ def _is_array_api_cls(cls: type) -> bool:
300306 or _issubclass_fast (cls , "sparse" , "SparseArray" )
301307 # TODO: drop support for jax<0.4.32 which didn't have __array_namespace__
302308 or _issubclass_fast (cls , "jax" , "Array" )
303- or _issubclass_fast (cls , "jax.core" , "Tracer" )
309+ or _issubclass_fast (cls , "jax.core" , "Tracer" ) # see is_jax_array for limitations
304310 )
305311
306312
@@ -939,7 +945,7 @@ def _is_writeable_cls(cls: type) -> bool | None:
939945 if (
940946 _issubclass_fast (cls , "numpy" , "generic" )
941947 or _issubclass_fast (cls , "jax" , "Array" )
942- or _issubclass_fast (cls , "jax.core" , "Tracer" )
948+ or _issubclass_fast (cls , "jax.core" , "Tracer" ) # see is_jax_array for limitations
943949 or _issubclass_fast (cls , "sparse" , "SparseArray" )
944950 ):
945951 return False
@@ -979,7 +985,7 @@ def _is_lazy_cls(cls: type) -> bool | None:
979985 return False
980986 if (
981987 _issubclass_fast (cls , "jax" , "Array" )
982- or _issubclass_fast (cls , "jax.core" , "Tracer" )
988+ or _issubclass_fast (cls , "jax.core" , "Tracer" ) # see is_jax_array for limitations
983989 or _issubclass_fast (cls , "dask.array" , "Array" )
984990 or _issubclass_fast (cls , "ndonnx" , "Array" )
985991 ):
0 commit comments