diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml new file mode 100644 index 0000000..01bc908 --- /dev/null +++ b/.github/workflows/pre-commit.yml @@ -0,0 +1,22 @@ +name: pre-commit + +on: [push, pull_request] + +jobs: + run-pre-commit: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + + - name: Install pre-commit and dependencies + run: | + pip install pre-commit black isort + + - name: Run pre-commit hooks + run: pre-commit run --all-files diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..a2330af --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,20 @@ +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v3.2.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files +- repo: https://github.com/psf/black + rev: 24.10.0 + hooks: + - id: black +- repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + name: isort (python) + args: ["--profile", "black"] diff --git a/CCMetrics/CC_base.py b/CCMetrics/CC_base.py new file mode 100644 index 0000000..121d347 --- /dev/null +++ b/CCMetrics/CC_base.py @@ -0,0 +1,340 @@ +import copy +import gc +import hashlib +import os +from enum import Enum + +import numpy as np +import torch +from monai.metrics import ( + Cumulative, + DiceMetric, + HausdorffDistanceMetric, + SurfaceDiceMetric, + SurfaceDistanceMetric, +) +from torch.nn import functional as F + +from CCMetrics.space_separation import compute_voronoi_regions as space_separation + + +class CCBaseMetric: + + def __init__( + self, + BaseMetric: Cumulative, + *args, + use_caching=True, + caching_dir=".cache", + metric_best_score=None, + metric_worst_score=None, + **kwargs, + ): + """ + Initializes a CC_base object. + + Args: + BaseMetric (Cumulative): The base Monai metric to be used. + *args: Variable length argument list, passed to the Monai metric. + use_caching (bool, optional): Flag to enable caching. Defaults to True. + caching_dir (str, optional): Directory to store the cache. Defaults to ".cache". + metric_best_score: The best score for the metric. Must be defined. + metric_worst_score: The worst score for the metric. Must be defined. + **kwargs: Arbitrary keyword arguments, passed to the Monai metric. + + Raises: + AssertionError: If metric_best_score or metric_worst_score is not defined. + + """ + assert metric_best_score is not None, "Best score must be defined" + assert metric_worst_score is not None, "Worst score must be defined" + self.metric_perfect_score = metric_best_score + self.metric_worst_score = metric_worst_score + self.buffer_collection = [] + if kwargs.get("include_background", False): + raise ValueError("Background class is not supported") + else: + kwargs["include_background"] = False + + if kwargs.get("cc_reduction", None): + assert kwargs["cc_reduction"] in [ + "patient", + "overall", + ], f"Unknown aggregation function {kwargs['cc_reduction']}" + self.cc_reduction = kwargs["cc_reduction"] + del kwargs["cc_reduction"] + else: + self.cc_reduction = "patient" + + self.base_metric = BaseMetric(*args, **kwargs) + self.use_caching = use_caching + self.caching_dir = caching_dir + if self.use_caching and not os.path.exists(self.caching_dir): + os.makedirs(self.caching_dir) + + def __call__(self, y_pred, y): + """ + Calculates the metric for the predicted and ground truth tensors. + + Args: + y_pred (numpy.ndarray or torch.Tensor): The predicted tensor. + y (numpy.ndarray or torch.Tensor): The ground truth tensor. + + Raises: + AssertionError: If the input shapes or conditions are not correct. + + Returns: + None + """ + # Check if tensor or numpy array + if isinstance(y_pred, np.ndarray): + y_pred = torch.from_numpy(y_pred) + if isinstance(y, np.ndarray): + y = torch.from_numpy(y) + + assert isinstance( + y_pred, torch.Tensor + ), f"Input is not a torch tensor. Got {type(y_pred)}" + assert isinstance( + y, torch.Tensor + ), f"Input is not a torch tensor. Got {type(y)}" + + # Check conditions + assert ( + len(y_pred.shape) == 5 + ), "Input shape is not correct. Expected shape: (B,C,D,H,W) as input y_pred" + assert ( + len(y.shape) == 5 + ), "Input shape is not correct. Expected shape: (B,C,D,H,W) as input y" + assert ( + y_pred.shape == y.shape + ), f"Input shapes do not match. Got {y_pred.shape} and {y.shape}" + assert ( + y_pred.shape[1] == 2 + ), f"Expected two classes in the input. Got {y_pred.shape[1]}" + assert y.shape[1] == 2, f"Expected two classes in the input. Got {y.shape[1]}" + assert ( + y_pred.shape[0] == 1 + ), f"Currently only a batch size of 1 is supported. Got {y_pred.shape[0]} in y_pred" + assert ( + y.shape[0] == 1 + ), f"Currently only a batch size of 1 is supported. Got {y.shape[0]} in y" + # Collect from previous runs + gc.collect() + + # Check if pure backgorund class + if y[0].argmax(0).sum() == 0: + if y_pred[0].argmax(0).sum() == 0: + # Case perfect prediction: No foreground class present in prediction + self.buffer_collection.append(torch.tensor([self.metric_perfect_score])) + else: + # Case worst prediction: Predicted Foreground class but no GT + self.buffer_collection.append(torch.tensor([self.metric_worst_score])) + return + + # Get separation as by ground-truth + if self.use_caching: + + gt_fingerprint = hashlib.md5( + y[0].argmax(0).cpu().numpy().tobytes() + ).hexdigest() + target_path = f"{os.path.join(self.caching_dir, gt_fingerprint)}.npy" + if os.path.exists(target_path): + cc_assignment = np.load(target_path) + else: + cc_assignment = space_separation(y[0].argmax(0).cpu().numpy()) + np.save(target_path, cc_assignment) + else: + cc_assignment = space_separation(y[0].argmax(0).cpu().numpy()) + + cc_assignment = torch.from_numpy(cc_assignment).type(torch.int64) + + missed_components = 0 + + for cc_id in cc_assignment.unique().tolist(): + pred_helper = copy.deepcopy(y_pred[0]).argmax(0) + label_helper = copy.deepcopy(y[0]).argmax(0) + # Find current region of interest + cc_mask = cc_assignment == cc_id + pred_helper[torch.logical_not(cc_mask)] = 0 + label_helper[torch.logical_not(cc_mask)] = 0 + if pred_helper.sum() == 0: + missed_components += 1 + + # Remap metrics back to one-hot encoding + pred_helper = F.one_hot(pred_helper, num_classes=2).permute(3, 0, 1, 2) + label_helper = F.one_hot(label_helper, num_classes=2).permute(3, 0, 1, 2) + + self.base_metric( + y_pred=pred_helper.unsqueeze(0), y=label_helper.unsqueeze(0) + ) + del pred_helper + del label_helper + del cc_mask + gc.collect() + + # Get metric buffer and reset it + metric_buffer = self.base_metric.get_buffer() + self.buffer_collection.append(metric_buffer) + self.base_metric.reset() + + def cc_aggregate(self, mode=None): + """ + Aggregates the buffer collection based on the specified mode. + + Args: + mode (str, optional): The aggregation mode. Can be "patient" or "overall". + If not provided, the default mode specified in self.cc_reduction will be used. + + Returns: + torch.Tensor: The aggregated result based on the specified mode. + + Raises: + AssertionError: If an unknown aggregation function is provided. + + """ + if mode is None: + mode = self.cc_reduction + assert mode in ["patient", "overall"], f"Unknown aggregation function {mode}" + cleaned_buffer = [ + torch.where( + torch.isinf(x), + torch.tensor( + self.metric_worst_score, dtype=torch.float32, device=x.device + ), + x, + ) + for x in self.buffer_collection + ] + cleaned_buffer = [ + torch.where( + torch.isnan(x), + torch.tensor( + self.metric_worst_score, dtype=torch.float32, device=x.device + ), + x, + ) + for x in cleaned_buffer + ] + cleaned_buffer = [x.reshape(-1, 1) for x in cleaned_buffer] + if mode == "patient": + # Aggregate per patient and return list of means + return torch.stack([x.mean() for x in cleaned_buffer]) + elif mode == "overall": + # Aggregate overall. All components are considered as equal. Return full list + return torch.concatenate(cleaned_buffer).squeeze() + + def get_buffer(self): + """ + Returns the buffer collection. + """ + return self.buffer_collection + + def reset(self): + """ + Resets the buffer collection. + """ + self.buffer_collection = [] + + def cache_datapoint(self, y): + """ + Caches the datapoint if caching is enabled. + + Args: + y (torch.Tensor): The input tensor. + + Raises: + ValueError: If caching is disabled. + + Returns: + None + """ + if self.use_caching: + # Handle data input + if isinstance(y, torch.Tensor): + y = y.cpu().numpy() + assert isinstance( + y, np.ndarray + ), "Input is not a numpy array or torch tensor. Caching is not possible" + assert ( + len(y.shape) == 3 + ), "Input shape is not correct. Expected shape: (D,H,W) as input y" + + gt_fingerprint = hashlib.md5(y.tobytes()).hexdigest() + target_path = f"{os.path.join(self.caching_dir, gt_fingerprint)}.npy" + if os.path.exists(target_path): + return + cc_assignment = space_separation(y) + np.save(target_path, cc_assignment) + else: + raise ValueError("Caching is disabled") + + +# Define used metrics from the paper +# For unbound metrics, the worst score is set to None and should be handled by the user, as it is infinite + + +class CCDiceMetric(CCBaseMetric): + """ + CCDiceMetric is a class that represents the Dice metric for connected components. + It inherits from the CCBaseMetric class. + """ + + def __init__(self, *args, **kwargs): + super().__init__( + DiceMetric, *args, metric_best_score=1.0, metric_worst_score=0.0, **kwargs + ) + + +class CCHausdorffDistanceMetric(CCBaseMetric): + """ + CCHausdorffDistanceMetric is a class that represents the Hausdorff distance metric for connected components. + It inherits from the CCBaseMetric class. + """ + + def __init__(self, *args, **kwargs): + super().__init__( + HausdorffDistanceMetric, *args, metric_best_score=0.0, **kwargs + ) + + +class CCHausdorffDistance95Metric(CCBaseMetric): + """ + A class representing a metric for calculating the 95th percentile Hausdorff distance for connected components. + It inherits from the CCBaseMetric class. + """ + + def __init__(self, *args, **kwargs): + super().__init__( + HausdorffDistanceMetric, + *args, + metric_best_score=0.0, + percentile=95, + **kwargs, + ) + + +class CCSurfaceDistanceMetric(CCBaseMetric): + """ + A class representing a metric for calculating the SurfaceDistance metric for connected components. + It inherits from the CCBaseMetric class. + """ + + def __init__(self, *args, **kwargs): + super().__init__(SurfaceDistanceMetric, *args, metric_best_score=0.0, **kwargs) + + +class CCSurfaceDiceMetric(CCBaseMetric): + """ + A class representing a metric for calculating the SurfaceDiceMetric metric for connected components. + It inherits from the CCBaseMetric class. + """ + + def __init__(self, *args, **kwargs): + super().__init__( + SurfaceDiceMetric, + *args, + metric_best_score=1.0, + metric_worst_score=0.0, + **kwargs, + ) diff --git a/CCMetrics/__init__.py b/CCMetrics/__init__.py new file mode 100644 index 0000000..eec63b2 --- /dev/null +++ b/CCMetrics/__init__.py @@ -0,0 +1,8 @@ +from CCMetrics.CC_base import ( + CCBaseMetric, + CCDiceMetric, + CCHausdorffDistance95Metric, + CCHausdorffDistanceMetric, + CCSurfaceDiceMetric, + CCSurfaceDistanceMetric, +) diff --git a/CCMetrics/space_separation.py b/CCMetrics/space_separation.py new file mode 100644 index 0000000..e01ee21 --- /dev/null +++ b/CCMetrics/space_separation.py @@ -0,0 +1,67 @@ +import cc3d +import numpy as np +from scipy.ndimage import distance_transform_edt +from scipy.spatial import cKDTree + + +def compute_voronoi_regions(labels): + """ + Compute Voronoi regions for the given labels. + + Parameters: + labels (ndarray): Input label array. + + Returns: + ndarray: Array of Voronoi region assignments. + + """ + cc_labels = cc3d.connected_components(labels) + current_assignment = np.zeros_like(cc_labels, dtype="int") + current_mins = np.ones_like(cc_labels, dtype="float") * np.inf + for idx, cc in enumerate(np.unique(cc_labels)): + if cc == 0: + pass + else: + # Compute distance transforms from current cc + cur_dt = distance_transform_edt(np.logical_not(cc_labels == cc)) + # Update the cc_asignment and previous minimas + msk = cur_dt < current_mins + current_mins[msk] = cur_dt[msk] + current_assignment[msk] = idx + cc_asignment = current_assignment + return cc_asignment + + +def compute_voronoi_kdtree(labels): + """ + Computes the Voronoi diagram using a KDTree for a given label image. + + Parameters: + labels (ndarray): The label image. + + Returns: + ndarray: Array of Voronoi region assignments. + """ + cc_labels = cc3d.connected_components(labels) + output = np.zeros_like(cc_labels, dtype=np.int32) + + coords = np.column_stack(np.nonzero(cc_labels)) + cc_ids = cc_labels[cc_labels > 0] + unique_ccs = np.unique(cc_ids) + + # Map each cc_id to its voxel coordinates + cc_points = {cc: coords[cc_ids == cc] for cc in unique_ccs} + + # Build a KDTree using all foreground voxels, tagged with their cc_id + all_pts = np.concatenate([cc_points[cc] for cc in unique_ccs]) + all_tags = np.concatenate([[cc] * len(cc_points[cc]) for cc in unique_ccs]) + + tree = cKDTree(all_pts) + + # For each voxel in the volume, find the nearest foreground point and assign its cc_id + all_voxels = np.indices(cc_labels.shape).reshape(3, -1).T + dists, idxs = tree.query(all_voxels) + nearest_ccs = all_tags[idxs] + output = nearest_ccs.reshape(cc_labels.shape) + + return output diff --git a/README.md b/README.md new file mode 100644 index 0000000..dea6ed4 --- /dev/null +++ b/README.md @@ -0,0 +1,111 @@ +# CC-Metrics +## Every Component Counts: Rethinking the Measure of Success for Medical Semantic Segmentation in Multi-Instance Segmentation Tasks + +[![Paper](https://img.shields.io/badge/PDF-Paper-green.svg)](https://arxiv.org/pdf/2410.18684) [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](LICENSE) + +## Description +CC-Metrics is an evaluation approach for assessing standard evaluation metrics such as Dice or Surface-Dice on a per-connected component basis. To determine how to match predictions to ground-truth connected components, it separates the image space into Voronoi regions and maps predictions within the Voronoi region to the corresponding ground-truth. + +Below is an example visualization of the Voronoi-based mapping process: + +
+ Voronoi Mapping Example +
+ +For more details, you can read the full paper [here](https://arxiv.org/pdf/2410.18684). + +## Table of Contents + +- [Description](#description) +- [Installation](#installation) +- [How to use CC-Metrics](#how-to-use-cc-metrics) + - [Basic Usecase](#basic-usecase) + - [Metric Aggregation](#metric-aggregation) + - [Caching mechanism](#caching-mechanism) +- [Citation](#citation) +- [License](#license) + +## Installation + +It it generally recommended to first install pytorch and Monai before installing CC-Metrics + +``` +git clone https://github.com/alexanderjaus/CC-Metrics.git +cd CC-Metrics +pip install -e . +``` + +## How to use CC-Metrics +CC-Metrics defines a wrapper around Monai's Cumulative metrics. + +#### Basic Usecase + +```python +from CCMetrics import CCDiceMetric + +cc_dice = CCDiceMetric( + cc_reduction="patient", + use_caching=True, + caching_dir=".cache" + ) + +y, y_hat = torch.rand((1, 2, 200, 200, 200)), torch.rand((1, 2, 200, 200, 200)) +# Tensors are expected in shape (B, C, D, H, W) + +cc_dice(y_pred=y_hat, y=y) + +ccdice.cc_aggregate() +``` +Explored metrics in the paper include +- CCDiceMetric +- CCHausdorffDistanceMetric +- CCHausdorffDistance95Metric +- CCSurfaceDistanceMetric +- CCSurfaceDiceMetric + +Unbounded metrics require specifying a worst-case value to replace infinity or NaN, as in + +```CCSurfaceDistanceMetric(cc_reduction="overall", metric_worst_score=30)``` + +This is necessary, as averaging these metrics would otherwise be undefined. + +#### Metric Aggregation + +The `CCBaseMetric` class supports two types of metric aggregation modes: + +1. **Patient-Level Aggregation (`patient`)**: + - Computes the mean metric score for each patient by aggregating all connected components within the patient. + - Returns a list of mean scores, one for each patient. + +2. **Overall Aggregation (`overall`)**: + - Treats all connected components across all patients equally. + - Aggregates the metric scores for all components into a single list. + +The aggregation mode can be specified using the `cc_aggregate` method, with the default mode being `patient`. + +#### Caching mechanism +CC-Metrics requires the computation of a generalized Voronoi diagram which serves as the mapping mechanism between predictions and ground-truth. As the separation of the image space only depends on the ground-truth, the mapping can be cached and reused between intermediate evaluations or across metrics. Even computed in advance to speed up the metric computation. + +Use the ```use_caching``` flag and provide a caching location. This will compute and cache the voronoi regions when they are computed for the first time. It is recommended to precache these regions allowing for a faster computation on the spot. This can be achieved using the +```prepare_caching.py``` script that computes and caches voronoi regions for all ```.nii.gz```images in a given directory. + +``` +python prepare_caching.py --gt --cache_dir --nof_workers +``` + + +## Citation + +If you make use of this project in your work, it would be appreciated if you cite the cc-metrics paper +``` +@article{jaus2024every, + title={Every Component Counts: Rethinking the Measure of Success for Medical Semantic Segmentation in Multi-Instance Segmentation Tasks}, + author={Jaus, Alexander and Seibold, Constantin and Rei{\ss}, Simon and Marinov, Zdravko and Li, Keyi and Ye, Zeling and Krieg, Stefan and Kleesiek, Jens and Stiefelhagen, Rainer}, + journal={arXiv preprint arXiv:2410.18684}, + year={2024} +} +``` + +## License + +This project is licensed under the [Apache 2.0 License](LICENSE). diff --git a/prepare_caching.py b/prepare_caching.py new file mode 100644 index 0000000..20a0ade --- /dev/null +++ b/prepare_caching.py @@ -0,0 +1,74 @@ +import argparse +import os +from multiprocessing import Pool + +import nibabel as nib +from tqdm import tqdm +from tqdm.contrib.concurrent import ( # Import process_map for progress bars with multiprocessing + process_map, +) + +from CCMetrics.CC_base import CCDiceMetric + + +def process_file(args): + gt_file, cache_dir = args + metric = CCDiceMetric(use_caching=True, caching_dir=cache_dir) + y = nib.load(gt_file).get_fdata() + metric.cache_datapoint(y) + + +def main(): + parser = argparse.ArgumentParser(description="Cache data") + parser.add_argument( + "--gt", + type=str, + help="Path to the directory containing the ground truth nii.gz images", + ) + parser.add_argument( + "--cache_dir", + type=str, + help="Path to the directory where the cache files will be stored", + ) + parser.add_argument( + "--nof_workers", + type=int, + default=1, + help="Number of workers to use for parallel processing", + ) + + args = parser.parse_args() + + assert args.gt is not None, "Please provide the path to the ground truth images" + assert os.path.exists(args.gt), "The path to the ground truth images does not exist" + assert args.cache_dir is not None, "Please provide the path to the cache directory" + + if not os.path.exists(args.cache_dir): + os.makedirs(args.cache_dir) + print(f"Created cache directory at {args.cache_dir}") + + identified_gt_files = [x for x in os.listdir(args.gt) if x.endswith(".nii.gz")] + full_path_gt_files = [os.path.join(args.gt, x) for x in identified_gt_files] + + print(f"Found {len(identified_gt_files)} ground truth files in directory {args.gt}") + print(f"Identified files look like this: {identified_gt_files[:5]}") + + if args.nof_workers > 1: + # Use process_map for parallel processing with progress bar + process_map( + process_file, + [(gt_file, args.cache_dir) for gt_file in full_path_gt_files], + max_workers=args.nof_workers, + desc="Processing files", + unit="file", + ) + else: + # For single-worker case, use regular tqdm + metric = CCDiceMetric(use_caching=True, caching_dir=args.cache_dir) + for gt_file in tqdm(full_path_gt_files, desc="Processing files", unit="file"): + y = nib.load(gt_file).get_fdata() + metric.cache_datapoint(y) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..4628ec8 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,38 @@ +[build-system] +requires = ["setuptools >= 65.0.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "CCMetrics" +version = "0.0.1" +authors = [ + {name = "Alexander Jaus", email = "alexander.jaus@kit.edu"}, +] +description = "An evaluation protocol for standard metrics per connected component" +readme = "README.md" +requires-python = ">=3.8" +keywords = ["medical segmentation", "tumor segmentation", "instance segmentation", "detection via segmentation"] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Topic :: Scientific/Engineering :: Medical Science Apps.", +] +dependencies = [ + "torch", + "monai", + "numpy", + "scipy", + "connected-components-3d", + "nibabel", + "tqdm" +] + +[project.urls] +Homepage = "https://github.com/alexanderjaus/CC-Metrics" +Issues = "https://github.com/alexanderjaus/CC-Metrics/issues" + +[tool.setuptools.packages.find] +include = ["CCMetrics"] diff --git a/resources/title_fig.jpg b/resources/title_fig.jpg new file mode 100644 index 0000000..8a09e61 Binary files /dev/null and b/resources/title_fig.jpg differ