Skip to content
Discussion options

You must be logged in to vote

Is there any way to keep the original variable names, e.g. x and y, in the JAXPR?

This is not possible in general; for example:

def f(x):
  for i in range(5):
    x += 1
  return x
print(jax.make_jaxpr(f)(1))
{ lambda ; a:i32[]. let
    b:i32[] = add a 1:i32[]
    c:i32[] = add b 1:i32[]
    d:i32[] = add c 1:i32[]
    e:i32[] = add d 1:i32[]
    f:i32[] = add e 1:i32[]
  in (f,) }

If we were trying to replicate the Python variable names, every jaxpr variable in this expression would have to be named x.

Now, you could imagine various workaround for this (e.g. call them x_1, x_2, etc.) but then you need workarounds for your workarounds (what if there is already a variable called x_1?)

T…

Replies: 3 comments 3 replies

Comment options

You must be logged in to vote
1 reply
@rturrado
Comment options

Comment options

You must be logged in to vote
1 reply
@rturrado
Comment options

Answer selected by rturrado
Comment options

You must be logged in to vote
1 reply
@rturrado
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
4 participants