-
Notifications
You must be signed in to change notification settings - Fork 620
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Asynchronous prefetching #5583
Comments
Hi @quanvuong, Thank you for reaching out.
|
That does seem to speed things up by quite a bit. Instantiating the data iterator is still quite slow because I'm using the integration with jax, which requires starting the worker pool with "spawn" (on a 8 H100 node, starting the all the worker pools can takes 20 minutes). Do you have advises on how to improve the speed here? |
Hi @quanvuong,
This is surprising and not expected. Can you share a self contained repro code we can run on our end for debugging? |
I'm working on the self contained repo, in the mean time, I have narrowed down to these lines that are the slow operations. Specifically, going from s0 to s1 takes 80 seconds (in nvidia/dali/_multiproc/pool.py)
|
Hi @quanvuong, Because you are referring to If that's the case, you can check https://docs.nvidia.com/deeplearning/dali/user-guide/docs/examples/general/data_loading/parallel_external_source.html#Serialization-and-heavy-setup to try to get around, by making sure that the seriallized object is lighter. |
Yes I am using parallel external source in my iterator (with jax integration). I have moved heavy set up to get_state as recommended, and that reduces the time taken to instantiate the data iterator by 10-15%, but it is still quite slow. Any advise? Is there a profiler that I can use? We are running on 8 H100 nodes, and instantiating the data iterator takes more than 10 minutes (about 2 minutes per gpu). |
Hi @quanvuong, Can you bisect which part of the external source callback is slow (start with the callback that just calls |
Hi @quanvuong, To make sure this the serialization is no longer the main factor contributing to the start-up time, you could use custom pickler that wraps the pickle and provides you with some more information, like the size of the callback once it is serialized.
Another thing that may contribute to the total start-up time of the workers (although I would expect it to show as |
Is this a new feature, an improvement, or a change to existing functionality?
New Feature
How would you describe the priority of this feature request
Nice to have (e.g. Adoption is possible, the feature will enhance the use-case even more).
Please provide a clear description of problem this feature solves
Instantiating the dali dataloader for jax takes a long time, because of prefetching.
Feature Description
It would be nice to have an asynchronous prefetching feature, so I can interleave jitting the model and the prefetching operations.
Describe your ideal solution
Have two function
def start_async_prefetch(self):
def block_till_ready_async_prefetch():
With these two functions, I can control when prefetching happens
Describe any alternatives you have considered
No response
Additional context
No response
Check for duplicates
The text was updated successfully, but these errors were encountered: