-
Notifications
You must be signed in to change notification settings - Fork 66
Optimize data loading for cell classification #493
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
I think one of the issues with the actions is Project-MONAI/MONAI#8277. It lists an older numpy as a dependency even though it's compatible with the latest numpy. |
|
It's updated on main though, so hopefully the next release will work ok. |
|
Looks like the next release should be out soon: Project-MONAI/MONAI#8421 |
|
There seems to have been some movement on the MONAI repo last week, getting closer to a new release, but it's not quite there yet. Just commenting here to remind ourselves that this PR still exists :) |
|
I'm hoping to start working on this again this week! 🤞 |
|
The slow release cadence of monai has made me rethink having it as a dependency. @matham, do you have an idea of how complicated it would be to remove the dependency from this PR? |
|
Originally, when I looked at it, I saw 3 options. The original scipy,
I looked at extracting the relevant code from monai, which should be doable. But I didn't think you'd want to have to maintain that so I didn't do it. But perhaps you do prefer it? Ultimately for There are other libraries that can handle 3d transformations, e.g. |
|
Hmm, lets go ahead with monai, and vendor the necessary code down the line if we feel the need to. |
|
I think we should probably wait, although:
Would you mind finishing adding the tests, so we can review it all in one go? |
|
Also, I should add - thanks again for all your hard work on improving cellfinder @matham! |
| batch_size = self.batch_size | ||
|
|
||
| if shuffle: | ||
| rng = np.random.default_rng() |
Check notice
Code scanning / SonarCloud
Results that depend on random number generation should be reproducible Low
|
This pull request has been mentioned on Image.sc Forum. There might be relevant details there: https://forum.image.sc/t/brainmapper-cell-classification-processing-time/117138/2 |
Description
This is still a WIP and needs a bit of polish.
What is this PR
Why is this PR needed?
When testing cell classification with our large whole brains (600GB+, 2.5 million+ potential cells), it took 32+ hours to run inference on all the potential cells. Looking at task manager, I could see that the GPU was barely utilized. This is typically an indication that the data was not read or processed quickly enough to keep the GPU fed.
With my changes, it only took
<1.5 hours.Also, when testing training the model with the original
serial2pdataset, it took79.6 min / epochfor main, vs24.33 min / epochwith the changes. And it used less than half the RAM of main.What does this PR do?
Very broadly, this PR refactors the data loading and augmentation classes so the various features can be shared among different datasets (dask arrays, tiff files containing cubes etc). But more importantly it adds caching to data loading. It also re-works the multi-processing of workers loading the data.
Consider a bunch of points with various
zvalues. We need to get a cuboid centered on the point. Say in Z it's 20 planes. Then for every point we'd access the Dask array to get 20 planes and extract the cuboid. So for each point we potentially read 20 planes. We already sorted during inference to read sequentialzpoints, but still unless Dask did some caching we'd be reading the same data constantly. And certainly there's overhead.Instead, we cache the last
nplanes in memory and so when readingzsequentially we save a lot of disk/Dask access. This is also quite important when we split processing up among workers who may access neighboringzpoints, preventing Dask (OS) from caching data. But, having our cache twice the size of a cube makes this not a problem for us.Additionally, there's a dedicated thread in the main process that does the data loading. All the worker (sub-processes) communicate with it via queues by sending it a buffer that the thread fills in with the loaded data and the worker than process it (e.g. augment, resize). pytorch tensors can be shared among processes in memory so this saves us a bunch of data copying.
Important changes:
max_workers. Unlike elsewhere where we want to say how many cores we want to keep free, here it makes more sense to say at most how many workers we want to use. Because additional workers above like 6 (in my testing), was just wasted due to overhead (multiprocessing).monaias a dependency so we can use their random transforms during data augmentation. Their transforms are proper 3d.monaihas its own order, and the model expected the data in yet another order. In main this order is hardcoded. But I added an explicitly defined ordering for each of them so we can automatically re-order the data as needed. It needs more code, but prevents mental complexity.main. the way we center it is different in x, y vs z. I kept this difference so as to stay compatible with main as much as possible.Work to be done:
References
None
How has this PR been tested?
I tested it by training on the original
serial2pdataset. You can see the train/test loss/accuracy to see it correctly trained:train/val accuracy:

loss:

I would like to do more comparisons to see how points are (potentially) differently classified between main and this branch.
Is this a breaking change?
The only potential breaking change is that the augmentation is different. But that shouldn't affect existing model inference, only newly trained models. At the same time, we need to double check this.
We may also want to consider retraining the model?
Does this PR require an update to the documentation?
We potentially only need to add support for the
max_workersparameter.Checklist: