Skip to content

Conversation

@matham
Copy link
Contributor

@matham matham commented Mar 14, 2025

Description

This is still a WIP and needs a bit of polish.

What is this PR

  • Bug fix
  • Addition of a new feature
  • Other

Why is this PR needed?

When testing cell classification with our large whole brains (600GB+, 2.5 million+ potential cells), it took 32+ hours to run inference on all the potential cells. Looking at task manager, I could see that the GPU was barely utilized. This is typically an indication that the data was not read or processed quickly enough to keep the GPU fed.

With my changes, it only took <1.5 hours.

Also, when testing training the model with the original serial2p dataset, it took 79.6 min / epoch for main, vs 24.33 min / epoch with the changes. And it used less than half the RAM of main.

What does this PR do?

Very broadly, this PR refactors the data loading and augmentation classes so the various features can be shared among different datasets (dask arrays, tiff files containing cubes etc). But more importantly it adds caching to data loading. It also re-works the multi-processing of workers loading the data.

Consider a bunch of points with various z values. We need to get a cuboid centered on the point. Say in Z it's 20 planes. Then for every point we'd access the Dask array to get 20 planes and extract the cuboid. So for each point we potentially read 20 planes. We already sorted during inference to read sequential z points, but still unless Dask did some caching we'd be reading the same data constantly. And certainly there's overhead.

Instead, we cache the last n planes in memory and so when reading z sequentially we save a lot of disk/Dask access. This is also quite important when we split processing up among workers who may access neighboring z points, preventing Dask (OS) from caching data. But, having our cache twice the size of a cube makes this not a problem for us.

Additionally, there's a dedicated thread in the main process that does the data loading. All the worker (sub-processes) communicate with it via queues by sending it a buffer that the thread fills in with the loaded data and the worker than process it (e.g. augment, resize). pytorch tensors can be shared among processes in memory so this saves us a bunch of data copying.


Important changes:

  • The only new (potential) external parameter is the max_workers. Unlike elsewhere where we want to say how many cores we want to keep free, here it makes more sense to say at most how many workers we want to use. Because additional workers above like 6 (in my testing), was just wasted due to overhead (multiprocessing).
  • Added monai as a dependency so we can use their random transforms during data augmentation. Their transforms are proper 3d.
  • The data has 4 dimensions, channel, x, y, and z. Everyone has their own requirement for the order. Torch has its order, the data is saved on disk in its order, monai has its own order, and the model expected the data in yet another order. In main this order is hardcoded. But I added an explicitly defined ordering for each of them so we can automatically re-order the data as needed. It needs more code, but prevents mental complexity.
  • When using arrays for the data, in main we filter out any points that are on the edge of the 3d stack. However it didn't take into account that we have to rescale the input data to the given network expected size. So this filtering was incorrect for point near the boundary when the voxel size of the data is different than the voxel size of the network (see the previous point). This has been fixed.
  • For a given point, we find a cuboid centered around this point. In main. the way we center it is different in x, y vs z. I kept this difference so as to stay compatible with main as much as possible.

Work to be done:

  • Add docs
  • Add tests
  • Better compare cell classification for main vs this branch. Specifically check that using the existing weights classifies cells similarly for main and this branch.
  • Compare a model trained with main vs this branch to see how similarly they classify the same cells.

References

None

How has this PR been tested?

I tested it by training on the original serial2p dataset. You can see the train/test loss/accuracy to see it correctly trained:

train/val accuracy:
image

loss:
image

I would like to do more comparisons to see how points are (potentially) differently classified between main and this branch.

Is this a breaking change?

The only potential breaking change is that the augmentation is different. But that shouldn't affect existing model inference, only newly trained models. At the same time, we need to double check this.

We may also want to consider retraining the model?

Does this PR require an update to the documentation?

We potentially only need to add support for the max_workers parameter.

Checklist:

  • The code has been tested locally
  • Tests have been added to cover all new functionality (unit & integration)
  • The documentation has been updated to reflect any changes
  • The code has been formatted with pre-commit

@matham matham marked this pull request as draft March 14, 2025 06:11
@matham
Copy link
Contributor Author

matham commented Mar 14, 2025

I think one of the issues with the actions is Project-MONAI/MONAI#8277. It lists an older numpy as a dependency even though it's compatible with the latest numpy.

@adamltyson
Copy link
Member

It's updated on main though, so hopefully the next release will work ok.

@adamltyson
Copy link
Member

Looks like the next release should be out soon: Project-MONAI/MONAI#8421

@alessandrofelder
Copy link
Member

alessandrofelder commented May 1, 2025

There seems to have been some movement on the MONAI repo last week, getting closer to a new release, but it's not quite there yet. Just commenting here to remind ourselves that this PR still exists :)

@matham
Copy link
Contributor Author

matham commented May 5, 2025

I'm hoping to start working on this again this week! 🤞

@adamltyson
Copy link
Member

The slow release cadence of monai has made me rethink having it as a dependency. @matham, do you have an idea of how complicated it would be to remove the dependency from this PR?

@matham
Copy link
Contributor Author

matham commented May 6, 2025

Originally, when I looked at it, I saw 3 options. The original scipy, TorchIO, and monai.

  1. scipy didn't seem to do true 3d transformations - rather it treated a cube as a bunch of 2d images that is each transformed.
  2. TorchIO used simpleitk, which I didn't love as it adds another large transitive dependency and its need to move data between torch to itk. Plus it had more dependencies overall: deprecated, humanize, nibabel, numpy, packaging, rich, scipy, simpleitk, torch, tqdm, typer, wrapt, markdown-it-py, pygments, filelock, typing-extensions, networkx, jinja2, fsspec, sympy, mpmath, colorama, typer-slim, typer-cli, click, shellingham, mdurl, MarkupSafe.
  3. monai worked natively with torch without major dependencies: numpy, torch, filelock, typing-extensions, networkx, jinja2, fsspec, sympy, mpmath, MarkupSafe so I went with that.

I looked at extracting the relevant code from monai, which should be doable. But I didn't think you'd want to have to maintain that so I didn't do it. But perhaps you do prefer it? Ultimately for monai it comes down to building transforms and applying it. But my linear algebra is not amazing, especially in 3d. But I can give it a try?

There are other libraries that can handle 3d transformations, e.g. pytransform3d, or even scipy, but they transform vectors, not volumes. You can use that to transform volumes but it seems a bit out there and prone to issues.

@adamltyson
Copy link
Member

Hmm, lets go ahead with monai, and vendor the necessary code down the line if we feel the need to.

@matham matham marked this pull request as ready for review May 24, 2025 20:02
@matham
Copy link
Contributor Author

matham commented May 24, 2025

  1. I still need to add some tests, and maybe a little docs here and there.
  2. I removed monai from the dependencies and manually install them with pip install monai --no-deps on the CI. I'm not sure what we want to do, if it's ok to merge in this state, or we need to wait for a new monai release? I'm assuming the latter?

@adamltyson
Copy link
Member

I removed monai from the dependencies and manually install them with pip install monai --no-deps on the CI. I'm not sure what we want to do, if it's ok to merge in this state, or we need to wait for a new monai release? I'm assuming the latter?

I think we should probably wait, although:

  • I will ask the monai developers if they have an ETA on v1.5
  • This PR is pretty big, so it may take a while for us to review it too (the new monai release may be out by then!)

Would you mind finishing adding the tests, so we can review it all in one go?

@adamltyson
Copy link
Member

Also, I should add - thanks again for all your hard work on improving cellfinder @matham!

batch_size = self.batch_size

if shuffle:
rng = np.random.default_rng()

Check notice

Code scanning / SonarCloud

Results that depend on random number generation should be reproducible Low

Provide a seed for this random generator. See more on SonarQube Cloud
@imagesc-bot
Copy link

This pull request has been mentioned on Image.sc Forum. There might be relevant details there:

https://forum.image.sc/t/brainmapper-cell-classification-processing-time/117138/2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants