You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I would like to compute an einsum according to the following formula:
n = 8192
arrays = [jax.random.normal(key=jax.random.PRNGKey(0), shape=(n, n)) for _ in range(6)]
formula = 'ij,ik,il,jk,jl,kl->ij'
I want to express the computation as 4 nested for loops over indices i, j, k, l without creating any intermediate arrays. As far as einsum_path is concerned, I can do this by passing the einsum path directly as [(0, 1, 2, 3, 4, 5)] via the optimize kwarg).
I would like to compute an einsum according to the following formula:
I want to express the computation as 4 nested for loops over indices i, j, k, l without creating any intermediate arrays. As far as einsum_path is concerned, I can do this by passing the einsum path directly as [(0, 1, 2, 3, 4, 5)] via the optimize kwarg).
However, when I try to do the einsum, I get this NotImplementedError with a comment that says "# if this is actually reachable, open an issue!"
https://github.com/jax-ml/jax/blob/main/jax/_src/numpy/lax_numpy.py#L9775
The text was updated successfully, but these errors were encountered: