Skip to content

jax.jit makes traing very slow #3636

Answered by chiamp
davidshen84 asked this question in Q&A
Discussion options

You must be logged in to vote

What sort of values are you passing in as v during the training loop?
Since you defined static_argnames=["v"], jax.jit will trigger a re-compile every time you pass in a different value for v in train_step. Read more about it here.

Replies: 2 comments

Comment options

You must be logged in to vote
0 replies
Answer selected by chiamp
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants