jax.jit makes traing very slow #3636
-
Hi, I have a simple dummy module like this.
If I use But if I want to use a real float value, I have to use the
The training speed is at about 13 s/it... that is 13 seconds per iteration. I found no caveats about slowness when using the |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
What sort of values are you passing in as |
Beta Was this translation helpful? Give feedback.
-
v is a random float number...so, it triggers a recompile every time...
lesson learnt.
Thanks
…On Thu, 25 Jan 2024, 15:22 Marcus Chiam, ***@***.***> wrote:
What sort of values are you passing in as v during the training loop?
Since you defined static_argnames=["v"], jax.jit will trigger a
re-compile every time you pass in a different value for v in train_step.
Read more about it here
<https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html#caching>.
—
Reply to this email directly, view it on GitHub
<#3636 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAAQBTIPPALWLPM4P44GJODYQHMXRAVCNFSM6AAAAABCDWMKK2VHI2DSMVQWIX3LMV43SRDJONRXK43TNFXW4Q3PNVWWK3TUHM4DENBQGU3DO>
.
You are receiving this because you authored the thread.Message ID:
***@***.***>
|
Beta Was this translation helpful? Give feedback.
What sort of values are you passing in as
v
during the training loop?Since you defined
static_argnames=["v"]
,jax.jit
will trigger a re-compile every time you pass in a different value forv
intrain_step
. Read more about it here.