You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm also asking this in the JAX repo and a few Discord channels but didn't have an answer yet.
fp8 has hardware support only on GPUs with sm >= 89 (Ada), such as RTX 4090 or A100. I've seen people trying to run it in PyTorch (e.g., this script) on older GPUs and getting errors. But JAX can actually run it on older GPUs.
and I can see the dtype is f8E4M3FN in the HLO IR. Then I used XLA_FLAGS="--xla_dump_to=..." and read module_0005.jit_f.ir-no-opt.ll. If the above dtype is float32, then the LLVM IR is relatively simple. But if it's float8, then the LLVM IR is much longer and contains instructions like load i8. So I assume there is some compiler pass in XLA to do the emulation?
`module_0005.jit_f.ir-no-opt.ll` with `dtype=jnp.float32`
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
I'm also asking this in the JAX repo and a few Discord channels but didn't have an answer yet.
fp8 has hardware support only on GPUs with sm >= 89 (Ada), such as RTX 4090 or A100. I've seen people trying to run it in PyTorch (e.g., this script) on older GPUs and getting errors. But JAX can actually run it on older GPUs.
I tried to run
and I can see the dtype is f8E4M3FN in the HLO IR. Then I used
XLA_FLAGS="--xla_dump_to=..."
and readmodule_0005.jit_f.ir-no-opt.ll
. If the above dtype is float32, then the LLVM IR is relatively simple. But if it's float8, then the LLVM IR is much longer and contains instructions likeload i8
. So I assume there is some compiler pass in XLA to do the emulation?`module_0005.jit_f.ir-no-opt.ll` with `dtype=jnp.float32`
`module_0005.jit_f.ir-no-opt.ll` with `dtype=jnp.float8_e4m3fn`
Beta Was this translation helpful? Give feedback.
All reactions