-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Description
The following code gives the error ❌:
import jax
from jax.experimental.ode import odeint
from jax.experimental.sparse import BCOO
s = jax.numpy.ones([4])
def cost_fn(param):
h0 = BCOO.fromdense(jax.numpy.eye(4), nse=4)
def hamiltonian(t):
return h0 * param[0]
def dynamics(y, t):
return hamiltonian(t) @ y
s1 = odeint(dynamics, s, jax.numpy.array([0.0, 1.0]))
return s1[0, 0]
gf = jax.jit(jax.grad(cost_fn))
param = jax.numpy.ones([1])
gf(param)
The error is:
JaxStackTraceBeforeTransformation: jax._src.interpreters.xla.InvalidInputException: Argument 'JitTracer<int32[4,2]>' of type <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'> is not a valid JAX type.
Replacing the sparse matrix with a dense matrix works as expected ✅
s = jax.numpy.ones([4])
def cost_fn(param):
h0 = jax.numpy.eye(4)
def hamiltonian(t):
return h0 * param[0]
def dynamics(y, t):
return hamiltonian(t) @ y
s1 = odeint(dynamics, s, jax.numpy.array([0.0, 1.0]))
return s1[0, 0]
gf = jax.jit(jax.grad(cost_fn))
param = jax.numpy.ones([1])
gf(param)
Moving the matrix definition within hamiltonian function or outside of cost function works as expected ✅
s = jax.numpy.ones([4])
h0 = BCOO.fromdense(jax.numpy.eye(4), nse=4)
def cost_fn(param):
def hamiltonian(t):
return h0 * param[0]
def dynamics(y, t):
return hamiltonian(t) @ y
s1 = odeint(dynamics, s, jax.numpy.array([0.0, 1.0]))
return s1[0, 0]
gf = jax.jit(jax.grad(cost_fn))
param = jax.numpy.ones([1])
gf(param)
s = jax.numpy.ones([4])
def cost_fn(param):
def hamiltonian(t):
h0 = BCOO.fromdense(jax.numpy.eye(4), nse=4)
return h0 * param[0]
def dynamics(y, t):
return hamiltonian(t) @ y
s1 = odeint(dynamics, s, jax.numpy.array([0.0, 1.0]))
return s1[0, 0]
gf = jax.jit(jax.grad(cost_fn))
param = jax.numpy.ones([1])
gf(param)
Also, replacing the odeint
with the ode solver from diffrax works as expected ✅.
System info (python version, jaxlib version, accelerator, etc.)
jax 0.7.0
jaxlib 0.7.0
CPU version
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working