Skip to content

JaxStackTraceBeforeTransformation error for odeint with sparse BCOO matrix #31792

@refraction-ray

Description

@refraction-ray

Description

The following code gives the error ❌:

import jax
from jax.experimental.ode import odeint
from jax.experimental.sparse import BCOO

s = jax.numpy.ones([4])

def cost_fn(param):
    h0 = BCOO.fromdense(jax.numpy.eye(4), nse=4)

    def hamiltonian(t):
        return h0 * param[0]
    
    def dynamics(y, t):
        return hamiltonian(t) @ y

    s1 = odeint(dynamics, s, jax.numpy.array([0.0, 1.0]))
    return s1[0, 0]

gf = jax.jit(jax.grad(cost_fn))
param = jax.numpy.ones([1])
gf(param)

The error is:

JaxStackTraceBeforeTransformation: jax._src.interpreters.xla.InvalidInputException: Argument 'JitTracer<int32[4,2]>' of type <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'> is not a valid JAX type.

Replacing the sparse matrix with a dense matrix works as expected ✅

s = jax.numpy.ones([4])

def cost_fn(param):
    h0 = jax.numpy.eye(4)

    def hamiltonian(t):
        return h0 * param[0]
    
    def dynamics(y, t):
        return hamiltonian(t) @ y

    s1 = odeint(dynamics, s, jax.numpy.array([0.0, 1.0]))
    return s1[0, 0]

gf = jax.jit(jax.grad(cost_fn))
param = jax.numpy.ones([1])
gf(param)

Moving the matrix definition within hamiltonian function or outside of cost function works as expected ✅

s = jax.numpy.ones([4])
h0 = BCOO.fromdense(jax.numpy.eye(4), nse=4)

def cost_fn(param):

    def hamiltonian(t):
        return h0 * param[0]
    
    def dynamics(y, t):
        return hamiltonian(t) @ y

    s1 = odeint(dynamics, s, jax.numpy.array([0.0, 1.0]))
    return s1[0, 0]

gf = jax.jit(jax.grad(cost_fn))
param = jax.numpy.ones([1])
gf(param)
s = jax.numpy.ones([4])

def cost_fn(param):

    def hamiltonian(t):
        h0 = BCOO.fromdense(jax.numpy.eye(4), nse=4)
        return h0 * param[0]
    
    def dynamics(y, t):
        return hamiltonian(t) @ y

    s1 = odeint(dynamics, s, jax.numpy.array([0.0, 1.0]))
    return s1[0, 0]

gf = jax.jit(jax.grad(cost_fn))
param = jax.numpy.ones([1])
gf(param)

Also, replacing the odeint with the ode solver from diffrax works as expected ✅.

System info (python version, jaxlib version, accelerator, etc.)

jax 0.7.0
jaxlib 0.7.0

CPU version

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions