Replies: 2 comments
-
I think this function will work, but it's super inefficient
|
Beta Was this translation helpful? Give feedback.
0 replies
-
inefficient drop diagonal op
|
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
If we want to create an operation which extracts the diagonal this can use the jax numpy api creating a function with 'non-static' array shapes
if we want to take all elements not the diagonal in numpy we might use a mask. This however doesn't work in jit because the array isn't static (this example of mutating arrays with masks is covered in the docs (https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html))
returns error: IndexError: Array boolean indices must be concrete.
is this kind of operation possible in Jax? Is any operation where we want to create a new tensor only consisting some elements of a previous tensor possible? Are there hacks to get around this?
Beta Was this translation helpful? Give feedback.
All reactions