File tree Expand file tree Collapse file tree 2 files changed +14
-1
lines changed
Expand file tree Collapse file tree 2 files changed +14
-1
lines changed Original file line number Diff line number Diff line change @@ -235,7 +235,11 @@ def is_jax_array(x: object) -> TypeIs[jax.Array]:
235235 is_pydata_sparse_array
236236 """
237237 cls = cast (Hashable , type (x ))
238- return _issubclass_fast (cls , "jax" , "Array" ) or _is_jax_zero_gradient_array (x )
238+ return (
239+ _issubclass_fast (cls , "jax" , "Array" )
240+ or _issubclass_fast (cls , "jax.core" , "Tracer" )
241+ or _is_jax_zero_gradient_array (x )
242+ )
239243
240244
241245def is_pydata_sparse_array (x : object ) -> TypeIs [sparse .SparseArray ]:
Original file line number Diff line number Diff line change @@ -56,6 +56,15 @@ def test_is_xp_array(library, func):
5656 assert is_array_api_obj (x )
5757
5858
59+ def test_is_jax_array_jitted ():
60+ import jax
61+ import jax .numpy as jnp
62+
63+ x = jnp .asarray ([1 , 2 , 3 ])
64+ assert is_jax_array (x )
65+ assert jax .jit (lambda y : is_jax_array (y ))(x )
66+
67+
5968@pytest .mark .parametrize ('library' , is_namespace_functions .keys ())
6069@pytest .mark .parametrize ('func' , is_namespace_functions .values ())
6170def test_is_xp_namespace (library , func ):
You can’t perform that action at this time.
0 commit comments