Understanding the source code of jax.lax | tanh #24240
-
Hi there, I'm trying to understand how the Does it simply call numpy's/XLA's tanh operation? If I were to optimize the So far, from the repo, I've gotten this, but I couldn't make much of it : Lines 337 to 339 in 8ef41a6 Any pointers appreciated, TIA! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
It's not so straightforward to follow the full stack, but the basic flow is that a JAX function get's lowered to StableHLO, which you can inspect as follows: import jax
import jax.numpy as jnp
print(jax.jit(jnp.tanh).lower(jnp.linspace(-1, 1, 5)).as_text()) Since StableHLO has a tanh intrinsic, the intermediate representation is (relatively) simple:
Then, depending on the specific platform that you are targeting, XLA will compile this program differently. For example, on CPU, I believe the tanh will be passed to LLVM, which (I expect) will emit SIMD operations. So, the tl;dr is that it's not really straightforward to track down exactly how specific operations are compiled. If you want to contribute optimizations like this, those would live in LLVM or XLA rather than JAX itself. Hope this helps! |
Beta Was this translation helpful? Give feedback.
It's not so straightforward to follow the full stack, but the basic flow is that a JAX function get's lowered to StableHLO, which you can inspect as follows:
Since StableHLO has a tanh intrinsic, the intermediate representation is (relatively) simple:
Then, depending on…