Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 192 additions & 0 deletions configs/vision/radiology/online/segmentation/btcv.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
---
trainer:
class_path: eva.Trainer
init_args:
n_runs: &N_RUNS ${oc.env:N_RUNS, 5}
default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, radiology/voco_b}/btcv}
max_epochs: &MAX_EPOCHS ${oc.env:MAX_EPOCHS, 500}
checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best}
check_val_every_n_epoch: ${oc.env:CHECK_VAL_EVERY_N_EPOCHS, 50}
num_sanity_val_steps: 0
log_every_n_steps: ${oc.env:LOG_EVERY_N_STEPS, 100}
callbacks:
- class_path: eva.callbacks.ConfigurationLogger
- class_path: lightning.pytorch.callbacks.TQDMProgressBar
init_args:
refresh_rate: ${oc.env:TQDM_REFRESH_RATE, 1}
logger:
- class_path: lightning.pytorch.loggers.TensorBoardLogger
init_args:
save_dir: *OUTPUT_ROOT
name: ""
# - class_path: lightning.pytorch.loggers.WandbLogger
# init_args:
# project: ${oc.env:WANDB_PROJECT, radiology}
# name: ${oc.env:WANDB_RUN_NAME, btcv-${oc.env:MODEL_NAME, radiology/voco_b}}
model:
class_path: eva.vision.models.modules.SemanticSegmentationModule
init_args:
encoder:
class_path: eva.vision.models.ModelFromRegistry
init_args:
model_name: ${oc.env:MODEL_NAME, radiology/voco_b}
model_kwargs:
out_indices: ${oc.env:OUT_INDICES, 6}
decoder:
class_path: eva.vision.models.networks.decoders.segmentation.SwinUNETRDecoder
init_args:
feature_size: ${oc.env:IN_FEATURES, 48}
out_channels: &NUM_CLASSES 14
inferer:
class_path: monai.inferers.SlidingWindowInferer
init_args:
roi_size: &ROI_SIZE ${oc.env:ROI_SIZE, [96, 96, 96]}
sw_batch_size: ${oc.env:SW_BATCH_SIZE, 8}
overlap: ${oc.env:SW_OVERLAP, 0.75}
criterion:
class_path: monai.losses.DiceCELoss
init_args:
include_background: false
to_onehot_y: true
softmax: true
lr_multiplier_encoder: 0.0
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: ${oc.env:LR_VALUE, 0.001}
betas: [0.9, 0.999]
weight_decay: ${oc.env:WEIGHT_DECAY, 0.01}
lr_scheduler:
class_path: torch.optim.lr_scheduler.CosineAnnealingLR
init_args:
T_max: *MAX_EPOCHS
eta_min: ${oc.env:LR_VALUE_END, 0.0001}
postprocess:
predictions_transforms:
- class_path: eva.core.models.transforms.AsDiscrete
init_args:
argmax: true
to_onehot: *NUM_CLASSES
targets_transforms:
- class_path: eva.core.models.transforms.AsDiscrete
init_args:
to_onehot: *NUM_CLASSES
metrics:
common:
- class_path: eva.metrics.AverageLoss
evaluation:
- class_path: torchmetrics.segmentation.DiceScore
init_args:
num_classes: *NUM_CLASSES
include_background: false
average: macro
input_format: one-hot
- class_path: torchmetrics.ClasswiseWrapper
init_args:
metric:
class_path: eva.vision.metrics.MonaiDiceScore
init_args:
include_background: true
num_classes: *NUM_CLASSES
input_format: one-hot
reduction: none
prefix: DiceScore_
labels:
- "0_background"
- "1_spleen"
- "2_right_kidney"
- "3_left_kidney"
- "4_gallbladder"
- "5_esophagus"
- "6_liver"
- "7_stomach"
- "8_aorta"
- "9_inferior_vena_cava"
- "10_portal_and_splenic_vein"
- "11_pancreas"
- "12_right_adrenal_gland"
- "13_left_adrenal_gland"
data:
class_path: eva.DataModule
init_args:
datasets:
train:
class_path: eva.vision.datasets.BTCV
init_args: &DATASET_ARGS
root: ${oc.env:DATA_ROOT, ./data/btcv}
split: train
download: ${oc.env:DOWNLOAD_DATA, false}
# Set `download: true` to download the dataset automatically
# The BTCV dataset is distributed under the CC BY 4.0 license
# (https://creativecommons.org/licenses/by-nc-nd/4.0/legalcode)
transforms:
class_path: torchvision.transforms.v2.Compose
init_args:
transforms:
- class_path: eva.vision.data.transforms.EnsureChannelFirst
init_args:
channel_dim: 1
- class_path: eva.vision.data.transforms.Spacing
init_args:
pixdim: [1.5, 1.5, 1.5]
- class_path: eva.vision.data.transforms.ScaleIntensityRange
init_args:
input_range:
- ${oc.env:SCALE_INTENSITY_MIN, -175.0}
- ${oc.env:SCALE_INTENSITY_MAX, 250.0}
output_range: [0.0, 1.0]
- class_path: eva.vision.data.transforms.CropForeground
- class_path: eva.vision.data.transforms.SpatialPad
init_args:
spatial_size: *ROI_SIZE
- class_path: eva.vision.data.transforms.RandCropByPosNegLabel
init_args:
spatial_size: *ROI_SIZE
num_samples: ${oc.env:SAMPLE_BATCH_SIZE, 4}
pos: ${oc.env:RAND_CROP_POS_WEIGHT, 9}
neg: ${oc.env:RAND_CROP_NEG_WEIGHT, 1}
- class_path: eva.vision.data.transforms.RandFlip
init_args:
spatial_axes: [0, 1, 2]
- class_path: eva.vision.data.transforms.RandRotate90
init_args:
spatial_axes: [1, 2]
- class_path: eva.vision.data.transforms.RandScaleIntensity
init_args:
factors: 0.1
prob: 0.1
- class_path: eva.vision.data.transforms.RandShiftIntensity
init_args:
offsets: 0.1
prob: 0.1
val:
class_path: eva.vision.datasets.BTCV
init_args:
<<: *DATASET_ARGS
split: val
transforms:
class_path: torchvision.transforms.v2.Compose
init_args:
transforms:
- class_path: eva.vision.data.transforms.EnsureChannelFirst
init_args:
channel_dim: 1
- class_path: eva.vision.data.transforms.Spacing
init_args:
pixdim: [1.5, 1.5, 1.5]
- class_path: eva.vision.data.transforms.ScaleIntensityRange
init_args:
input_range:
- ${oc.env:SCALE_INTENSITY_MIN, -175.0}
- ${oc.env:SCALE_INTENSITY_MAX, 250.0}
output_range: [0.0, 1.0]
- class_path: eva.vision.data.transforms.CropForeground
dataloaders:
train:
batch_size: ${oc.env:BATCH_SIZE, 2}
num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 8}
shuffle: true
collate_fn: eva.vision.data.dataloaders.collate_fn.collection_collate
val:
batch_size: 1
num_workers: *N_DATA_WORKERS
59 changes: 59 additions & 0 deletions docs/datasets/btcv.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Beyond the Cranial Vault (BTCV) Abdomen dataset.

The BTCV dataset comprises abdominal CT scans acquired at the Vanderbilt University Medical Center from metastatic liver cancer patients or post-operative ventral hernia patients.

The annotations cover segmentations of the spleen, right and left kidney, gallbladder, esophagus, liver, stomach, aorta, inferior vena cava, portal vein and splenic vein, pancreas, right adrenal gland, left adrenal gland are included in this data set. Images were manually labeled by two experienced undergraduate students, and verified by a radiologist.


## Raw data

### Key stats

| | |
|-----------------------|-----------------------------------------------------------|
| **Modality** | Vision (radiology, CT scans) |
| **Task** | Segmentation (14 classes) |
| **Image dimension** | 512 x 512 x ~140 (number of slices) |
| **Files format** | `.nii` ("NIFTI") images |
| **Number of scans** | 30 |
| **Splits in use** | train (80%) / val (20%) |


### Splits

While the full dataset contains 90 CT scans, we use the train/val split from MONAI which uses a subset of 30 CT scans (https://github.com/Luffy03/Large-Scale-Medical/blob/main/Downstream/monai/BTCV/dataset/dataset_0.json):

| Splits | Train | Validation |
|----------------|------------------|-------------------|
| #Scans | 24 (80%) | 6 (20%) |


### Organization

The training data are organized as follows:

```
imagesTr
├── img0001.nii.gz
├── img0002.nii.gz
└── ...

labelsTr
├── label0001.nii.gz
├── label0002.nii.gz
└── ...
```

## Download

The `BTCV` dataset class supports downloading the data during runtime by setting the init argument `download=True`.

## Relevant links

* [zenodo download source](https://zenodo.org/records/1169361)
* [huggingface dataset](https://huggingface.co/datasets/Luffy503/VoCo_Downstream/blob/main/BTCV.zip)


## License

[CC BY 4.0](https://creativecommons.org/licenses/by/4.0)
3 changes: 2 additions & 1 deletion docs/datasets/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@

| Dataset | #Images | Image Size | Task | Download provided
|---|---|---|---|---|
| [BTCV](btcv.md) | 30 | 512 x 512 x ~140 \* | Semantic Segmentation (14 classes) | Yes |
| [TotalSegmentator](total_segmentator.md) | 1228 | ~300 x ~300 x ~350 \* | Semantic Segmentation (117 classes) | Yes |
| [LiTS](lits.md) | 131 (58638) | ~300 x ~300 x ~350 \* | Semantic Segmentation (2 classes) | No |
| [LiTS](lits.md) | 131 (58638) | ~300 x ~300 x ~350 \* | Semantic Segmentation (3 classes) | No |

\* 3D images of varying sizes
5 changes: 3 additions & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ Supported datasets & tasks include:

*Radiology datasets*

- **[TotalSegmentator](datasets/total_segmentator.md)**: radiology/CT-scan for segmentation of anatomical structures
- **[LiTS](datasets/lits.md)**: radiology/CT-scan for segmentation of liver and tumor
- **[BTCV](datasets/btcv.md)**: Segmentation of abdominal organs (CT scans).
- **[TotalSegmentator](datasets/total_segmentator.md)**: Segmentation of anatomical structures (CT scans).
- **[LiTS](datasets/lits.md)**: Segmentation of liver and tumor (CT scans).

To evaluate FMs, *eva* provides support for different model-formats, including models trained with PyTorch, models available on HuggingFace and ONNX-models. For other formats custom wrappers can be implemented.

Expand Down
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ nav:
- PANDA: datasets/panda.md
- PANDASmall: datasets/panda_small.md
- Radiology:
- BTCV: datasets/btcv.md
- TotalSegmentator: datasets/total_segmentator.md
- LiTS: datasets/lits.md
- Reference API:
Expand Down
6 changes: 4 additions & 2 deletions src/eva/core/models/modules/head.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Neural Network Head Module."""

from typing import Any, Callable, Dict
from typing import Any, Callable, Dict, List

import torch
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
Expand Down Expand Up @@ -108,7 +108,9 @@ def test_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPU
return self._batch_step(batch)

@override
def predict_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> torch.Tensor:
def predict_step(
self, batch: INPUT_BATCH, *args: Any, **kwargs: Any
) -> torch.Tensor | List[torch.Tensor]:
tensor = INPUT_BATCH(*batch).data
return tensor if self.backbone is None else self.backbone(tensor)

Expand Down
4 changes: 2 additions & 2 deletions src/eva/core/models/modules/typings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Type annotations for model modules."""

from typing import Any, Dict, NamedTuple
from typing import Any, Dict, List, NamedTuple

import lightning.pytorch as pl
import torch
Expand All @@ -13,7 +13,7 @@
class INPUT_BATCH(NamedTuple):
"""The default input batch data scheme."""

data: torch.Tensor
data: torch.Tensor | List[torch.Tensor]
"""The data batch."""

targets: torch.Tensor | None = None
Expand Down
3 changes: 2 additions & 1 deletion src/eva/core/models/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Model outputs transforms API."""

from eva.core.models.transforms.as_discrete import AsDiscrete
from eva.core.models.transforms.extract_cls_features import ExtractCLSFeatures
from eva.core.models.transforms.extract_patch_features import ExtractPatchFeatures

__all__ = ["ExtractCLSFeatures", "ExtractPatchFeatures"]
__all__ = ["AsDiscrete", "ExtractCLSFeatures", "ExtractPatchFeatures"]
57 changes: 57 additions & 0 deletions src/eva/core/models/transforms/as_discrete.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""Defines the AsDiscrete transformation."""

import torch


class AsDiscrete:
"""Convert the logits tensor to discrete values."""

def __init__(
self,
argmax: bool = False,
to_onehot: int | bool | None = None,
threshold: float | None = None,
) -> None:
"""Convert the input tensor/array into discrete values.

Args:
argmax: Whether to execute argmax function on input data before transform.
to_onehot: if not None, convert input data into the one-hot format with
specified number of classes. If bool, it will try to infer the number
of classes.
threshold: If not None, threshold the float values to int number 0 or 1
with specified threshold.
"""
super().__init__()

self._argmax = argmax
self._to_onehot = to_onehot
self._threshold = threshold

def __call__(self, tensor: torch.Tensor) -> torch.Tensor:
"""Call method for the transformation."""
if self._argmax:
tensor = torch.argmax(tensor, dim=1, keepdim=True)

if self._to_onehot is not None:
tensor = _one_hot(tensor, num_classes=self._to_onehot, dim=1, dtype=torch.long)

if self._threshold is not None:
tensor = tensor >= self._threshold

return tensor


def _one_hot(
tensor: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.float, dim: int = 1
) -> torch.Tensor:
"""Convert input tensor into one-hot format (implementation taken from MONAI)."""
shape = list(tensor.shape)
if shape[dim] != 1:
raise AssertionError(f"Input tensor must have 1 channel at dim {dim}.")

shape[dim] = num_classes
o = torch.zeros(size=shape, dtype=dtype, device=tensor.device)
tensor = o.scatter_(dim=dim, index=tensor.long(), value=1)

return tensor
2 changes: 1 addition & 1 deletion src/eva/core/models/wrappers/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def load_state_dict_from_url(
os.makedirs(model_dir, exist_ok=True)

cached_file = os.path.join(model_dir, filename or os.path.basename(url))
if force or not _check_integrity(cached_file, md5):
if force or not os.path.exists(cached_file) or not _check_integrity(cached_file, md5):
sys.stderr.write(f"Downloading: '{url}' to {cached_file}\n")
_download_url_to_file(url, cached_file, progress=progress)
if md5 is None or not _check_integrity(cached_file, md5):
Expand Down
Loading