diff --git a/configs/vision/radiology/online/segmentation/btcv.yaml b/configs/vision/radiology/online/segmentation/btcv.yaml new file mode 100644 index 000000000..ac22d63be --- /dev/null +++ b/configs/vision/radiology/online/segmentation/btcv.yaml @@ -0,0 +1,192 @@ +--- +trainer: + class_path: eva.Trainer + init_args: + n_runs: &N_RUNS ${oc.env:N_RUNS, 5} + default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, radiology/voco_b}/btcv} + max_epochs: &MAX_EPOCHS ${oc.env:MAX_EPOCHS, 500} + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} + check_val_every_n_epoch: ${oc.env:CHECK_VAL_EVERY_N_EPOCHS, 50} + num_sanity_val_steps: 0 + log_every_n_steps: ${oc.env:LOG_EVERY_N_STEPS, 100} + callbacks: + - class_path: eva.callbacks.ConfigurationLogger + - class_path: lightning.pytorch.callbacks.TQDMProgressBar + init_args: + refresh_rate: ${oc.env:TQDM_REFRESH_RATE, 1} + logger: + - class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: *OUTPUT_ROOT + name: "" + # - class_path: lightning.pytorch.loggers.WandbLogger + # init_args: + # project: ${oc.env:WANDB_PROJECT, radiology} + # name: ${oc.env:WANDB_RUN_NAME, btcv-${oc.env:MODEL_NAME, radiology/voco_b}} +model: + class_path: eva.vision.models.modules.SemanticSegmentationModule + init_args: + encoder: + class_path: eva.vision.models.ModelFromRegistry + init_args: + model_name: ${oc.env:MODEL_NAME, radiology/voco_b} + model_kwargs: + out_indices: ${oc.env:OUT_INDICES, 6} + decoder: + class_path: eva.vision.models.networks.decoders.segmentation.SwinUNETRDecoder + init_args: + feature_size: ${oc.env:IN_FEATURES, 48} + out_channels: &NUM_CLASSES 14 + inferer: + class_path: monai.inferers.SlidingWindowInferer + init_args: + roi_size: &ROI_SIZE ${oc.env:ROI_SIZE, [96, 96, 96]} + sw_batch_size: ${oc.env:SW_BATCH_SIZE, 8} + overlap: ${oc.env:SW_OVERLAP, 0.75} + criterion: + class_path: monai.losses.DiceCELoss + init_args: + include_background: false + to_onehot_y: true + softmax: true + lr_multiplier_encoder: 0.0 + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: ${oc.env:LR_VALUE, 0.001} + betas: [0.9, 0.999] + weight_decay: ${oc.env:WEIGHT_DECAY, 0.01} + lr_scheduler: + class_path: torch.optim.lr_scheduler.CosineAnnealingLR + init_args: + T_max: *MAX_EPOCHS + eta_min: ${oc.env:LR_VALUE_END, 0.0001} + postprocess: + predictions_transforms: + - class_path: eva.core.models.transforms.AsDiscrete + init_args: + argmax: true + to_onehot: *NUM_CLASSES + targets_transforms: + - class_path: eva.core.models.transforms.AsDiscrete + init_args: + to_onehot: *NUM_CLASSES + metrics: + common: + - class_path: eva.metrics.AverageLoss + evaluation: + - class_path: torchmetrics.segmentation.DiceScore + init_args: + num_classes: *NUM_CLASSES + include_background: false + average: macro + input_format: one-hot + - class_path: torchmetrics.ClasswiseWrapper + init_args: + metric: + class_path: eva.vision.metrics.MonaiDiceScore + init_args: + include_background: true + num_classes: *NUM_CLASSES + input_format: one-hot + reduction: none + prefix: DiceScore_ + labels: + - "0_background" + - "1_spleen" + - "2_right_kidney" + - "3_left_kidney" + - "4_gallbladder" + - "5_esophagus" + - "6_liver" + - "7_stomach" + - "8_aorta" + - "9_inferior_vena_cava" + - "10_portal_and_splenic_vein" + - "11_pancreas" + - "12_right_adrenal_gland" + - "13_left_adrenal_gland" +data: + class_path: eva.DataModule + init_args: + datasets: + train: + class_path: eva.vision.datasets.BTCV + init_args: &DATASET_ARGS + root: ${oc.env:DATA_ROOT, ./data/btcv} + split: train + download: ${oc.env:DOWNLOAD_DATA, false} + # Set `download: true` to download the dataset automatically + # The BTCV dataset is distributed under the CC BY 4.0 license + # (https://creativecommons.org/licenses/by-nc-nd/4.0/legalcode) + transforms: + class_path: torchvision.transforms.v2.Compose + init_args: + transforms: + - class_path: eva.vision.data.transforms.EnsureChannelFirst + init_args: + channel_dim: 1 + - class_path: eva.vision.data.transforms.Spacing + init_args: + pixdim: [1.5, 1.5, 1.5] + - class_path: eva.vision.data.transforms.ScaleIntensityRange + init_args: + input_range: + - ${oc.env:SCALE_INTENSITY_MIN, -175.0} + - ${oc.env:SCALE_INTENSITY_MAX, 250.0} + output_range: [0.0, 1.0] + - class_path: eva.vision.data.transforms.CropForeground + - class_path: eva.vision.data.transforms.SpatialPad + init_args: + spatial_size: *ROI_SIZE + - class_path: eva.vision.data.transforms.RandCropByPosNegLabel + init_args: + spatial_size: *ROI_SIZE + num_samples: ${oc.env:SAMPLE_BATCH_SIZE, 4} + pos: ${oc.env:RAND_CROP_POS_WEIGHT, 9} + neg: ${oc.env:RAND_CROP_NEG_WEIGHT, 1} + - class_path: eva.vision.data.transforms.RandFlip + init_args: + spatial_axes: [0, 1, 2] + - class_path: eva.vision.data.transforms.RandRotate90 + init_args: + spatial_axes: [1, 2] + - class_path: eva.vision.data.transforms.RandScaleIntensity + init_args: + factors: 0.1 + prob: 0.1 + - class_path: eva.vision.data.transforms.RandShiftIntensity + init_args: + offsets: 0.1 + prob: 0.1 + val: + class_path: eva.vision.datasets.BTCV + init_args: + <<: *DATASET_ARGS + split: val + transforms: + class_path: torchvision.transforms.v2.Compose + init_args: + transforms: + - class_path: eva.vision.data.transforms.EnsureChannelFirst + init_args: + channel_dim: 1 + - class_path: eva.vision.data.transforms.Spacing + init_args: + pixdim: [1.5, 1.5, 1.5] + - class_path: eva.vision.data.transforms.ScaleIntensityRange + init_args: + input_range: + - ${oc.env:SCALE_INTENSITY_MIN, -175.0} + - ${oc.env:SCALE_INTENSITY_MAX, 250.0} + output_range: [0.0, 1.0] + - class_path: eva.vision.data.transforms.CropForeground + dataloaders: + train: + batch_size: ${oc.env:BATCH_SIZE, 2} + num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 8} + shuffle: true + collate_fn: eva.vision.data.dataloaders.collate_fn.collection_collate + val: + batch_size: 1 + num_workers: *N_DATA_WORKERS diff --git a/docs/datasets/btcv.md b/docs/datasets/btcv.md new file mode 100644 index 000000000..556402fa1 --- /dev/null +++ b/docs/datasets/btcv.md @@ -0,0 +1,59 @@ +# Beyond the Cranial Vault (BTCV) Abdomen dataset. + +The BTCV dataset comprises abdominal CT scans acquired at the Vanderbilt University Medical Center from metastatic liver cancer patients or post-operative ventral hernia patients. + +The annotations cover segmentations of the spleen, right and left kidney, gallbladder, esophagus, liver, stomach, aorta, inferior vena cava, portal vein and splenic vein, pancreas, right adrenal gland, left adrenal gland are included in this data set. Images were manually labeled by two experienced undergraduate students, and verified by a radiologist. + + +## Raw data + +### Key stats + +| | | +|-----------------------|-----------------------------------------------------------| +| **Modality** | Vision (radiology, CT scans) | +| **Task** | Segmentation (14 classes) | +| **Image dimension** | 512 x 512 x ~140 (number of slices) | +| **Files format** | `.nii` ("NIFTI") images | +| **Number of scans** | 30 | +| **Splits in use** | train (80%) / val (20%) | + + +### Splits + +While the full dataset contains 90 CT scans, we use the train/val split from MONAI which uses a subset of 30 CT scans (https://github.com/Luffy03/Large-Scale-Medical/blob/main/Downstream/monai/BTCV/dataset/dataset_0.json): + +| Splits | Train | Validation | +|----------------|------------------|-------------------| +| #Scans | 24 (80%) | 6 (20%) | + + +### Organization + +The training data are organized as follows: + +``` +imagesTr +├── img0001.nii.gz +├── img0002.nii.gz +└── ... + +labelsTr +├── label0001.nii.gz +├── label0002.nii.gz +└── ... +``` + +## Download + +The `BTCV` dataset class supports downloading the data during runtime by setting the init argument `download=True`. + +## Relevant links + +* [zenodo download source](https://zenodo.org/records/1169361) +* [huggingface dataset](https://huggingface.co/datasets/Luffy503/VoCo_Downstream/blob/main/BTCV.zip) + + +## License + +[CC BY 4.0](https://creativecommons.org/licenses/by/4.0) diff --git a/docs/datasets/index.md b/docs/datasets/index.md index 97cb52f01..dcf47ba49 100644 --- a/docs/datasets/index.md +++ b/docs/datasets/index.md @@ -34,7 +34,8 @@ | Dataset | #Images | Image Size | Task | Download provided |---|---|---|---|---| +| [BTCV](btcv.md) | 30 | 512 x 512 x ~140 \* | Semantic Segmentation (14 classes) | Yes | | [TotalSegmentator](total_segmentator.md) | 1228 | ~300 x ~300 x ~350 \* | Semantic Segmentation (117 classes) | Yes | -| [LiTS](lits.md) | 131 (58638) | ~300 x ~300 x ~350 \* | Semantic Segmentation (2 classes) | No | +| [LiTS](lits.md) | 131 (58638) | ~300 x ~300 x ~350 \* | Semantic Segmentation (3 classes) | No | \* 3D images of varying sizes diff --git a/docs/index.md b/docs/index.md index 8b482fe6f..941e754dd 100644 --- a/docs/index.md +++ b/docs/index.md @@ -61,8 +61,9 @@ Supported datasets & tasks include: *Radiology datasets* -- **[TotalSegmentator](datasets/total_segmentator.md)**: radiology/CT-scan for segmentation of anatomical structures -- **[LiTS](datasets/lits.md)**: radiology/CT-scan for segmentation of liver and tumor +- **[BTCV](datasets/btcv.md)**: Segmentation of abdominal organs (CT scans). +- **[TotalSegmentator](datasets/total_segmentator.md)**: Segmentation of anatomical structures (CT scans). +- **[LiTS](datasets/lits.md)**: Segmentation of liver and tumor (CT scans). To evaluate FMs, *eva* provides support for different model-formats, including models trained with PyTorch, models available on HuggingFace and ONNX-models. For other formats custom wrappers can be implemented. diff --git a/mkdocs.yml b/mkdocs.yml index 7458bcea3..700b82d56 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -99,6 +99,7 @@ nav: - PANDA: datasets/panda.md - PANDASmall: datasets/panda_small.md - Radiology: + - BTCV: datasets/btcv.md - TotalSegmentator: datasets/total_segmentator.md - LiTS: datasets/lits.md - Reference API: diff --git a/src/eva/core/models/modules/head.py b/src/eva/core/models/modules/head.py index 4d3732ced..56a4b3b3c 100644 --- a/src/eva/core/models/modules/head.py +++ b/src/eva/core/models/modules/head.py @@ -1,6 +1,6 @@ """Neural Network Head Module.""" -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, List import torch from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable @@ -108,7 +108,9 @@ def test_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPU return self._batch_step(batch) @override - def predict_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> torch.Tensor: + def predict_step( + self, batch: INPUT_BATCH, *args: Any, **kwargs: Any + ) -> torch.Tensor | List[torch.Tensor]: tensor = INPUT_BATCH(*batch).data return tensor if self.backbone is None else self.backbone(tensor) diff --git a/src/eva/core/models/modules/typings.py b/src/eva/core/models/modules/typings.py index a999a7a3d..3e7a64985 100644 --- a/src/eva/core/models/modules/typings.py +++ b/src/eva/core/models/modules/typings.py @@ -1,6 +1,6 @@ """Type annotations for model modules.""" -from typing import Any, Dict, NamedTuple +from typing import Any, Dict, List, NamedTuple import lightning.pytorch as pl import torch @@ -13,7 +13,7 @@ class INPUT_BATCH(NamedTuple): """The default input batch data scheme.""" - data: torch.Tensor + data: torch.Tensor | List[torch.Tensor] """The data batch.""" targets: torch.Tensor | None = None diff --git a/src/eva/core/models/transforms/__init__.py b/src/eva/core/models/transforms/__init__.py index db20ffa94..a8f5c940f 100644 --- a/src/eva/core/models/transforms/__init__.py +++ b/src/eva/core/models/transforms/__init__.py @@ -1,6 +1,7 @@ """Model outputs transforms API.""" +from eva.core.models.transforms.as_discrete import AsDiscrete from eva.core.models.transforms.extract_cls_features import ExtractCLSFeatures from eva.core.models.transforms.extract_patch_features import ExtractPatchFeatures -__all__ = ["ExtractCLSFeatures", "ExtractPatchFeatures"] +__all__ = ["AsDiscrete", "ExtractCLSFeatures", "ExtractPatchFeatures"] diff --git a/src/eva/core/models/transforms/as_discrete.py b/src/eva/core/models/transforms/as_discrete.py new file mode 100644 index 000000000..a00d3728f --- /dev/null +++ b/src/eva/core/models/transforms/as_discrete.py @@ -0,0 +1,57 @@ +"""Defines the AsDiscrete transformation.""" + +import torch + + +class AsDiscrete: + """Convert the logits tensor to discrete values.""" + + def __init__( + self, + argmax: bool = False, + to_onehot: int | bool | None = None, + threshold: float | None = None, + ) -> None: + """Convert the input tensor/array into discrete values. + + Args: + argmax: Whether to execute argmax function on input data before transform. + to_onehot: if not None, convert input data into the one-hot format with + specified number of classes. If bool, it will try to infer the number + of classes. + threshold: If not None, threshold the float values to int number 0 or 1 + with specified threshold. + """ + super().__init__() + + self._argmax = argmax + self._to_onehot = to_onehot + self._threshold = threshold + + def __call__(self, tensor: torch.Tensor) -> torch.Tensor: + """Call method for the transformation.""" + if self._argmax: + tensor = torch.argmax(tensor, dim=1, keepdim=True) + + if self._to_onehot is not None: + tensor = _one_hot(tensor, num_classes=self._to_onehot, dim=1, dtype=torch.long) + + if self._threshold is not None: + tensor = tensor >= self._threshold + + return tensor + + +def _one_hot( + tensor: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.float, dim: int = 1 +) -> torch.Tensor: + """Convert input tensor into one-hot format (implementation taken from MONAI).""" + shape = list(tensor.shape) + if shape[dim] != 1: + raise AssertionError(f"Input tensor must have 1 channel at dim {dim}.") + + shape[dim] = num_classes + o = torch.zeros(size=shape, dtype=dtype, device=tensor.device) + tensor = o.scatter_(dim=dim, index=tensor.long(), value=1) + + return tensor diff --git a/src/eva/core/models/wrappers/_utils.py b/src/eva/core/models/wrappers/_utils.py index d2ebd79cb..cb1d9d946 100644 --- a/src/eva/core/models/wrappers/_utils.py +++ b/src/eva/core/models/wrappers/_utils.py @@ -63,7 +63,7 @@ def load_state_dict_from_url( os.makedirs(model_dir, exist_ok=True) cached_file = os.path.join(model_dir, filename or os.path.basename(url)) - if force or not _check_integrity(cached_file, md5): + if force or not os.path.exists(cached_file) or not _check_integrity(cached_file, md5): sys.stderr.write(f"Downloading: '{url}' to {cached_file}\n") _download_url_to_file(url, cached_file, progress=progress) if md5 is None or not _check_integrity(cached_file, md5): diff --git a/src/eva/vision/data/dataloaders/__init__.py b/src/eva/vision/data/dataloaders/__init__.py new file mode 100644 index 000000000..7367b69f7 --- /dev/null +++ b/src/eva/vision/data/dataloaders/__init__.py @@ -0,0 +1,5 @@ +"""Dataloader related utilities and functions.""" + +from eva.vision.data.dataloaders import collate_fn + +__all__ = ["collate_fn"] diff --git a/src/eva/vision/data/dataloaders/collate_fn/__init__.py b/src/eva/vision/data/dataloaders/collate_fn/__init__.py new file mode 100644 index 000000000..da01ce9a5 --- /dev/null +++ b/src/eva/vision/data/dataloaders/collate_fn/__init__.py @@ -0,0 +1,5 @@ +"""Dataloader collate API.""" + +from eva.vision.data.dataloaders.collate_fn.collection import collection_collate + +__all__ = ["collection_collate"] diff --git a/src/eva/vision/data/dataloaders/collate_fn/collection.py b/src/eva/vision/data/dataloaders/collate_fn/collection.py new file mode 100644 index 000000000..9a7d7efa4 --- /dev/null +++ b/src/eva/vision/data/dataloaders/collate_fn/collection.py @@ -0,0 +1,22 @@ +"""Data only collate filter function.""" + +from typing import Any, List + +import torch + +from eva.core.models.modules.typings import INPUT_BATCH + + +def collection_collate(batch: List[List[INPUT_BATCH]]) -> Any: + """Collate function for stacking a collection of data samples. + + Args: + batch: The batch to be collated. + + Returns: + The collated batch. + """ + tensors, targets, metadata = zip(*batch, strict=False) + batch_tensors = torch.cat(list(map(torch.stack, tensors))) + batch_targets = torch.cat(list(map(torch.stack, targets))) + return batch_tensors, batch_targets, metadata diff --git a/src/eva/vision/metrics/segmentation/monai_dice.py b/src/eva/vision/metrics/segmentation/monai_dice.py index 9e7f55ac9..d780895bc 100644 --- a/src/eva/vision/metrics/segmentation/monai_dice.py +++ b/src/eva/vision/metrics/segmentation/monai_dice.py @@ -1,5 +1,7 @@ """Wrapper for dice score metric from MONAI.""" +from typing import Literal + from monai.metrics.meandice import DiceMetric from typing_extensions import override @@ -14,6 +16,7 @@ def __init__( self, num_classes: int, include_background: bool = True, + input_format: Literal["one-hot", "index"] = "index", reduction: str = "mean", ignore_index: int | None = None, **kwargs, @@ -24,6 +27,8 @@ def __init__( num_classes: The number of classes in the dataset. include_background: Whether to include the background class in the computation. reduction: The method to reduce the dice score. Options are `"mean"`, `"sum"`, `"none"`. + input_format: Choose between "one-hot" for one-hot encoded tensors or "index" + for index tensors. ignore_index: Integer specifying a target class to ignore. If given, this class index does not contribute to the returned score. kwargs: Additional keyword arguments for instantiating monai's `DiceMetric` class. @@ -40,11 +45,13 @@ def __init__( self.reduction = reduction self.orig_num_classes = num_classes self.ignore_index = ignore_index + self.input_format = input_format @override def update(self, preds, target): - preds = _utils.index_to_one_hot(preds, num_classes=self.orig_num_classes) - target = _utils.index_to_one_hot(target, num_classes=self.orig_num_classes) + if self.input_format == "index": + preds = _utils.index_to_one_hot(preds, num_classes=self.orig_num_classes) + target = _utils.index_to_one_hot(target, num_classes=self.orig_num_classes) if self.ignore_index is not None: preds, target = _utils.apply_ignore_index(preds, target, self.ignore_index) return super().update(preds, target) diff --git a/src/eva/vision/models/modules/semantic_segmentation.py b/src/eva/vision/models/modules/semantic_segmentation.py index b9596f12e..ae0fec57d 100644 --- a/src/eva/vision/models/modules/semantic_segmentation.py +++ b/src/eva/vision/models/modules/semantic_segmentation.py @@ -25,7 +25,7 @@ class SemanticSegmentationModule(module.ModelModule): def __init__( self, - decoder: decoders.Decoder, + decoder: decoders.Decoder | nn.Module, criterion: Callable[..., torch.Tensor], encoder: Dict[str, Any] | Callable[[torch.Tensor], List[torch.Tensor]] | None = None, lr_multiplier_encoder: float = 0.0, @@ -133,7 +133,9 @@ def test_step(self, batch: INPUT_TENSOR_BATCH, *args: Any, **kwargs: Any) -> STE return self._batch_step(batch) @override - def predict_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> torch.Tensor: + def predict_step( + self, batch: INPUT_BATCH, *args: Any, **kwargs: Any + ) -> torch.Tensor | List[torch.Tensor]: tensor = INPUT_BATCH(*batch).data return self.encoder(tensor) if isinstance(self.encoder, nn.Module) else tensor