Skip to content

Commit

Permalink
adding vista2d (#31)
Browse files Browse the repository at this point in the history
Adding Vista2D code: a training and inference pipeline for cell
segmentation.

---------

Signed-off-by: am <am>
Co-authored-by: am <am>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
myron and pre-commit-ci[bot] authored Aug 5, 2024
1 parent 31e53e0 commit a8ab29f
Show file tree
Hide file tree
Showing 66 changed files with 184,841 additions and 0 deletions.
178 changes: 178 additions & 0 deletions vista2d/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
<!--
Copyright (c) MONAI Consortium
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->

## Overview

The **VISTA2D** is a cell segmentation training and inference pipeline for cell imaging [[`Blog`](https://developer.nvidia.com/blog/advancing-cell-segmentation-and-morphology-analysis-with-nvidia-ai-foundation-model-vista-2d/)].

A pretrained model was trained on collection of 15K public microscopy images. The data collection and training can be reproduced following the [tutorial](./download_preprocessor/). Alternatively, the model can be retrained on your own dataset. The pretrained vista2d model achieves good performance on diverse set of cell types, microscopy image modalities, and can be further finetuned if necessary. The codebase utilizes several components from other great works including [SegmentAnything](https://github.com/facebookresearch/segment-anything) and [Cellpose](https://www.cellpose.org/), which must be pip installed as dependencies. Vista2D codebase follows MONAI bundle format and its [specifications](https://docs.monai.io/en/stable/mb_specification.html).

<div align="center"> <img src="https://developer-blogs.nvidia.com/wp-content/uploads/2024/04/magnified-cells-1.png" width="800"/> </div>


### Model highlights

- Robust deep learning algorithm based on transformers
- Generalist model as compared to specialist models
- Multiple dataset sources and file formats supported
- Multiple modalities of imaging data collectively supported
- Multi-GPU and multinode training support


### Generalization performance

Evaluation was performed for the VISTA2D model with multiple public datasets, such as TissueNet, LIVECell, Omnipose, DeepBacs, Cellpose, and [more](./docs/data_license.txt). A total of ~15K annotated cell images were collected to train the generalist VISTA2D model. This ensured broad coverage of many different types of cells, which were acquired by various imaging acquisition types. The benchmark results of the experiment were performed on held-out test sets for each public dataset that were already defined by the dataset contributors. Average precision at an IoU threshold of 0.5 was used for evaluating performance. The benchmark results are reported in comparison with the best numbers found in the literature, in addition to a specialist VISTA2D model trained only on a particular dataset or a subset of data.

<div align="center"> <img src="https://developer-blogs.nvidia.com/wp-content/uploads/2024/04/vista-2d-model-precision-versus-specialist-model-baseline-performance.png" width="800"/> </div>



### Install dependencies

```
pip install monai fire tifffile imagecodecs pillow fastremap
pip install --no-deps cellpose natsort roifile
pip install git+https://github.com/facebookresearch/segment-anything.git
pip install mlflow psutil pynvml #optional for MLFlow support
```

### Execute training
```bash
python -m monai.bundle run_workflow "scripts.workflow.VistaCell" --config_file configs/hyper_parameters.yaml
```

#### Quick run with a few data points
```bash
python -m monai.bundle run_workflow "scripts.workflow.VistaCell" --config_file configs/hyper_parameters.yaml --quick True --train#trainer#max_epochs 3
```

### Execute multi-GPU training
```bash
torchrun --nproc_per_node=gpu -m monai.bundle run_workflow "scripts.workflow.VistaCell" --config_file configs/hyper_parameters.yaml
```

### Execute validation
```bash
python -m monai.bundle run_workflow "scripts.workflow.VistaCell" --config_file configs/hyper_parameters.yaml --pretrained_ckpt_name model.pt --mode eval
```
(can append `--quick True` for quick demoing)

### Execute multi-GPU validation
```bash
torchrun --nproc_per_node=gpu -m monai.bundle run_workflow "scripts.workflow.VistaCell" --config_file configs/hyper_parameters.yaml --mode eval
```

### Execute inference
```bash
python -m monai.bundle run --config_file configs/inference.json
```

### Execute multi-GPU inference
```bash
torchrun --nproc_per_node=gpu -m monai.bundle run_workflow "scripts.workflow.VistaCell" --config_file configs/hyper_parameters.yaml --mode infer --pretrained_ckpt_name model.pt
```
(can append `--quick True` for quick demoing)



#### Finetune starting from a trained checkpoint
(we use a smaller learning rate, small number of epochs, and initialize from a checkpoint)
```bash
python -m monai.bundle run_workflow "scripts.workflow.VistaCell" --config_file configs/hyper_parameters.yaml --learning_rate=0.001 --train#trainer#max_epochs 20 --pretrained_ckpt_path /path/to/saved/model.pt
```


#### Configuration options

To disable the segmentation writing:
```
--postprocessing []
```

Load a checkpoint for validation or inference (relative path within results directory):
```
--pretrained_ckpt_name "model.pt"
```

Load a checkpoint for validation or inference (absolute path):
```
--pretrained_ckpt_path "/path/to/another/location/model.pt"
```

`--mode eval` or `--mode infer`will use the corresponding configurations from the `validate` or `infer`
of the `configs/hyper_parameters.yaml`.

By default the generated `model.pt` corresponds to the checkpoint at the best validation score,
`model_final.pt` is the checkpoint after the latest training epoch.


### Development

For development purposes it's possible to run the script directly (without monai bundle calls)

```bash
python scripts/workflow.py --config_file configs/hyper_parameters.yaml ...
torchrun --nproc_per_node=gpu -m scripts/workflow.py --config_file configs/hyper_parameters.yaml ..
```

### MLFlow support

Enable MLFlow logging by specifying "mlflow_tracking_uri" (can be local or remote URL).

```bash
python -m monai.bundle run_workflow "scripts.workflow.VistaCell" --config_file configs/hyper_parameters.yaml --mlflow_tracking_uri=http://127.0.0.1:8080
```

Optionally use "--mlflow_run_name=.." to specify MLFlow experiment name, and "--mlflow_log_system_metrics=True/False" to enable logging of CPU/GPU resources (requires pip install psutil pynvml)



### Unit tests

Test single GPU training:
```
python unit_tests/test_vista2d.py
```

Test multi-GPU training (may need to uncomment the `"--standalone"` in the `unit_tests/utils.py` file):
```
python unit_tests/test_vista2d_mgpu.py
```

## Compute Requirements
Min GPU memory requirements 16Gb.


## Contributing
Vista2D codebase follows MONAI bundle format and its [specifications](https://docs.monai.io/en/stable/mb_specification.html).
Make sure to run pre-commit before committing code changes to git
```bash
pip install pre-commit
python3 -m pre_commit run --all-files
```


## Community

Join the conversation on Twitter [@ProjectMONAI](https://twitter.com/ProjectMONAI) or join
our [Slack channel](https://projectmonai.slack.com/archives/C031QRE0M1C).

Ask and answer questions on [MONAI VISTA's GitHub discussions tab](https://github.com/Project-MONAI/VISTA/discussions).

## License

The codebase is under Apache 2.0 Licence. The model weight is released under CC-BY-NC-SA-4.0. For various public data licenses please see [data_license.txt](./docs/data_license.txt).

## Acknowledgement
- [segment-anything](https://github.com/facebookresearch/segment-anything)
- [Cellpose](https://www.cellpose.org/)
135 changes: 135 additions & 0 deletions vista2d/configs/hyper_parameters.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
imports:
- $import os

# seed: 28022024 # uncommend for deterministic results (but slower)
seed: null

bundle_root: "."
ckpt_path: $os.path.join(@bundle_root, "models") # location to save checkpoints
output_dir: $os.path.join(@bundle_root, "eval") # location to save events and logs
log_output_file: $os.path.join(@output_dir, "vista_cell.log")

mlflow_tracking_uri: null # enable mlflow logging, e.g. $@ckpt_path + '/mlruns/ or "http://127.0.0.1:8080" or a remote url
mlflow_log_system_metrics: true # log system metrics to mlflow (requires: pip install psutil pynvml)
mlflow_run_name: null # optional name of the current run

ckpt_save: true # save checkpoints periodically
amp: true
amp_dtype: "float16" #float16 or bfloat16 (Ampere or newer)
channels_last: true
compile: false # complie the model for faster processing

start_epoch: 0
run_final_testing: true
use_weighted_sampler: false # only applicable when using several dataset jsons for data_list_files

pretrained_ckpt_name: null
pretrained_ckpt_path: null

# for commandline setting of a single dataset
datalist: datalists/tissuenet_skin_mibi_datalist.json
basedir: /data/tissuenet
data_list_files:
- {datalist: "@datalist", basedir: "@basedir"}


fold: 0
learning_rate: 0.01 # try 1.0e-4 if using AdamW
quick: false # whether to use a small subset of data for quick testing
roi_size: [256, 256]

train:
skip: false
handlers: []
trainer:
num_warmup_epochs: 3
max_epochs: 200
num_epochs_per_saving: 1
num_epochs_per_validation: null
num_workers: 4
batch_size: 1
dataset:
preprocessing:
roi_size: "@roi_size"
data:
key: null # set to 'testing' to use this subset in periodic validations, instead of the the validation set
data_list_files: "@data_list_files"

dataset:
data:
key: "testing"
data_list_files: "@data_list_files"

validate:
grouping: true
evaluator:
postprocessing: "@postprocessing"
dataset:
data: "@dataset#data"
batch_size: 1
num_workers: 4
preprocessing: null
postprocessing: null
inferer: null
handlers: null
key_metric: null

infer:
evaluator:
postprocessing: "@postprocessing"
dataset:
data: "@dataset#data"


device: "$torch.device(('cuda:' + os.environ.get('LOCAL_RANK', '0')) if torch.cuda.is_available() else 'cpu')"
network_def:
_target_: scripts.cell_sam_wrapper.CellSamWrapper
checkpoint: $os.path.join(@ckpt_path, "sam_vit_b_01ec64.pth")
network: $@network_def.to(@device)

loss_function:
_target_: scripts.components.CellLoss

key_metric:
_target_: scripts.components.CellAcc

# optimizer:
# _target_: torch.optim.AdamW
# params: [email protected]()
# lr: "@learning_rate"
# weight_decay: 1.0e-5

optimizer:
_target_: torch.optim.SGD
params: [email protected]()
momentum: 0.9
lr: "@learning_rate"
weight_decay: 1.0e-5

lr_scheduler:
_target_: monai.optimizers.lr_scheduler.WarmupCosineSchedule
optimizer: "@optimizer"
warmup_steps: "@train#trainer#num_warmup_epochs"
warmup_multiplier: 0.1
t_total: "@train#trainer#max_epochs"

inferer:
sliding_inferer:
_target_: monai.inferers.SlidingWindowInfererAdapt
roi_size: "@roi_size"
sw_batch_size: 1
overlap: 0.625
mode: "gaussian"
cache_roi_weight_map: true
progress: false

image_saver:
_target_: scripts.components.SaveTiffd
keys: "seg"
output_dir: "@output_dir"
nested_folder: false

postprocessing:
_target_: monai.transforms.Compose
transforms:
- "@image_saver"
Loading

0 comments on commit a8ab29f

Please sign in to comment.