Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Data corruption with JAX plugin #5617

Open
1 task done
kvablack opened this issue Sep 3, 2024 · 3 comments
Open
1 task done

Data corruption with JAX plugin #5617

kvablack opened this issue Sep 3, 2024 · 3 comments
Assignees
Labels
bug Something isn't working JAX Issues related to DALI and JAX integration

Comments

@kvablack
Copy link

kvablack commented Sep 3, 2024

Version

nvidia-dali-cuda120==1.40.0, jax==0.4.31

Describe the bug.

We recently discovered a problem that when we used DALI, our training curves were mysteriously worse. We were able to fix it by adding a jax.block_until_ready() call after each train step, to foil JAX's asynchronous dispatch. I therefore hypothesized that there was some sort of data corruption going on, caused by a lack of synchronization between JAX and DALI.

If I understand correctly, when you request the next element, the JAX plugin roughly does the following steps:

def get_next_element(pipe: dali.Pipeline):
    # gets the next element
    element = pipe.share_outputs()
    # copies to JAX memory
    element = [jax.dlpack.from_dlpack(x.as_tensor()._expose_dlpack_capsule(), copy=True) for x in element]
    # tells DALI that we are done with the output buffers
    pipe.release_outputs()
    # schedules the next fetch
    pipe.schedule_run()
    return element

I believe the problem is that jax.dlpack.from_dlpack and jnp.copy are both asynchronous. Therefore, pipe.release_outputs() is called before the copy actually occurs. When you request the next element, DALI can overwrite the output buffers before JAX is done reading from them.

I was able to fix the problem by bypassing the DALI JAX plugin, and adding the following line to my code:

def get_next_element(pipe: dali.Pipeline):
    # gets the next element
    element = pipe.share_outputs()
    # copies to JAX memory
    element = [jax.dlpack.from_dlpack(x.as_tensor()._expose_dlpack_capsule(), copy=True) for x in element]
    # wait until the copy is done (THIS IS THE MISSING STEP)
    element = jax.block_until_ready(element)
    # tells DALI that we are done with the output buffers
    pipe.release_outputs()
    # schedules the next fetch
    pipe.schedule_run()
    return element

Of course, I suspect this forces JAX to flush the entire GPU pipeline, which is somewhat inefficient. However, I don't think there's any way around this without deeper integration between JAX and DALI (specifically, I think DALI needs to provide a DLPack deleter, although I'm no expert).

Minimum reproducible example

No response

Relevant log output

No response

Other/Misc.

No response

Check for duplicates

  • I have searched the open bugs/issues and have found no duplicates for this bug report
@kvablack kvablack added the bug Something isn't working label Sep 3, 2024
@awolant awolant assigned awolant and unassigned szalpal Sep 4, 2024
@awolant awolant added the JAX Issues related to DALI and JAX integration label Sep 4, 2024
@awolant
Copy link
Contributor

awolant commented Sep 4, 2024

Hello @kvablack

thank you for reporting this issue. Do you have some standalone script to reproduce it reliably. It would help a lot as I had no luck in reproducing it so far.

As far as I remember DALI should pass DLPack deleter to JAX to be called when capsule is no longer needed. There might be some issue with it as you pointed out.

@kvablack
Copy link
Author

kvablack commented Sep 4, 2024

Here you go, reproduces on a 4090. You do need a couple of things to trigger the issue -- namely, large enough arrays and a fake "train step".

import jax
import jax.numpy as jnp
import numpy as np
from nvidia import dali
import tqdm

# this needs to be large to trigger the bug.
ARR_SIZE = 2**16


class ExternalSource:
    def __call__(self, sample_info: dali.types.SampleInfo):
        return [
            np.full((ARR_SIZE), fill_value=sample_info.idx_in_epoch, dtype=np.int32),
            np.full((ARR_SIZE), fill_value=sample_info.idx_in_epoch, dtype=np.int32),
        ]


def get_pipe():
    @dali.pipeline_def(
        batch_size=2,
        num_threads=1,
        prefetch_queue_depth=6,
        py_start_method="spawn",
    )
    def pipeline():
        outputs = dali.fn.external_source(
            source=ExternalSource(),
            num_outputs=2,
            batch=False,
            parallel=True,
        )

        outputs = [arr.gpu() for arr in outputs]

        return tuple(outputs)

    pipe = pipeline(device_id=0)
    pipe.build()
    pipe.schedule_run()
    return pipe


def get_next(pipe: dali.Pipeline):
    outputs = pipe.share_outputs()
    element = [jax.dlpack.from_dlpack(x.as_tensor()._expose_dlpack_capsule(), copy=True) for x in outputs]
    # UNCOMMENT THIS LINE TO MAKE THE ASSERTIONS PASS.
    # jax.block_until_ready(element)
    pipe.release_outputs()
    pipe.schedule_run()
    return element


@jax.jit
def f(x):
    # matmul to simulate the train step. interestingly, this is required to reproduce the bug.
    return [y @ y.T for y in x]


if __name__ == "__main__":
    pipe = get_pipe()
    batches = []
    for _ in tqdm.trange(10):
        batches.append(f(get_next(pipe)))

    for i in tqdm.trange(len(batches)):
        for elem in batches[i]:
            # `x` is what the ExternalSource should have returned
            x = jnp.broadcast_to(jnp.arange(elem.shape[0])[:, None] + i * elem.shape[0], (elem.shape[0], ARR_SIZE))
            # `y` is what the "train step" should have returned
            y = x @ x.T
            # check that it matches the batch that came out of the pipeline
            assert jnp.all(y == elem), f"Failed on batch {i}\n\nExpected:\n{y}\n\nGot:\n{elem}"

@quanvuong
Copy link

+1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working JAX Issues related to DALI and JAX integration
Projects
None yet
Development

No branches or pull requests

4 participants