Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

array.at.set is incredibly slow for complex128 dtype #24872

Open
chrisrothUT opened this issue Nov 13, 2024 · 5 comments
Open

array.at.set is incredibly slow for complex128 dtype #24872

chrisrothUT opened this issue Nov 13, 2024 · 5 comments
Labels
bug Something isn't working

Comments

@chrisrothUT
Copy link

chrisrothUT commented Nov 13, 2024

Description

For some reason array.at.set is incredibly slow with complex128 datatypes. Here I show it is much faster to split the arrays into real and imaginary parts before calling array.at.set and then recombine them into a complex array afterwards.

from jax import numpy as jnp
from time import time
import jax
import os

jax.config.update("jax_enable_x64", True)
@jax.jit
def set(x,x2,inds):
  return x.at[inds].set(x2)

@jax.jit
def complex_set(x,x2,inds):
  return jax.lax.complex(x.real.at[inds].set(x2.real), x.imag.at[inds].set(x2.imag))

x = jnp.zeros([10000000],dtype=jnp.complex128)
x2 = jnp.zeros([10000],dtype=jnp.complex128)
inds = jnp.arange(10000)

set(x,x2,inds)
complex_set(x,x2,inds)

t = time()
jax.block_until_ready(set(x,x2,inds))
print('set time=', time()-t)

t = time()
jax.block_until_ready(complex_set(x,x2,inds))
print('complex set time=', time()-t)

set time= 0.07047343254089355
complex set time= 0.0006287097930908203

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.35
jaxlib: 0.4.34
numpy: 2.0.2
python: 3.12.4 | packaged by Anaconda, Inc. | (main, Jun 18 2024, 15:12:24) [GCC 11.2.0]
device info: NVIDIA H100 PCIe-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='workergpu158', release='6.1.97.1.fi', version='#1 SMP Tue Jul 9 06:21:23 EDT 2024', machine='x86_64')

@chrisrothUT chrisrothUT added the bug Something isn't working label Nov 13, 2024
@hawkinsp
Copy link
Collaborator

hawkinsp commented Nov 13, 2024

As it happens we have workaround in JAX to avoid this slow behavior for scatter-add and scatter-sub, but not scatter-update. It should be pretty easy to make it work for scatter-update as well.

(The issue is that 128-bit scatters are currently expensive in XLA, because NVIDIA GPUs don't have a 16-byte atomic write operation.)

@hawkinsp
Copy link
Collaborator

Actually, thinking about this a bit more, it's somewhat problematic to split into real and imaginary parts.

If there are multiple updates to the same index, then it's unspecified which update "wins". If we performed updates to both real and imaginary parts separately, you might get the real part of one and the imaginary part of another. Only if you promised us the indices are non-overlapping would it be safe for us to do that. Is that true in your case?

It's easier for add and sub because those are associative operations; we can apply the updates in any order and still get the same result, up to floating point error.

@chrisrothUT
Copy link
Author

I see the issue. Yes, in our case the indices are non-overlapping so these functions are strictly the same.

Maybe the solution is to provide a warning about how scatter-update is slow with complex128 dtype and suggest updating the real and imaginary parts separately?

@hawkinsp
Copy link
Collaborator

Can you try specifying unique_indices=True as an argument to set?

https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

That may well fix the problem.

@chrisrothUT
Copy link
Author

chrisrothUT commented Nov 15, 2024

It doesn't fix the issue.

from jax import numpy as jnp
from time import time
import jax
import os

jax.config.update("jax_enable_x64", True)
@jax.jit
def set(x,x2,inds):
  return x.at[inds].set(x2,unique_indices=True)

@jax.jit
def complex_set(x,x2,inds):
  return jax.lax.complex(x.real.at[inds].set(x2.real), x.imag.at[inds].set(x2.imag))

x = jnp.zeros([10000000],dtype=jnp.complex128)
x2 = jnp.zeros([10000],dtype=jnp.complex128)
inds = jnp.arange(10000)

set(x,x2,inds)
complex_set(x,x2,inds)

t = time()
jax.block_until_ready(set(x,x2,inds))
print('set time=', time()-t)

t = time()
jax.block_until_ready(complex_set(x,x2,inds))
print('complex set time=', time()-t)

set time= 0.07504415512084961
complex set time= 0.0005309581756591797

I agree that this might be a natural way to implement the fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants