@@ -235,7 +235,17 @@ def is_jax_array(x: object) -> TypeIs[jax.Array]:
235235 is_pydata_sparse_array
236236 """
237237 cls = cast (Hashable , type (x ))
238- return _issubclass_fast (cls , "jax" , "Array" ) or _is_jax_zero_gradient_array (x )
238+ # We test for jax.core.Tracer here to identify jax arrays during jit tracing. From jax 0.8.2 on,
239+ # tracers are not a subclass of jax.Array anymore. Note that tracers can also represent
240+ # non-array values and a fully correct implementation would need to use isinstance checks. Since
241+ # we use hash-based caching with type names as keys, we cannot use instance checks without
242+ # losing performance here. For more information, see
243+ # https://github.com/data-apis/array-api-compat/pull/369 and the corresponding issue.
244+ return (
245+ _issubclass_fast (cls , "jax" , "Array" )
246+ or _issubclass_fast (cls , "jax.core" , "Tracer" )
247+ or _is_jax_zero_gradient_array (x )
248+ )
239249
240250
241251def is_pydata_sparse_array (x : object ) -> TypeIs [sparse .SparseArray ]:
@@ -296,6 +306,7 @@ def _is_array_api_cls(cls: type) -> bool:
296306 or _issubclass_fast (cls , "sparse" , "SparseArray" )
297307 # TODO: drop support for jax<0.4.32 which didn't have __array_namespace__
298308 or _issubclass_fast (cls , "jax" , "Array" )
309+ or _issubclass_fast (cls , "jax.core" , "Tracer" ) # see is_jax_array for limitations
299310 )
300311
301312
@@ -934,6 +945,7 @@ def _is_writeable_cls(cls: type) -> bool | None:
934945 if (
935946 _issubclass_fast (cls , "numpy" , "generic" )
936947 or _issubclass_fast (cls , "jax" , "Array" )
948+ or _issubclass_fast (cls , "jax.core" , "Tracer" ) # see is_jax_array for limitations
937949 or _issubclass_fast (cls , "sparse" , "SparseArray" )
938950 ):
939951 return False
@@ -973,6 +985,7 @@ def _is_lazy_cls(cls: type) -> bool | None:
973985 return False
974986 if (
975987 _issubclass_fast (cls , "jax" , "Array" )
988+ or _issubclass_fast (cls , "jax.core" , "Tracer" ) # see is_jax_array for limitations
976989 or _issubclass_fast (cls , "dask.array" , "Array" )
977990 or _issubclass_fast (cls , "ndonnx" , "Array" )
978991 ):
0 commit comments