Skip to content

Commit d8befd5

Browse files
committed
Add comments on jax.core.Tracer detection limitations
1 parent bcf350f commit d8befd5

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

array_api_compat/common/_helpers.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,12 @@ def is_jax_array(x: object) -> TypeIs[jax.Array]:
235235
is_pydata_sparse_array
236236
"""
237237
cls = cast(Hashable, type(x))
238+
# We test for jax.core.Tracer here to identify jax arrays during jit tracing. From jax 0.8.2 on,
239+
# tracers are not a subclass of jax.Array anymore. Note that tracers can also represent
240+
# non-array values and a fully correct implementation would need to use isinstance checks. Since
241+
# we use hash-based caching with type names as keys, we cannot use instance checks without
242+
# losing performance here. For more information, see
243+
# https://github.com/data-apis/array-api-compat/pull/369 and the corresponding issue.
238244
return (
239245
_issubclass_fast(cls, "jax", "Array")
240246
or _issubclass_fast(cls, "jax.core", "Tracer")
@@ -300,7 +306,7 @@ def _is_array_api_cls(cls: type) -> bool:
300306
or _issubclass_fast(cls, "sparse", "SparseArray")
301307
# TODO: drop support for jax<0.4.32 which didn't have __array_namespace__
302308
or _issubclass_fast(cls, "jax", "Array")
303-
or _issubclass_fast(cls, "jax.core", "Tracer")
309+
or _issubclass_fast(cls, "jax.core", "Tracer") # see is_jax_array for limitations
304310
)
305311

306312

@@ -939,7 +945,7 @@ def _is_writeable_cls(cls: type) -> bool | None:
939945
if (
940946
_issubclass_fast(cls, "numpy", "generic")
941947
or _issubclass_fast(cls, "jax", "Array")
942-
or _issubclass_fast(cls, "jax.core", "Tracer")
948+
or _issubclass_fast(cls, "jax.core", "Tracer") # see is_jax_array for limitations
943949
or _issubclass_fast(cls, "sparse", "SparseArray")
944950
):
945951
return False
@@ -979,7 +985,7 @@ def _is_lazy_cls(cls: type) -> bool | None:
979985
return False
980986
if (
981987
_issubclass_fast(cls, "jax", "Array")
982-
or _issubclass_fast(cls, "jax.core", "Tracer")
988+
or _issubclass_fast(cls, "jax.core", "Tracer") # see is_jax_array for limitations
983989
or _issubclass_fast(cls, "dask.array", "Array")
984990
or _issubclass_fast(cls, "ndonnx", "Array")
985991
):

tests/test_jax.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,10 @@ def test_device_jit(func):
4646

4747

4848
def test_inside_jit():
49-
jax = pytest.importorskip("jax")
50-
import jax.numpy as jnp
51-
49+
# Test if jax arrays are handled correctly inside jax.jit.
50+
# Jax tracers are not a subclass of jax.Array from 0.8.2 on. We explicitly test that
51+
# tracers are handled appropriately. For limitations, see is_jax_array() docstring.
52+
# Reference issue: https://github.com/data-apis/array-api-compat/issues/368
5253
x = jnp.asarray([1, 2, 3])
5354
assert jax.jit(is_jax_array)(x)
5455
assert jax.jit(is_array_api_obj)(x)

0 commit comments

Comments
 (0)