-
Notifications
You must be signed in to change notification settings - Fork 620
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Data corruption with JAX plugin #5617
Comments
Hello @kvablack thank you for reporting this issue. Do you have some standalone script to reproduce it reliably. It would help a lot as I had no luck in reproducing it so far. As far as I remember DALI should pass DLPack |
Here you go, reproduces on a 4090. You do need a couple of things to trigger the issue -- namely, large enough arrays and a fake "train step". import jax
import jax.numpy as jnp
import numpy as np
from nvidia import dali
import tqdm
# this needs to be large to trigger the bug.
ARR_SIZE = 2**16
class ExternalSource:
def __call__(self, sample_info: dali.types.SampleInfo):
return [
np.full((ARR_SIZE), fill_value=sample_info.idx_in_epoch, dtype=np.int32),
np.full((ARR_SIZE), fill_value=sample_info.idx_in_epoch, dtype=np.int32),
]
def get_pipe():
@dali.pipeline_def(
batch_size=2,
num_threads=1,
prefetch_queue_depth=6,
py_start_method="spawn",
)
def pipeline():
outputs = dali.fn.external_source(
source=ExternalSource(),
num_outputs=2,
batch=False,
parallel=True,
)
outputs = [arr.gpu() for arr in outputs]
return tuple(outputs)
pipe = pipeline(device_id=0)
pipe.build()
pipe.schedule_run()
return pipe
def get_next(pipe: dali.Pipeline):
outputs = pipe.share_outputs()
element = [jax.dlpack.from_dlpack(x.as_tensor()._expose_dlpack_capsule(), copy=True) for x in outputs]
# UNCOMMENT THIS LINE TO MAKE THE ASSERTIONS PASS.
# jax.block_until_ready(element)
pipe.release_outputs()
pipe.schedule_run()
return element
@jax.jit
def f(x):
# matmul to simulate the train step. interestingly, this is required to reproduce the bug.
return [y @ y.T for y in x]
if __name__ == "__main__":
pipe = get_pipe()
batches = []
for _ in tqdm.trange(10):
batches.append(f(get_next(pipe)))
for i in tqdm.trange(len(batches)):
for elem in batches[i]:
# `x` is what the ExternalSource should have returned
x = jnp.broadcast_to(jnp.arange(elem.shape[0])[:, None] + i * elem.shape[0], (elem.shape[0], ARR_SIZE))
# `y` is what the "train step" should have returned
y = x @ x.T
# check that it matches the batch that came out of the pipeline
assert jnp.all(y == elem), f"Failed on batch {i}\n\nExpected:\n{y}\n\nGot:\n{elem}" |
+1 |
Version
nvidia-dali-cuda120==1.40.0, jax==0.4.31
Describe the bug.
We recently discovered a problem that when we used DALI, our training curves were mysteriously worse. We were able to fix it by adding a
jax.block_until_ready()
call after each train step, to foil JAX's asynchronous dispatch. I therefore hypothesized that there was some sort of data corruption going on, caused by a lack of synchronization between JAX and DALI.If I understand correctly, when you request the next element, the JAX plugin roughly does the following steps:
I believe the problem is that
jax.dlpack.from_dlpack
andjnp.copy
are both asynchronous. Therefore,pipe.release_outputs()
is called before the copy actually occurs. When you request the next element, DALI can overwrite the output buffers before JAX is done reading from them.I was able to fix the problem by bypassing the DALI JAX plugin, and adding the following line to my code:
Of course, I suspect this forces JAX to flush the entire GPU pipeline, which is somewhat inefficient. However, I don't think there's any way around this without deeper integration between JAX and DALI (specifically, I think DALI needs to provide a DLPack
deleter
, although I'm no expert).Minimum reproducible example
No response
Relevant log output
No response
Other/Misc.
No response
Check for duplicates
The text was updated successfully, but these errors were encountered: