Skip to content

Commit 00d1f89

Browse files
committed
Fix is_jax_array for jax>=0.8.2
1 parent 88535e5 commit 00d1f89

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

array_api_compat/common/_helpers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,11 @@ def is_jax_array(x: object) -> TypeIs[jax.Array]:
235235
is_pydata_sparse_array
236236
"""
237237
cls = cast(Hashable, type(x))
238-
return _issubclass_fast(cls, "jax", "Array") or _is_jax_zero_gradient_array(x)
238+
return (
239+
_issubclass_fast(cls, "jax", "Array")
240+
or _issubclass_fast(cls, "jax.core", "Tracer")
241+
or _is_jax_zero_gradient_array(x)
242+
)
239243

240244

241245
def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]:

tests/test_common.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,15 @@ def test_is_xp_array(library, func):
5656
assert is_array_api_obj(x)
5757

5858

59+
def test_is_jax_array_jitted():
60+
import jax
61+
import jax.numpy as jnp
62+
63+
x = jnp.asarray([1, 2, 3])
64+
assert is_jax_array(x)
65+
assert jax.jit(lambda y: is_jax_array(y))(x)
66+
67+
5968
@pytest.mark.parametrize('library', is_namespace_functions.keys())
6069
@pytest.mark.parametrize('func', is_namespace_functions.values())
6170
def test_is_xp_namespace(library, func):

0 commit comments

Comments
 (0)