Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 188 additions & 0 deletions configs/vision/radiology/online/segmentation/btcv.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
---
trainer:
class_path: eva.Trainer
init_args:
n_runs: &N_RUNS ${oc.env:N_RUNS, 5}
default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, radiology/voco_b}/btcv}
max_epochs: &MAX_EPOCHS ${oc.env:MAX_EPOCHS, 500}
checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best}
check_val_every_n_epoch: ${oc.env:CHECK_VAL_EVERY_N_EPOCHS, 50}
num_sanity_val_steps: 0
log_every_n_steps: ${oc.env:LOG_EVERY_N_STEPS, 100}
callbacks:
- class_path: eva.callbacks.ConfigurationLogger
- class_path: lightning.pytorch.callbacks.TQDMProgressBar
init_args:
refresh_rate: ${oc.env:TQDM_REFRESH_RATE, 1}
logger:
- class_path: lightning.pytorch.loggers.TensorBoardLogger
init_args:
save_dir: *OUTPUT_ROOT
name: ""
# - class_path: lightning.pytorch.loggers.WandbLogger
# init_args:
# project: ${oc.env:WANDB_PROJECT, radiology}
# name: ${oc.env:WANDB_RUN_NAME, btcv-${oc.env:MODEL_NAME, radiology/voco_b}}
model:
class_path: eva.vision.models.modules.SemanticSegmentationModule
init_args:
encoder:
class_path: eva.vision.models.ModelFromRegistry
init_args:
model_name: ${oc.env:MODEL_NAME, radiology/voco_b}
model_kwargs:
out_indices: ${oc.env:OUT_INDICES, 6}
decoder:
class_path: eva.vision.models.networks.decoders.segmentation.SwinUNETRDecoder
init_args:
feature_size: ${oc.env:IN_FEATURES, 48}
out_channels: &NUM_CLASSES 14
inferer:
class_path: monai.inferers.SlidingWindowInferer
init_args:
roi_size: &ROI_SIZE ${oc.env:ROI_SIZE, [96, 96, 96]}
sw_batch_size: ${oc.env:SW_BATCH_SIZE, 8}
overlap: ${oc.env:SW_OVERLAP, 0.75}
criterion:
class_path: monai.losses.DiceCELoss
init_args:
include_background: false
to_onehot_y: true
softmax: true
lr_multiplier_encoder: 0.0
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: ${oc.env:LR_VALUE, 0.001}
betas: [0.9, 0.999]
weight_decay: ${oc.env:WEIGHT_DECAY, 0.01}
lr_scheduler:
class_path: torch.optim.lr_scheduler.CosineAnnealingLR
init_args:
T_max: *MAX_EPOCHS
eta_min: ${oc.env:LR_VALUE_END, 0.0001}
postprocess:
predictions_transforms:
- class_path: eva.core.models.transforms.AsDiscrete
init_args:
argmax: true
to_onehot: *NUM_CLASSES
targets_transforms:
- class_path: eva.core.models.transforms.AsDiscrete
init_args:
to_onehot: *NUM_CLASSES
metrics:
common:
- class_path: eva.metrics.AverageLoss
evaluation:
- class_path: torchmetrics.segmentation.DiceScore
init_args:
num_classes: *NUM_CLASSES
include_background: false
average: macro
input_format: one-hot
- class_path: torchmetrics.ClasswiseWrapper
init_args:
metric:
class_path: eva.vision.metrics.MonaiDiceScore
init_args:
include_background: true
num_classes: *NUM_CLASSES
input_format: one-hot
reduction: none
prefix: DiceScore_
labels:
- "0_background"
- "1_spleen"
- "2_right_kidney"
- "3_left_kidney"
- "4_gallbladder"
- "5_esophagus"
- "6_liver"
- "7_stomach"
- "8_aorta"
- "9_inferior_vena_cava"
- "10_portal_and_splenic_vein"
- "11_pancreas"
- "12_right_adrenal_gland"
- "13_left_adrenal_gland"
data:
class_path: eva.DataModule
init_args:
datasets:
train:
class_path: eva.vision.datasets.BTCV
init_args: &DATASET_ARGS
root: ${oc.env:DATA_ROOT, ./data/btcv}
split: train
transforms:
class_path: torchvision.transforms.v2.Compose
init_args:
transforms:
- class_path: eva.vision.data.transforms.EnsureChannelFirst
init_args:
channel_dim: 1
- class_path: eva.vision.data.transforms.Spacing
init_args:
pixdim: [1.5, 1.5, 1.5]
- class_path: eva.vision.data.transforms.ScaleIntensityRange
init_args:
input_range:
- ${oc.env:SCALE_INTENSITY_MIN, -175.0}
- ${oc.env:SCALE_INTENSITY_MAX, 250.0}
output_range: [0.0, 1.0]
- class_path: eva.vision.data.transforms.CropForeground
- class_path: eva.vision.data.transforms.SpatialPad
init_args:
spatial_size: *ROI_SIZE
- class_path: eva.vision.data.transforms.RandCropByPosNegLabel
init_args:
spatial_size: *ROI_SIZE
num_samples: ${oc.env:SAMPLE_BATCH_SIZE, 4}
pos: ${oc.env:RAND_CROP_POS_WEIGHT, 9}
neg: ${oc.env:RAND_CROP_NEG_WEIGHT, 1}
- class_path: eva.vision.data.transforms.RandFlip
init_args:
spatial_axes: [0, 1, 2]
- class_path: eva.vision.data.transforms.RandRotate90
init_args:
spatial_axes: [1, 2]
- class_path: eva.vision.data.transforms.RandScaleIntensity
init_args:
factors: 0.1
prob: 0.1
- class_path: eva.vision.data.transforms.RandShiftIntensity
init_args:
offsets: 0.1
prob: 0.1
val:
class_path: eva.vision.datasets.BTCV
init_args:
<<: *DATASET_ARGS
split: val
transforms:
class_path: torchvision.transforms.v2.Compose
init_args:
transforms:
- class_path: eva.vision.data.transforms.EnsureChannelFirst
init_args:
channel_dim: 1
- class_path: eva.vision.data.transforms.Spacing
init_args:
pixdim: [1.5, 1.5, 1.5]
- class_path: eva.vision.data.transforms.ScaleIntensityRange
init_args:
input_range:
- ${oc.env:SCALE_INTENSITY_MIN, -175.0}
- ${oc.env:SCALE_INTENSITY_MAX, 250.0}
output_range: [0.0, 1.0]
- class_path: eva.vision.data.transforms.CropForeground
dataloaders:
train:
batch_size: ${oc.env:BATCH_SIZE, 2}
num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 8}
shuffle: true
collate_fn: eva.vision.data.dataloaders.collate_fn.collection_collate
val:
batch_size: 1
num_workers: *N_DATA_WORKERS
4 changes: 2 additions & 2 deletions src/eva/core/models/modules/typings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Type annotations for model modules."""

from typing import Any, Dict, NamedTuple
from typing import Any, Dict, List, NamedTuple

import lightning.pytorch as pl
import torch
Expand All @@ -13,7 +13,7 @@
class INPUT_BATCH(NamedTuple):
"""The default input batch data scheme."""

data: torch.Tensor
data: torch.Tensor | List[torch.Tensor]
"""The data batch."""

targets: torch.Tensor | None = None
Expand Down
3 changes: 2 additions & 1 deletion src/eva/core/models/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Model outputs transforms API."""

from eva.core.models.transforms.as_discrete import AsDiscrete
from eva.core.models.transforms.extract_cls_features import ExtractCLSFeatures
from eva.core.models.transforms.extract_patch_features import ExtractPatchFeatures

__all__ = ["ExtractCLSFeatures", "ExtractPatchFeatures"]
__all__ = ["AsDiscrete", "ExtractCLSFeatures", "ExtractPatchFeatures"]
40 changes: 40 additions & 0 deletions src/eva/core/models/transforms/as_discrete.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Defines the AsDiscrete transformation."""

import torch
from monai.networks.utils import one_hot


class AsDiscrete:
def __init__(
self,
argmax: bool = False,
to_onehot: int | bool | None = None,
threshold: float | None = None,
) -> None:
"""Convert the input tensor/array into discrete values.

Args:
argmax: Whether to execute argmax function on input data before transform.
to_onehot: if not None, convert input data into the one-hot format with
specified number of classes. If bool, it will try to infer the number
of classes.
threshold: If not None, threshold the float values to int number 0 or 1
with specified threshold.
"""
super().__init__()

self._argmax = argmax
self._to_onehot = to_onehot
self._threshold = threshold

def __call__(self, tensor: torch.Tensor) -> torch.Tensor:
if self._argmax:
tensor = torch.argmax(tensor, dim=1, keepdim=True)

if self._to_onehot is not None:
tensor = one_hot(tensor, num_classes=self._to_onehot, dim=1, dtype=torch.long)

if self._threshold is not None:
tensor = tensor >= self._threshold

return tensor
2 changes: 1 addition & 1 deletion src/eva/core/models/wrappers/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def load_state_dict_from_url(
os.makedirs(model_dir, exist_ok=True)

cached_file = os.path.join(model_dir, filename or os.path.basename(url))
if force or not _check_integrity(cached_file, md5):
if force or not os.path.exists(cached_file) or not _check_integrity(cached_file, md5):
sys.stderr.write(f"Downloading: '{url}' to {cached_file}\n")
_download_url_to_file(url, cached_file, progress=progress)
if md5 is None or not _check_integrity(cached_file, md5):
Expand Down
5 changes: 5 additions & 0 deletions src/eva/vision/data/dataloaders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Dataloader related utilities and functions."""

from eva.vision.data.dataloaders import collate_fn

__all__ = ["collate_fn"]
5 changes: 5 additions & 0 deletions src/eva/vision/data/dataloaders/collate_fn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Dataloader collate API."""

from eva.vision.data.dataloaders.collate_fn.collection import collection_collate

__all__ = ["collection_collate"]
22 changes: 22 additions & 0 deletions src/eva/vision/data/dataloaders/collate_fn/collection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Data only collate filter function."""

from typing import Any, List

import torch

from eva.core.models.modules.typings import INPUT_BATCH


def collection_collate(batch: List[List[INPUT_BATCH]]) -> Any:
"""Collate function for stacking a collection of data samples.

Args:
batch: The batch to be collated.

Returns:
The collated batch.
"""
tensors, targets, metadata = zip(*batch, strict=False)
batch_tensors = torch.cat(list(map(torch.stack, tensors)))
batch_targets = torch.cat(list(map(torch.stack, targets)))
return batch_tensors, batch_targets, metadata
11 changes: 9 additions & 2 deletions src/eva/vision/metrics/segmentation/monai_dice.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Wrapper for dice score metric from MONAI."""

from typing import Literal

from monai.metrics.meandice import DiceMetric
from typing_extensions import override

Expand All @@ -14,6 +16,7 @@ def __init__(
self,
num_classes: int,
include_background: bool = True,
input_format: Literal["one-hot", "index"] = "index",
reduction: str = "mean",
ignore_index: int | None = None,
**kwargs,
Expand All @@ -24,6 +27,8 @@ def __init__(
num_classes: The number of classes in the dataset.
include_background: Whether to include the background class in the computation.
reduction: The method to reduce the dice score. Options are `"mean"`, `"sum"`, `"none"`.
input_format: Choose between "one-hot" for one-hot encoded tensors or "index"
for index tensors.
ignore_index: Integer specifying a target class to ignore. If given, this class
index does not contribute to the returned score.
kwargs: Additional keyword arguments for instantiating monai's `DiceMetric` class.
Expand All @@ -40,11 +45,13 @@ def __init__(
self.reduction = reduction
self.orig_num_classes = num_classes
self.ignore_index = ignore_index
self.input_format = input_format

@override
def update(self, preds, target):
preds = _utils.index_to_one_hot(preds, num_classes=self.orig_num_classes)
target = _utils.index_to_one_hot(target, num_classes=self.orig_num_classes)
if self.input_format == "index":
preds = _utils.index_to_one_hot(preds, num_classes=self.orig_num_classes)
target = _utils.index_to_one_hot(target, num_classes=self.orig_num_classes)
if self.ignore_index is not None:
preds, target = _utils.apply_ignore_index(preds, target, self.ignore_index)
return super().update(preds, target)
Expand Down
2 changes: 1 addition & 1 deletion src/eva/vision/models/modules/semantic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class SemanticSegmentationModule(module.ModelModule):

def __init__(
self,
decoder: decoders.Decoder,
decoder: decoders.Decoder | nn.Module,
criterion: Callable[..., torch.Tensor],
encoder: Dict[str, Any] | Callable[[torch.Tensor], List[torch.Tensor]] | None = None,
lr_multiplier_encoder: float = 0.0,
Expand Down
Loading