Skip to content

Commit 9edc956

Browse files
committed
adding confusion_matrix_patch.py with descriptions
1 parent 996e876 commit 9edc956

File tree

1 file changed

+360
-0
lines changed

1 file changed

+360
-0
lines changed
Lines changed: 360 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,360 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import warnings
15+
from collections.abc import Sequence
16+
17+
import torch
18+
19+
from monai.metrics.utils import do_metric_reduction, ignore_background
20+
from monai.utils import MetricReduction, ensure_tuple
21+
22+
from .metric import CumulativeIterationMetric
23+
24+
25+
class ConfusionMatrixMetricPatch(CumulativeIterationMetric):
26+
"""
27+
Compute confusion matrix related metrics. This function supports to calculate all metrics mentioned in:
28+
`Confusion matrix <https://en.wikipedia.org/wiki/Confusion_matrix>`_.
29+
It can support both multi-classes and multi-labels classification and segmentation tasks.
30+
`y_preds` is expected to have binarized predictions and `y` should be in one-hot format. You can use suitable transforms
31+
in ``monai.transforms.post`` first to achieve binarized values.
32+
The `include_background` parameter can be set to ``False`` for an instance to exclude
33+
the first category (channel index 0) which is by convention assumed to be background. If the non-background
34+
segmentations are small compared to the total image size they can get overwhelmed by the signal from the
35+
background.
36+
37+
Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.
38+
39+
Args:
40+
include_background: whether to include metric computation on the first channel of
41+
the predicted output. Defaults to True.
42+
metric_name: [``"sensitivity"``, ``"specificity"``, ``"precision"``, ``"negative predictive value"``,
43+
``"miss rate"``, ``"fall out"``, ``"false discovery rate"``, ``"false omission rate"``,
44+
``"prevalence threshold"``, ``"threat score"``, ``"accuracy"``, ``"balanced accuracy"``,
45+
``"f1 score"``, ``"matthews correlation coefficient"``, ``"fowlkes mallows index"``,
46+
``"informedness"``, ``"markedness"``]
47+
Some of the metrics have multiple aliases (as shown in the wikipedia page aforementioned),
48+
and you can also input those names instead.
49+
Except for input only one metric, multiple metrics are also supported via input a sequence of metric names, such as
50+
("sensitivity", "precision", "recall"), if ``compute_sample`` is ``True``, multiple ``f`` and ``not_nans`` will be
51+
returned with the same order as input names when calling the class.
52+
compute_sample: when reducing, if ``True``, each sample's metric will be computed based on each confusion matrix first.
53+
if ``False``, compute reduction on the confusion matrices first, defaults to ``False``.
54+
reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
55+
available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
56+
``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
57+
get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns [(metric, not_nans), ...]. If False,
58+
aggregate() returns [metric, ...].
59+
Here `not_nans` count the number of not nans for True Positive, False Positive, True Negative and False Negative.
60+
Its shape depends on the shape of the metric, and it has one more dimension with size 4. For example, if the shape
61+
of the metric is [3, 3], `not_nans` has the shape [3, 3, 4].
62+
63+
"""
64+
65+
def __init__(
66+
self,
67+
include_background: bool = True,
68+
metric_name: Sequence[str] | str = "hit_rate",
69+
compute_sample: bool = False,
70+
reduction: MetricReduction | str = MetricReduction.MEAN,
71+
get_not_nans: bool = False,
72+
) -> None:
73+
super().__init__()
74+
self.include_background = include_background
75+
self.metric_name = ensure_tuple(metric_name)
76+
self.compute_sample = compute_sample
77+
self.reduction = reduction
78+
self.get_not_nans = get_not_nans
79+
80+
def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override]
81+
"""
82+
Args:
83+
y_pred: input data to compute. It must be one-hot format and first dim is batch.
84+
The values should be binarized.
85+
y: ground truth to compute the metric. It must be one-hot format and first dim is batch.
86+
The values should be binarized.
87+
Raises:
88+
ValueError: when `y_pred` has less than two dimensions.
89+
"""
90+
# check dimension
91+
dims = y_pred.ndimension()
92+
if dims < 2:
93+
raise ValueError("y_pred should have at least two dimensions.")
94+
if dims == 2 or (dims == 3 and y_pred.shape[-1] == 1):
95+
if self.compute_sample:
96+
warnings.warn("As for classification task, compute_sample should be False.")
97+
self.compute_sample = False
98+
99+
return get_confusion_matrix(y_pred=y_pred, y=y, include_background=self.include_background)
100+
101+
def aggregate(
102+
self, compute_sample: bool = False, reduction: MetricReduction | str | None = None
103+
) -> list[torch.Tensor | tuple[torch.Tensor, torch.Tensor]]:
104+
"""
105+
Execute reduction for the confusion matrix values.
106+
107+
Args:
108+
compute_sample: when reducing, if ``True``, each sample's metric will be computed based on each confusion matrix first.
109+
if ``False``, compute reduction on the confusion matrices first, defaults to ``False``.
110+
reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
111+
available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
112+
``"mean_channel"``, ``"sum_channel"``}, default to `self.reduction`. if "none", will not do reduction.
113+
114+
"""
115+
data = self.get_buffer()
116+
if not isinstance(data, torch.Tensor):
117+
raise ValueError("the data to aggregate must be PyTorch Tensor.")
118+
119+
results: list[torch.Tensor | tuple[torch.Tensor, torch.Tensor]] = []
120+
for metric_name in self.metric_name:
121+
if compute_sample or self.compute_sample:
122+
sub_confusion_matrix = compute_confusion_matrix_metric(metric_name, data)
123+
f, not_nans = do_metric_reduction(sub_confusion_matrix, reduction or self.reduction)
124+
else:
125+
f, not_nans = do_metric_reduction(data, reduction or self.reduction)
126+
f = compute_confusion_matrix_metric(metric_name, f)
127+
if self.get_not_nans:
128+
results.append((f, not_nans))
129+
else:
130+
results.append(f)
131+
return results
132+
133+
134+
def get_confusion_matrix(y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True) -> torch.Tensor:
135+
"""
136+
Compute confusion matrix. A tensor with the shape [BC4] will be returned. Where, the third dimension
137+
represents the number of true positive, false positive, true negative and false negative values for
138+
each channel of each sample within the input batch. Where, B equals to the batch size and C equals to
139+
the number of classes that need to be computed.
140+
141+
Args:
142+
y_pred: input data to compute. It must be one-hot format and first dim is batch.
143+
The values should be binarized.
144+
y: ground truth to compute the metric. It must be one-hot format and first dim is batch.
145+
The values should be binarized.
146+
include_background: whether to include metric computation on the first channel of
147+
the predicted output. Defaults to True.
148+
149+
Raises:
150+
ValueError: when `y_pred` and `y` have different shapes.
151+
"""
152+
153+
if not include_background:
154+
y_pred, y = ignore_background(y_pred=y_pred, y=y)
155+
156+
if y.shape != y_pred.shape:
157+
raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.")
158+
159+
# get confusion matrix related metric
160+
batch_size, n_class = y_pred.shape[:2]
161+
# convert to [BNS], where S is the number of pixels for one sample.
162+
# As for classification tasks, S equals to 1.
163+
y_pred = y_pred.reshape(batch_size, n_class, -1)
164+
y = y.reshape(batch_size, n_class, -1)
165+
tp = (y_pred + y) == 2
166+
tn = (y_pred + y) == 0
167+
168+
tp = tp.sum(dim=[2]).float()
169+
tn = tn.sum(dim=[2]).float()
170+
p = y.sum(dim=[2]).float()
171+
n = y.shape[-1] - p
172+
173+
fn = p - tp
174+
fp = n - tn
175+
176+
return torch.stack([tp, fp, tn, fn], dim=-1)
177+
178+
179+
"""
180+
This function is used to compute confusion matrix related metric.
181+
182+
Args:
183+
metric_name: [``"sensitivity"``, ``"specificity"``, ``"precision"``, ``"negative predictive value"``,
184+
``"miss rate"``, ``"fall out"``, ``"false discovery rate"``, ``"false omission rate"``,
185+
``"prevalence threshold"``, ``"threat score"``, ``"accuracy"``, ``"balanced accuracy"``,
186+
``"f1 score"``, ``"matthews correlation coefficient"``, ``"fowlkes mallows index"``,
187+
``"informedness"``, ``"markedness"``]
188+
Some of the metrics have multiple aliases (as shown in the wikipedia page aforementioned),
189+
and you can also input those names instead.
190+
confusion_matrix: Please see the doc string of the function ``get_confusion_matrix`` for more details.
191+
192+
Raises:
193+
ValueError: when the size of the last dimension of confusion_matrix is not 4.
194+
NotImplementedError: when specify a not implemented metric_name.
195+
196+
"""
197+
198+
def compute_confusion_matrix_metric(metric_name: str, confusion_matrix: torch.Tensor) -> torch.Tensor:
199+
200+
metric = check_confusion_matrix_metric_name(metric_name)
201+
202+
""" Ckeck dimensionality of confusion_matrix tensor """
203+
input_dim = confusion_matrix.ndimension()
204+
205+
"""
206+
If confusion_matrix is a one-dimensional tensor, it will be given a new dimension.
207+
If the size of the last dimension of confusion_matrix is not 4 (the expected size of a standard 2x2 confusion matrix), a ValueError will be raised.
208+
"""
209+
if input_dim == 1:
210+
confusion_matrix = confusion_matrix.unsqueeze(dim=0)
211+
if confusion_matrix.shape[-1] != 4:
212+
raise ValueError("the size of the last dimension of confusion_matrix should be 4.")
213+
214+
tp = confusion_matrix[..., 0] # get True Positive (TP) from confusion_metrix
215+
fp = confusion_matrix[..., 1] # get False Positive (FP) ...
216+
tn = confusion_matrix[..., 2] # get True Negative (TN) ...
217+
fn = confusion_matrix[..., 3] # get False Negative (FN) ...
218+
p = tp + fn # total number of actual positive cases
219+
n = fp + tn # total number of actual negative cases
220+
221+
# calculate metric
222+
numerator: torch.Tensor
223+
denominator: torch.Tensor | float
224+
nan_tensor = torch.tensor(float("nan"), device=confusion_matrix.device)
225+
226+
"""
227+
1. tpr - True Positive Rate (Recall): The ratio of correctly predicted positive samples to the total number of samples that are actually positive.
228+
2. tnr - True Negative Rate: The proportion of correctly predicted negative samples over the total number of samples that are actually negative.
229+
3. ppv - Positive Predictive Value (Precision): The ratio of correctly predicted positive samples to the total number of samples predicted to be positive.
230+
4. npv - Negative Predictive Value: The ratio of correctly predicted negative samples to the total number of samples predicted to be negative.
231+
5. fnr - False Negative Rate: The ratio of positive samples that are incorrectly predicted to be negative to the total number of samples that are actually positive.
232+
6. fpr - False Positive Rate: The ratio of negative samples that are incorrectly predicted as positive to the total number of samples that are actually negative.
233+
7. fdr - False Discovery Rate: The ratio of predicted positive samples that are actually negative to the total number of samples predicted to be positive.
234+
8. for - False Omission Rate: The ratio of predicted negative samples that are actually positive to the total number of samples predicted to be negative.
235+
9. pt - Prevalence Threshold: It provides insight into the optimal balance point for deciding a positive or negative classification based on the prevalence of the condition in the dataset.
236+
10. ts - Threat Score: It measures the proportion of correct predictions among all relevant events.
237+
11. acc - Accuracy: It measures the proportion of correctly classified instances out of the total number of instances in a dataset.
238+
12. ba - Balanced Accuracy: It adjusts the traditional accuracy by accounting for both the True Positive Rate (Sensitivity) and the True Negative Rate (Specificity).
239+
13. f1 - F1-score: It is a performance metric for classification tasks, especially useful when the dataset is imbalanced. It combines Precision and Recall into a single metric by calculating their harmonic mean.
240+
14. mcc - Matthews Correlation Coefficient: A more robust measure of correlation between prediction and observation than accuracy, especially in cases of imbalanced classes.
241+
15. fm - Fowlkes-Mallows Index: It measures the geometric mean of precision and recall.
242+
16. bm - Informedness: It measures the extent to which the model's predictions are better than random guessing.
243+
17. mk - Markedness: It measures the extent to which the model's predictions are better than random guessing, focusing on the positive class.
244+
"""
245+
246+
match metric:
247+
case "tpr":
248+
numerator, denominator = tp, p
249+
case "tnr": #
250+
numerator, denominator = tn, n
251+
case "ppv":
252+
numerator, denominator = tp, (tp + fp)
253+
case "npv":
254+
numerator, denominator = tn, (tn + fn)
255+
case "fnr":
256+
numerator, denominator = fn, p
257+
case "fpr":
258+
numerator, denominator = fp, n
259+
case "fdr":
260+
numerator, denominator = fp, (fp + tp)
261+
case "for":
262+
numerator, denominator = fn, (fn + tn)
263+
case "pt":
264+
tpr = torch.where(p > 0, tp / p, nan_tensor)
265+
tnr = torch.where(n > 0, tn / n, nan_tensor)
266+
numerator = torch.sqrt(tpr * (1.0 - tnr)) + tnr - 1.0
267+
denominator = tpr + tnr - 1.0
268+
case "ts":
269+
numerator, denominator = tp, (tp + fn + fp)
270+
case "acc":
271+
numerator, denominator = (tp + tn), (p + n)
272+
case "ba":
273+
tpr = torch.where(p > 0, tp / p, nan_tensor)
274+
tnr = torch.where(n > 0, tn / n, nan_tensor)
275+
numerator, denominator = (tpr + tnr), 2.0
276+
case "f1":
277+
numerator, denominator = tp * 2.0, (tp * 2.0 + fn + fp)
278+
case "mcc":
279+
numerator = tp * tn - fp * fn
280+
denominator = torch.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))
281+
case "fm":
282+
tpr = torch.where(p > 0, tp / p, nan_tensor)
283+
ppv = torch.where((tp + fp) > 0, tp / (tp + fp), nan_tensor)
284+
numerator = torch.sqrt(ppv * tpr)
285+
denominator = 1.0
286+
case "bm":
287+
tpr = torch.where(p > 0, tp / p, nan_tensor)
288+
tnr = torch.where(n > 0, tn / n, nan_tensor)
289+
numerator = tpr + tnr - 1.0
290+
denominator = 1.0
291+
case "mk":
292+
ppv = torch.where((tp + fp) > 0, tp / (tp + fp), nan_tensor)
293+
npv = torch.where((tn + fn) > 0, tn / (tn + fn), nan_tensor)
294+
numerator = ppv + npv - 1.0
295+
denominator = 1.0
296+
case _:
297+
raise NotImplementedError("the metric is not implemented.")
298+
299+
if isinstance(denominator, torch.Tensor):
300+
return torch.where(denominator != 0, numerator / denominator, nan_tensor)
301+
return numerator / denominator
302+
303+
304+
def check_confusion_matrix_metric_name(metric_name: str) -> str:
305+
"""
306+
There are many metrics related to confusion matrix, and some of the metrics have
307+
more than one names. In addition, some of the names are very long.
308+
Therefore, this function is used to check and simplify the name.
309+
310+
Returns:
311+
Simplified metric name.
312+
313+
Raises:
314+
NotImplementedError: when the metric is not implemented.
315+
"""
316+
metric_name = metric_name.replace(" ", "_")
317+
metric_name = metric_name.lower()
318+
if metric_name in ["sensitivity", "recall", "hit_rate", "true_positive_rate", "tpr"]:
319+
return "tpr"
320+
if metric_name in ["specificity", "selectivity", "true_negative_rate", "tnr"]:
321+
return "tnr"
322+
if metric_name in ["precision", "positive_predictive_value", "ppv"]:
323+
return "ppv"
324+
if metric_name in ["negative_predictive_value", "npv"]:
325+
return "npv"
326+
if metric_name in ["miss_rate", "false_negative_rate", "fnr"]:
327+
return "fnr"
328+
if metric_name in ["fall_out", "false_positive_rate", "fpr"]:
329+
return "fpr"
330+
if metric_name in ["false_discovery_rate", "fdr"]:
331+
return "fdr"
332+
if metric_name in ["false_omission_rate", "for"]:
333+
return "for"
334+
if metric_name in ["prevalence_threshold", "pt"]:
335+
return "pt"
336+
if metric_name in ["threat_score", "critical_success_index", "ts", "csi"]:
337+
return "ts"
338+
if metric_name in ["accuracy", "acc"]:
339+
return "acc"
340+
if metric_name in ["balanced_accuracy", "ba"]:
341+
return "ba"
342+
if metric_name in ["f1_score", "f1"]:
343+
return "f1"
344+
if metric_name in ["matthews_correlation_coefficient", "mcc"]:
345+
return "mcc"
346+
if metric_name in ["fowlkes_mallows_index", "fm"]:
347+
return "fm"
348+
if metric_name in ["informedness", "bookmaker_informedness", "bm", "youden_index", "youden"]:
349+
return "bm"
350+
if metric_name in ["markedness", "deltap", "mk"]:
351+
return "mk"
352+
raise NotImplementedError("the metric is not implemented.")
353+
354+
355+
from sklearn.metrics import confusion_matrix
356+
357+
y_test = [0 ,1, 1, 1, 0, 1, 0, 1, 0, 1, 1]
358+
y_pred = [0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1]
359+
360+
print(confusion_matrix(y_test, y_pred))

0 commit comments

Comments
 (0)