Skip to content

Conversation

@matham
Copy link
Contributor

@matham matham commented Mar 14, 2025

Description

This is still a WIP and needs a bit of polish.

What is this PR

  • Bug fix
  • Addition of a new feature
  • Other

Why is this PR needed?

When testing cell classification with our large whole brains (600GB+, 2.5 million+ potential cells), it took 32+ hours to run inference on all the potential cells. Looking at task manager, I could see that the GPU was barely utilized. This is typically an indication that the data was not read or processed quickly enough to keep the GPU fed.

With my changes, it only took <1.5 hours.

Also, when testing training the model with the original serial2p dataset, it took 79.6 min / epoch for main, vs 24.33 min / epoch with the changes. And it used less than half the RAM of main.

What does this PR do?

Very broadly, this PR refactors the data loading and augmentation classes so the various features can be shared among different datasets (dask arrays, tiff files containing cubes etc). But more importantly it adds caching to data loading. It also re-works the multi-processing of workers loading the data.

Consider a bunch of points with various z values. We need to get a cuboid centered on the point. Say in Z it's 20 planes. Then for every point we'd access the Dask array to get 20 planes and extract the cuboid. So for each point we potentially read 20 planes. We already sorted during inference to read sequential z points, but still unless Dask did some caching we'd be reading the same data constantly. And certainly there's overhead.

Instead, we cache the last n planes in memory and so when reading z sequentially we save a lot of disk/Dask access. This is also quite important when we split processing up among workers who may access neighboring z points, preventing Dask (OS) from caching data. But, having our cache twice the size of a cube makes this not a problem for us.

Additionally, there's a dedicated thread in the main process that does the data loading. All the worker (sub-processes) communicate with it via queues by sending it a buffer that the thread fills in with the loaded data and the worker than process it (e.g. augment, resize). pytorch tensors can be shared among processes in memory so this saves us a bunch of data copying.


Important changes:

  • The only new (potential) external parameter is the max_workers. Unlike elsewhere where we want to say how many cores we want to keep free, here it makes more sense to say at most how many workers we want to use. Because additional workers above like 6 (in my testing), was just wasted due to overhead (multiprocessing).
  • Added monai as a dependency so we can use their random transforms during data augmentation. Their transforms are proper 3d.
  • The data has 4 dimensions, channel, x, y, and z. Everyone has their own requirement for the order. Torch has its order, the data is saved on disk in its order, monai has its own order, and the model expected the data in yet another order. In main this order is hardcoded. But I added an explicitly defined ordering for each of them so we can automatically re-order the data as needed. It needs more code, but prevents mental complexity.
  • When using arrays for the data, in main we filter out any points that are on the edge of the 3d stack. However it didn't take into account that we have to rescale the input data to the given network expected size. So this filtering was incorrect for point near the boundary when the voxel size of the data is different than the voxel size of the network (see the previous point). This has been fixed.
  • For a given point, we find a cuboid centered around this point. In main. the way we center it is different in x, y vs z. I kept this difference so as to stay compatible with main as much as possible.

Work to be done:

  • Add docs
  • Add tests
  • Better compare cell classification for main vs this branch. Specifically check that using the existing weights classifies cells similarly for main and this branch.
  • Compare a model trained with main vs this branch to see how similarly they classify the same cells.

References

None

How has this PR been tested?

I tested it by training on the original serial2p dataset. You can see the train/test loss/accuracy to see it correctly trained:

train/val accuracy:
image

loss:
image

I would like to do more comparisons to see how points are (potentially) differently classified between main and this branch.

Is this a breaking change?

The only potential breaking change is that the augmentation is different. But that shouldn't affect existing model inference, only newly trained models. At the same time, we need to double check this.

We may also want to consider retraining the model?

Does this PR require an update to the documentation?

We potentially only need to add support for the max_workers parameter.

Checklist:

  • The code has been tested locally
  • Tests have been added to cover all new functionality (unit & integration)
  • The documentation has been updated to reflect any changes
  • The code has been formatted with pre-commit

@matham matham marked this pull request as draft March 14, 2025 06:11
@matham
Copy link
Contributor Author

matham commented Mar 14, 2025

I think one of the issues with the actions is Project-MONAI/MONAI#8277. It lists an older numpy as a dependency even though it's compatible with the latest numpy.

@adamltyson
Copy link
Member

It's updated on main though, so hopefully the next release will work ok.

@adamltyson
Copy link
Member

Looks like the next release should be out soon: Project-MONAI/MONAI#8421

@alessandrofelder
Copy link
Member

alessandrofelder commented May 1, 2025

There seems to have been some movement on the MONAI repo last week, getting closer to a new release, but it's not quite there yet. Just commenting here to remind ourselves that this PR still exists :)

@matham
Copy link
Contributor Author

matham commented May 5, 2025

I'm hoping to start working on this again this week! 🤞

@adamltyson
Copy link
Member

The slow release cadence of monai has made me rethink having it as a dependency. @matham, do you have an idea of how complicated it would be to remove the dependency from this PR?

@matham
Copy link
Contributor Author

matham commented May 6, 2025

Originally, when I looked at it, I saw 3 options. The original scipy, TorchIO, and monai.

  1. scipy didn't seem to do true 3d transformations - rather it treated a cube as a bunch of 2d images that is each transformed.
  2. TorchIO used simpleitk, which I didn't love as it adds another large transitive dependency and its need to move data between torch to itk. Plus it had more dependencies overall: deprecated, humanize, nibabel, numpy, packaging, rich, scipy, simpleitk, torch, tqdm, typer, wrapt, markdown-it-py, pygments, filelock, typing-extensions, networkx, jinja2, fsspec, sympy, mpmath, colorama, typer-slim, typer-cli, click, shellingham, mdurl, MarkupSafe.
  3. monai worked natively with torch without major dependencies: numpy, torch, filelock, typing-extensions, networkx, jinja2, fsspec, sympy, mpmath, MarkupSafe so I went with that.

I looked at extracting the relevant code from monai, which should be doable. But I didn't think you'd want to have to maintain that so I didn't do it. But perhaps you do prefer it? Ultimately for monai it comes down to building transforms and applying it. But my linear algebra is not amazing, especially in 3d. But I can give it a try?

There are other libraries that can handle 3d transformations, e.g. pytransform3d, or even scipy, but they transform vectors, not volumes. You can use that to transform volumes but it seems a bit out there and prone to issues.

@adamltyson
Copy link
Member

Hmm, lets go ahead with monai, and vendor the necessary code down the line if we feel the need to.

@matham matham marked this pull request as ready for review May 24, 2025 20:02
@matham
Copy link
Contributor Author

matham commented May 24, 2025

  1. I still need to add some tests, and maybe a little docs here and there.
  2. I removed monai from the dependencies and manually install them with pip install monai --no-deps on the CI. I'm not sure what we want to do, if it's ok to merge in this state, or we need to wait for a new monai release? I'm assuming the latter?

@adamltyson
Copy link
Member

I removed monai from the dependencies and manually install them with pip install monai --no-deps on the CI. I'm not sure what we want to do, if it's ok to merge in this state, or we need to wait for a new monai release? I'm assuming the latter?

I think we should probably wait, although:

  • I will ask the monai developers if they have an ETA on v1.5
  • This PR is pretty big, so it may take a while for us to review it too (the new monai release may be out by then!)

Would you mind finishing adding the tests, so we can review it all in one go?

@adamltyson
Copy link
Member

Also, I should add - thanks again for all your hard work on improving cellfinder @matham!

batch_size = self.batch_size

if shuffle:
rng = np.random.default_rng()

Check notice

Code scanning / SonarCloud

Results that depend on random number generation should be reproducible Low

Provide a seed for this random generator. See more on SonarQube Cloud
@imagesc-bot
Copy link

This pull request has been mentioned on Image.sc Forum. There might be relevant details there:

https://forum.image.sc/t/brainmapper-cell-classification-processing-time/117138/2

@alessandrofelder
Copy link
Member

alessandrofelder commented Jan 21, 2026

If it fits with your current priorities @matham - shall we pick this PR back up now the latest MONAI version should support numpy 2 IIUC?

@matham
Copy link
Contributor Author

matham commented Jan 22, 2026

I would like to resume this PR. But I wasn't sure how to move forward. This is a giant PR, which I didn't want to just dump on y'all. But I'm not sure how to pare it down.

There is about 4 different overall changes in this branch:

  1. The core changes are the wholesale changes to cube generator and augmentation, and therefore to classify and train run files etc.
  2. But, then to use it and test it with my data I had to update napari curation, because e.g. it hard-coded the voxel sizes and other things.
  3. And then to train the model and make it work with my data, I had to add support for specifying the a more nuanced learning rate scheduling.
  4. And then I further had to add support for z-score normalizing the input cube data to the model during training and testing.

I could try to split it up into these 4 parts, but of course I would then have to test them each, but I mostly all wrote them together. But, even just number 1 is a lot of changes. Is that okay?

@alessandrofelder
Copy link
Member

Thanks @matham !

Based on your latest comment, it sounds like at least 2. could be more easily separated into its own PR (and maybe also 3 + 4... not sure), as "pre-cursor" PRs to to 1. But I also don't want to give you more work: you've done so much! (And I'll have 4000 lines to review anyway 😁 )

I wonder whether a good way forward would instead be to spend a short amount of time thinking of information for me as a reviewer that might help guide me, like

  • what files/functions/classes with changes would you look at first? (my mind is not so deep in cellfinder as it once was, but I am happy to take another deepdive soon!)?
  • are there any "manual" experiments (maybe comparing main and this branch) that would help me verify that this PR is helpful (I'm sure it is, but I have to check 😄 )?
  • how the various changes fit together (The PR description and your latest comment are already helpful in this regard)
  • anything else that you can think of might help me find my way through the changes more easily

(There also are some TODOs listed in the PR description, but maybe those are out of date already?)

@alessandrofelder
Copy link
Member

(PS Also, if you could pick up the merge conflict... think it should be straightforward to resolve)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants