-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
array.at.set is incredibly slow for complex128 dtype #24872
Comments
As it happens we have workaround in JAX to avoid this slow behavior for scatter-add and scatter-sub, but not scatter-update. It should be pretty easy to make it work for scatter-update as well. (The issue is that 128-bit scatters are currently expensive in XLA, because NVIDIA GPUs don't have a 16-byte atomic write operation.) |
Actually, thinking about this a bit more, it's somewhat problematic to split into real and imaginary parts. If there are multiple updates to the same index, then it's unspecified which update "wins". If we performed updates to both real and imaginary parts separately, you might get the real part of one and the imaginary part of another. Only if you promised us the indices are non-overlapping would it be safe for us to do that. Is that true in your case? It's easier for |
I see the issue. Yes, in our case the indices are non-overlapping so these functions are strictly the same. Maybe the solution is to provide a warning about how scatter-update is slow with complex128 dtype and suggest updating the real and imaginary parts separately? |
Can you try specifying https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html That may well fix the problem. |
It doesn't fix the issue.
set time= 0.07504415512084961 I agree that this might be a natural way to implement the fix. |
Description
For some reason
array.at.set
is incredibly slow with complex128 datatypes. Here I show it is much faster to split the arrays into real and imaginary parts before callingarray.at.set
and then recombine them into a complex array afterwards.set time= 0.07047343254089355
complex set time= 0.0006287097930908203
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.35
jaxlib: 0.4.34
numpy: 2.0.2
python: 3.12.4 | packaged by Anaconda, Inc. | (main, Jun 18 2024, 15:12:24) [GCC 11.2.0]
device info: NVIDIA H100 PCIe-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='workergpu158', release='6.1.97.1.fi', version='#1 SMP Tue Jul 9 06:21:23 EDT 2024', machine='x86_64')
The text was updated successfully, but these errors were encountered: