Skip to content

Commit b61e9c3

Browse files
authored
BUG: Fix is_jax_array for jax>=0.8.2 (#369)
* Fix is_jax_array for jax>=0.8.2 * Skip jax test if not installed * Fix and test array_api_obj, is_writable_array, is_lazy_array * Add comments on jax.core.Tracer detection limitations
1 parent 88535e5 commit b61e9c3

File tree

2 files changed

+36
-4
lines changed

2 files changed

+36
-4
lines changed

array_api_compat/common/_helpers.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,17 @@ def is_jax_array(x: object) -> TypeIs[jax.Array]:
235235
is_pydata_sparse_array
236236
"""
237237
cls = cast(Hashable, type(x))
238-
return _issubclass_fast(cls, "jax", "Array") or _is_jax_zero_gradient_array(x)
238+
# We test for jax.core.Tracer here to identify jax arrays during jit tracing. From jax 0.8.2 on,
239+
# tracers are not a subclass of jax.Array anymore. Note that tracers can also represent
240+
# non-array values and a fully correct implementation would need to use isinstance checks. Since
241+
# we use hash-based caching with type names as keys, we cannot use instance checks without
242+
# losing performance here. For more information, see
243+
# https://github.com/data-apis/array-api-compat/pull/369 and the corresponding issue.
244+
return (
245+
_issubclass_fast(cls, "jax", "Array")
246+
or _issubclass_fast(cls, "jax.core", "Tracer")
247+
or _is_jax_zero_gradient_array(x)
248+
)
239249

240250

241251
def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]:
@@ -296,6 +306,7 @@ def _is_array_api_cls(cls: type) -> bool:
296306
or _issubclass_fast(cls, "sparse", "SparseArray")
297307
# TODO: drop support for jax<0.4.32 which didn't have __array_namespace__
298308
or _issubclass_fast(cls, "jax", "Array")
309+
or _issubclass_fast(cls, "jax.core", "Tracer") # see is_jax_array for limitations
299310
)
300311

301312

@@ -934,6 +945,7 @@ def _is_writeable_cls(cls: type) -> bool | None:
934945
if (
935946
_issubclass_fast(cls, "numpy", "generic")
936947
or _issubclass_fast(cls, "jax", "Array")
948+
or _issubclass_fast(cls, "jax.core", "Tracer") # see is_jax_array for limitations
937949
or _issubclass_fast(cls, "sparse", "SparseArray")
938950
):
939951
return False
@@ -973,6 +985,7 @@ def _is_lazy_cls(cls: type) -> bool | None:
973985
return False
974986
if (
975987
_issubclass_fast(cls, "jax", "Array")
988+
or _issubclass_fast(cls, "jax.core", "Tracer") # see is_jax_array for limitations
976989
or _issubclass_fast(cls, "dask.array", "Array")
977990
or _issubclass_fast(cls, "ndonnx", "Array")
978991
):

tests/test_jax.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
from numpy.testing import assert_equal
22
import pytest
33

4-
from array_api_compat import device, to_device
4+
from array_api_compat import (
5+
device,
6+
to_device,
7+
is_jax_array,
8+
is_lazy_array,
9+
is_array_api_obj,
10+
is_writeable_array,
11+
)
512

613
try:
714
import jax
@@ -13,7 +20,7 @@
1320

1421

1522
@pytest.mark.parametrize(
16-
"func",
23+
"func",
1724
[
1825
lambda x: jnp.zeros(1, device=device(x)),
1926
lambda x: jnp.zeros_like(jnp.ones(1, device=device(x))),
@@ -26,7 +33,7 @@
2633
),
2734
),
2835
lambda x: to_device(jnp.zeros(1), device(x)),
29-
]
36+
],
3037
)
3138
def test_device_jit(func):
3239
# Test work around to https://github.com/jax-ml/jax/issues/26000
@@ -36,3 +43,15 @@ def test_device_jit(func):
3643
x = jnp.ones(1)
3744
assert_equal(func(x), jnp.asarray([0]))
3845
assert_equal(jax.jit(func)(x), jnp.asarray([0]))
46+
47+
48+
def test_inside_jit():
49+
# Test if jax arrays are handled correctly inside jax.jit.
50+
# Jax tracers are not a subclass of jax.Array from 0.8.2 on. We explicitly test that
51+
# tracers are handled appropriately. For limitations, see is_jax_array() docstring.
52+
# Reference issue: https://github.com/data-apis/array-api-compat/issues/368
53+
x = jnp.asarray([1, 2, 3])
54+
assert jax.jit(is_jax_array)(x)
55+
assert jax.jit(is_array_api_obj)(x)
56+
assert not jax.jit(is_writeable_array)(x)
57+
assert jax.jit(is_lazy_array)(x)

0 commit comments

Comments
 (0)