Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 140 additions & 61 deletions CCMetrics/CC_base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import copy
import gc
import hashlib
import os
from enum import Enum

import numpy as np
import torch
Expand All @@ -13,9 +11,11 @@
SurfaceDiceMetric,
SurfaceDistanceMetric,
)
from torch.nn import functional as F

from CCMetrics.space_separation import compute_voronoi_regions as space_separation
from CCMetrics.space_separation import compute_voronoi_regions_fast

# Globally disable gradient computation for this entire module
torch.set_grad_enabled(False)


class CCBaseMetric:
Expand All @@ -24,7 +24,7 @@ def __init__(
self,
BaseMetric: Cumulative,
*args,
use_caching=True,
use_caching=False,
caching_dir=".cache",
metric_best_score=None,
metric_worst_score=None,
Expand Down Expand Up @@ -72,32 +72,25 @@ def __init__(
if self.use_caching and not os.path.exists(self.caching_dir):
os.makedirs(self.caching_dir)

def __call__(self, y_pred, y):
"""
Calculates the metric for the predicted and ground truth tensors.
# Set cpu backend
self.xp = np
self.backend = "numpy"
self.space_separation = compute_voronoi_regions_fast

Args:
y_pred (numpy.ndarray or torch.Tensor): The predicted tensor.
y (numpy.ndarray or torch.Tensor): The ground truth tensor.
def _verify_and_convert(self, y_pred, y):

Raises:
AssertionError: If the input shapes or conditions are not correct.

Returns:
None
"""
# Check if tensor or numpy array
if isinstance(y_pred, np.ndarray):
y_pred = torch.from_numpy(y_pred)
if isinstance(y, np.ndarray):
y = torch.from_numpy(y)
# Automatically convert to numy
if isinstance(y_pred, torch.Tensor):
y_pred = y_pred.detach().cpu().numpy()
if isinstance(y, torch.Tensor):
y = y.detach().cpu().numpy()

assert isinstance(
y_pred, torch.Tensor
), f"Input is not a torch tensor. Got {type(y_pred)}"
y_pred, self.xp.ndarray
), f"Input is not a numpy array. Got {type(y_pred)}"
assert isinstance(
y, torch.Tensor
), f"Input is not a torch tensor. Got {type(y)}"
y, self.xp.ndarray
), f"Input is not a numpy array. Got {type(y)}"

# Check conditions
assert (
Expand All @@ -119,59 +112,103 @@ def __call__(self, y_pred, y):
assert (
y.shape[0] == 1
), f"Currently only a batch size of 1 is supported. Got {y.shape[0]} in y"
# Collect from previous runs
gc.collect()

# Check if pure backgorund class
if y[0].argmax(0).sum() == 0:
if y_pred[0].argmax(0).sum() == 0:
return y_pred, y

def _convert_to_target(self, y_pred, y):
if type(y_pred) == self.xp.ndarray:
y_pred = torch.from_numpy(y_pred)
if type(y) == self.xp.ndarray:
y = torch.from_numpy(y)
return y_pred, y

def __call__(self, y_pred, y):
"""
Calculates the metric for the predicted and ground truth tensors.

Args:
y_pred (numpy.ndarray or torch.Tensor): The predicted tensor.
y (numpy.ndarray or torch.Tensor): The ground truth tensor.

Raises:
AssertionError: If the input shapes or conditions are not correct.

Returns:
None
"""
y_pred, y = self._verify_and_convert(y_pred, y)

# Compute argmax
pred_helper = y_pred.argmax(1)
label_helper = y.argmax(1)

# Check if pure background class
if label_helper[0].sum() == 0:
if pred_helper[0].sum() == 0:
# Case perfect prediction: No foreground class present in prediction
self.buffer_collection.append(torch.tensor([self.metric_perfect_score]))
else:
# Case worst prediction: Predicted Foreground class but no GT
self.buffer_collection.append(torch.tensor([self.metric_worst_score]))
return

# Get separation as by ground-truth
if self.use_caching:

gt_fingerprint = hashlib.md5(
y[0].argmax(0).cpu().numpy().tobytes()
).hexdigest()
target_path = f"{os.path.join(self.caching_dir, gt_fingerprint)}.npy"
if os.path.exists(target_path):
cc_assignment = np.load(target_path)
else:
cc_assignment = space_separation(y[0].argmax(0).cpu().numpy())
np.save(target_path, cc_assignment)
else:
cc_assignment = space_separation(y[0].argmax(0).cpu().numpy())

cc_assignment = torch.from_numpy(cc_assignment).type(torch.int64)
cc_assignment = self.space_separation(label_helper[0])

missed_components = 0

for cc_id in cc_assignment.unique().tolist():
pred_helper = copy.deepcopy(y_pred[0]).argmax(0)
label_helper = copy.deepcopy(y[0]).argmax(0)
# Find current region of interest
for cc_id in self.xp.unique(cc_assignment):
cc_mask = cc_assignment == cc_id
pred_helper[torch.logical_not(cc_mask)] = 0
label_helper[torch.logical_not(cc_mask)] = 0
if pred_helper.sum() == 0:
min_corner_idx = self.xp.argwhere(cc_mask).min(axis=0)
max_corner_idx = self.xp.argwhere(cc_mask).max(axis=0)

# Cut out the region of interest
crop_pred = pred_helper[0][
min_corner_idx[0] : max_corner_idx[0] + 1,
min_corner_idx[1] : max_corner_idx[1] + 1,
min_corner_idx[2] : max_corner_idx[2] + 1,
]
crop_label = label_helper[0][
min_corner_idx[0] : max_corner_idx[0] + 1,
min_corner_idx[1] : max_corner_idx[1] + 1,
min_corner_idx[2] : max_corner_idx[2] + 1,
]
pred_masked = (
crop_pred
* cc_mask[
min_corner_idx[0] : max_corner_idx[0] + 1,
min_corner_idx[1] : max_corner_idx[1] + 1,
min_corner_idx[2] : max_corner_idx[2] + 1,
]
)
label_masked = (
crop_label
* cc_mask[
min_corner_idx[0] : max_corner_idx[0] + 1,
min_corner_idx[1] : max_corner_idx[1] + 1,
min_corner_idx[2] : max_corner_idx[2] + 1,
]
)

if pred_masked.sum() == 0:
missed_components += 1

# Remap metrics back to one-hot encoding
pred_helper = F.one_hot(pred_helper, num_classes=2).permute(3, 0, 1, 2)
label_helper = F.one_hot(label_helper, num_classes=2).permute(3, 0, 1, 2)
# pred_onehot = F.one_hot(pred_masked, num_classes=2).permute(3, 0, 1, 2)
# label_onehot = F.one_hot(label_masked, num_classes=2).permute(3, 0, 1, 2)
pred_onehot = self.xp.moveaxis(self.xp.eye(2)[pred_masked], -1, 0)
label_onehot = self.xp.moveaxis(self.xp.eye(2)[label_masked], -1, 0)

self.base_metric(
y_pred=pred_helper.unsqueeze(0), y=label_helper.unsqueeze(0)
pred_onehot, label_onehot = self._convert_to_target(
pred_onehot[self.xp.newaxis], label_onehot[self.xp.newaxis]
)
del pred_helper
del label_helper

self.base_metric(y_pred=pred_onehot, y=label_onehot)

del crop_pred, crop_label, pred_masked, label_masked
del cc_mask
gc.collect()
del pred_helper
del label_helper

# Get metric buffer and reset it
metric_buffer = self.base_metric.get_buffer()
Expand Down Expand Up @@ -264,7 +301,7 @@ def cache_datapoint(self, y):
target_path = f"{os.path.join(self.caching_dir, gt_fingerprint)}.npy"
if os.path.exists(target_path):
return
cc_assignment = space_separation(y)
cc_assignment = self.space_separation(y)
np.save(target_path, cc_assignment)
else:
raise ValueError("Caching is disabled")
Expand All @@ -285,6 +322,48 @@ def __init__(self, *args, **kwargs):
DiceMetric, *args, metric_best_score=1.0, metric_worst_score=0.0, **kwargs
)

def __call__(self, y_pred, y):
"""
Calculates the Dice metric for the predicted and ground truth tensors.
Args:
y_pred (numpy.ndarray or torch.Tensor): The predicted tensor.
y (numpy.ndarray or torch.Tensor): The ground truth tensor.
Raises:
AssertionError: If the input shapes or conditions are not correct.
Returns:
None
"""
y_pred, y = self._verify_and_convert(y_pred, y)

pred_helper = y_pred.argmax(1)
label_helper = y.argmax(1)
if label_helper[0].sum() == 0:
if pred_helper[0].sum() == 0:
# Case perfect prediction: No foreground class present in prediction
self.buffer_collection.append(torch.tensor([self.metric_perfect_score]))
else:
# Case worst prediction: Predicted Foreground class but no GT
self.buffer_collection.append(torch.tensor([self.metric_worst_score]))
return
cc_assignment = self.space_separation(label_helper[0])

uniq, inv = self.xp.unique(cc_assignment.ravel(), return_inverse=True)

nof_components = uniq.size

code = label_helper.ravel() << 1 | pred_helper.ravel()
idx = inv << 2 | code
hist = self.xp.bincount(idx, minlength=nof_components * 4).reshape(-1, 4)
TN, FP, FN, TP = hist[:, 0], hist[:, 1], hist[:, 2], hist[:, 3]
denom = 2 * TP + FP + FN
dice_scores = self.xp.where(denom > 0, (2 * TP) / denom, 1.0)
dice_scores = (
torch.from_numpy(self.xp.asnumpy(dice_scores))
if self.backend == "cupy"
else torch.from_numpy(dice_scores)
)
self.buffer_collection.append(dice_scores.unsqueeze(-1))


class CCHausdorffDistanceMetric(CCBaseMetric):
"""
Expand Down
Loading