Replies: 2 comments 1 reply
-
https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html reading this closer, the answer really seems to be no. |
Beta Was this translation helpful? Give feedback.
0 replies
-
Python control flow based on the values within a traced array cannot be used with JIT or pmap, but you can often re-express your code in such a way that it can. In your case, using def f(x):
x = x**2 + x
return jnp.where(jnp.sum(x) < 3, 3 * x ** 2, -4 * x) There is more information available on this in 🔪 JAX - The Sharp Bits 🔪 : Control Flow |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Can I distribute a function with pmap, containing some jittable subfunctions, and a traced input dependent (not collective ie not dependent on the data from other devices AND not static) control flow at the end?
or do all pmapped functions have to only contain any code that is jittable?
Thanks!
docs https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html under heading 'pmap and jit'
control flow https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow
potentially related #5895
Example
The docs indicate that no dynamic control flow can be used with pmap (by inference, they don't directly state this). Is there a way to distribute code with traced variables control flow?
Beta Was this translation helpful? Give feedback.
All reactions