1+ # Copyright (c) MONAI Consortium
2+ # Licensed under the Apache License, Version 2.0 (the "License");
3+ # you may not use this file except in compliance with the License.
4+ # You may obtain a copy of the License at
5+ # http://www.apache.org/licenses/LICENSE-2.0
6+ # Unless required by applicable law or agreed to in writing, software
7+ # distributed under the License is distributed on an "AS IS" BASIS,
8+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+ # See the License for the specific language governing permissions and
10+ # limitations under the License.
11+
12+ from __future__ import annotations
13+
14+ import warnings
15+ from collections .abc import Sequence
16+
17+ import torch
18+
19+ from monai .metrics .utils import do_metric_reduction , ignore_background
20+ from monai .utils import MetricReduction , ensure_tuple
21+
22+ from .metric import CumulativeIterationMetric
23+
24+
25+ class ConfusionMatrixMetricPatch (CumulativeIterationMetric ):
26+ """
27+ Compute confusion matrix related metrics. This function supports to calculate all metrics mentioned in:
28+ `Confusion matrix <https://en.wikipedia.org/wiki/Confusion_matrix>`_.
29+ It can support both multi-classes and multi-labels classification and segmentation tasks.
30+ `y_preds` is expected to have binarized predictions and `y` should be in one-hot format. You can use suitable transforms
31+ in ``monai.transforms.post`` first to achieve binarized values.
32+ The `include_background` parameter can be set to ``False`` for an instance to exclude
33+ the first category (channel index 0) which is by convention assumed to be background. If the non-background
34+ segmentations are small compared to the total image size they can get overwhelmed by the signal from the
35+ background.
36+
37+ Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.
38+
39+ Args:
40+ include_background: whether to include metric computation on the first channel of
41+ the predicted output. Defaults to True.
42+ metric_name: [``"sensitivity"``, ``"specificity"``, ``"precision"``, ``"negative predictive value"``,
43+ ``"miss rate"``, ``"fall out"``, ``"false discovery rate"``, ``"false omission rate"``,
44+ ``"prevalence threshold"``, ``"threat score"``, ``"accuracy"``, ``"balanced accuracy"``,
45+ ``"f1 score"``, ``"matthews correlation coefficient"``, ``"fowlkes mallows index"``,
46+ ``"informedness"``, ``"markedness"``]
47+ Some of the metrics have multiple aliases (as shown in the wikipedia page aforementioned),
48+ and you can also input those names instead.
49+ Except for input only one metric, multiple metrics are also supported via input a sequence of metric names, such as
50+ ("sensitivity", "precision", "recall"), if ``compute_sample`` is ``True``, multiple ``f`` and ``not_nans`` will be
51+ returned with the same order as input names when calling the class.
52+ compute_sample: when reducing, if ``True``, each sample's metric will be computed based on each confusion matrix first.
53+ if ``False``, compute reduction on the confusion matrices first, defaults to ``False``.
54+ reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
55+ available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
56+ ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
57+ get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns [(metric, not_nans), ...]. If False,
58+ aggregate() returns [metric, ...].
59+ Here `not_nans` count the number of not nans for True Positive, False Positive, True Negative and False Negative.
60+ Its shape depends on the shape of the metric, and it has one more dimension with size 4. For example, if the shape
61+ of the metric is [3, 3], `not_nans` has the shape [3, 3, 4].
62+
63+ """
64+
65+ def __init__ (
66+ self ,
67+ include_background : bool = True ,
68+ metric_name : Sequence [str ] | str = "hit_rate" ,
69+ compute_sample : bool = False ,
70+ reduction : MetricReduction | str = MetricReduction .MEAN ,
71+ get_not_nans : bool = False ,
72+ ) -> None :
73+ super ().__init__ ()
74+ self .include_background = include_background
75+ self .metric_name = ensure_tuple (metric_name )
76+ self .compute_sample = compute_sample
77+ self .reduction = reduction
78+ self .get_not_nans = get_not_nans
79+
80+ def _compute_tensor (self , y_pred : torch .Tensor , y : torch .Tensor ) -> torch .Tensor : # type: ignore[override]
81+ """
82+ Args:
83+ y_pred: input data to compute. It must be one-hot format and first dim is batch.
84+ The values should be binarized.
85+ y: ground truth to compute the metric. It must be one-hot format and first dim is batch.
86+ The values should be binarized.
87+ Raises:
88+ ValueError: when `y_pred` has less than two dimensions.
89+ """
90+ # check dimension
91+ dims = y_pred .ndimension ()
92+ if dims < 2 :
93+ raise ValueError ("y_pred should have at least two dimensions." )
94+ if dims == 2 or (dims == 3 and y_pred .shape [- 1 ] == 1 ):
95+ if self .compute_sample :
96+ warnings .warn ("As for classification task, compute_sample should be False." )
97+ self .compute_sample = False
98+
99+ return get_confusion_matrix (y_pred = y_pred , y = y , include_background = self .include_background )
100+
101+ def aggregate (
102+ self , compute_sample : bool = False , reduction : MetricReduction | str | None = None
103+ ) -> list [torch .Tensor | tuple [torch .Tensor , torch .Tensor ]]:
104+ """
105+ Execute reduction for the confusion matrix values.
106+
107+ Args:
108+ compute_sample: when reducing, if ``True``, each sample's metric will be computed based on each confusion matrix first.
109+ if ``False``, compute reduction on the confusion matrices first, defaults to ``False``.
110+ reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
111+ available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
112+ ``"mean_channel"``, ``"sum_channel"``}, default to `self.reduction`. if "none", will not do reduction.
113+
114+ """
115+ data = self .get_buffer ()
116+ if not isinstance (data , torch .Tensor ):
117+ raise ValueError ("the data to aggregate must be PyTorch Tensor." )
118+
119+ results : list [torch .Tensor | tuple [torch .Tensor , torch .Tensor ]] = []
120+ for metric_name in self .metric_name :
121+ if compute_sample or self .compute_sample :
122+ sub_confusion_matrix = compute_confusion_matrix_metric (metric_name , data )
123+ f , not_nans = do_metric_reduction (sub_confusion_matrix , reduction or self .reduction )
124+ else :
125+ f , not_nans = do_metric_reduction (data , reduction or self .reduction )
126+ f = compute_confusion_matrix_metric (metric_name , f )
127+ if self .get_not_nans :
128+ results .append ((f , not_nans ))
129+ else :
130+ results .append (f )
131+ return results
132+
133+
134+ def get_confusion_matrix (y_pred : torch .Tensor , y : torch .Tensor , include_background : bool = True ) -> torch .Tensor :
135+ """
136+ Compute confusion matrix. A tensor with the shape [BC4] will be returned. Where, the third dimension
137+ represents the number of true positive, false positive, true negative and false negative values for
138+ each channel of each sample within the input batch. Where, B equals to the batch size and C equals to
139+ the number of classes that need to be computed.
140+
141+ Args:
142+ y_pred: input data to compute. It must be one-hot format and first dim is batch.
143+ The values should be binarized.
144+ y: ground truth to compute the metric. It must be one-hot format and first dim is batch.
145+ The values should be binarized.
146+ include_background: whether to include metric computation on the first channel of
147+ the predicted output. Defaults to True.
148+
149+ Raises:
150+ ValueError: when `y_pred` and `y` have different shapes.
151+ """
152+
153+ if not include_background :
154+ y_pred , y = ignore_background (y_pred = y_pred , y = y )
155+
156+ if y .shape != y_pred .shape :
157+ raise ValueError (f"y_pred and y should have same shapes, got { y_pred .shape } and { y .shape } ." )
158+
159+ # get confusion matrix related metric
160+ batch_size , n_class = y_pred .shape [:2 ]
161+ # convert to [BNS], where S is the number of pixels for one sample.
162+ # As for classification tasks, S equals to 1.
163+ y_pred = y_pred .reshape (batch_size , n_class , - 1 )
164+ y = y .reshape (batch_size , n_class , - 1 )
165+ tp = (y_pred + y ) == 2
166+ tn = (y_pred + y ) == 0
167+
168+ tp = tp .sum (dim = [2 ]).float ()
169+ tn = tn .sum (dim = [2 ]).float ()
170+ p = y .sum (dim = [2 ]).float ()
171+ n = y .shape [- 1 ] - p
172+
173+ fn = p - tp
174+ fp = n - tn
175+
176+ return torch .stack ([tp , fp , tn , fn ], dim = - 1 )
177+
178+
179+ """
180+ This function is used to compute confusion matrix related metric.
181+
182+ Args:
183+ metric_name: [``"sensitivity"``, ``"specificity"``, ``"precision"``, ``"negative predictive value"``,
184+ ``"miss rate"``, ``"fall out"``, ``"false discovery rate"``, ``"false omission rate"``,
185+ ``"prevalence threshold"``, ``"threat score"``, ``"accuracy"``, ``"balanced accuracy"``,
186+ ``"f1 score"``, ``"matthews correlation coefficient"``, ``"fowlkes mallows index"``,
187+ ``"informedness"``, ``"markedness"``]
188+ Some of the metrics have multiple aliases (as shown in the wikipedia page aforementioned),
189+ and you can also input those names instead.
190+ confusion_matrix: Please see the doc string of the function ``get_confusion_matrix`` for more details.
191+
192+ Raises:
193+ ValueError: when the size of the last dimension of confusion_matrix is not 4.
194+ NotImplementedError: when specify a not implemented metric_name.
195+
196+ """
197+
198+ def compute_confusion_matrix_metric (metric_name : str , confusion_matrix : torch .Tensor ) -> torch .Tensor :
199+
200+ metric = check_confusion_matrix_metric_name (metric_name )
201+
202+ """ Ckeck dimensionality of confusion_matrix tensor """
203+ input_dim = confusion_matrix .ndimension ()
204+
205+ """
206+ If confusion_matrix is a one-dimensional tensor, it will be given a new dimension.
207+ If the size of the last dimension of confusion_matrix is not 4 (the expected size of a standard 2x2 confusion matrix), a ValueError will be raised.
208+ """
209+ if input_dim == 1 :
210+ confusion_matrix = confusion_matrix .unsqueeze (dim = 0 )
211+ if confusion_matrix .shape [- 1 ] != 4 :
212+ raise ValueError ("the size of the last dimension of confusion_matrix should be 4." )
213+
214+ tp = confusion_matrix [..., 0 ] # get True Positive (TP) from confusion_metrix
215+ fp = confusion_matrix [..., 1 ] # get False Positive (FP) ...
216+ tn = confusion_matrix [..., 2 ] # get True Negative (TN) ...
217+ fn = confusion_matrix [..., 3 ] # get False Negative (FN) ...
218+ p = tp + fn # total number of actual positive cases
219+ n = fp + tn # total number of actual negative cases
220+
221+ # calculate metric
222+ numerator : torch .Tensor
223+ denominator : torch .Tensor | float
224+ nan_tensor = torch .tensor (float ("nan" ), device = confusion_matrix .device )
225+
226+ """
227+ 1. tpr - True Positive Rate (Recall): The ratio of correctly predicted positive samples to the total number of samples that are actually positive.
228+ 2. tnr - True Negative Rate: The proportion of correctly predicted negative samples over the total number of samples that are actually negative.
229+ 3. ppv - Positive Predictive Value (Precision): The ratio of correctly predicted positive samples to the total number of samples predicted to be positive.
230+ 4. npv - Negative Predictive Value: The ratio of correctly predicted negative samples to the total number of samples predicted to be negative.
231+ 5. fnr - False Negative Rate: The ratio of positive samples that are incorrectly predicted to be negative to the total number of samples that are actually positive.
232+ 6. fpr - False Positive Rate: The ratio of negative samples that are incorrectly predicted as positive to the total number of samples that are actually negative.
233+ 7. fdr - False Discovery Rate: The ratio of predicted positive samples that are actually negative to the total number of samples predicted to be positive.
234+ 8. for - False Omission Rate: The ratio of predicted negative samples that are actually positive to the total number of samples predicted to be negative.
235+ 9. pt - Prevalence Threshold: It provides insight into the optimal balance point for deciding a positive or negative classification based on the prevalence of the condition in the dataset.
236+ 10. ts - Threat Score: It measures the proportion of correct predictions among all relevant events.
237+ 11. acc - Accuracy: It measures the proportion of correctly classified instances out of the total number of instances in a dataset.
238+ 12. ba - Balanced Accuracy: It adjusts the traditional accuracy by accounting for both the True Positive Rate (Sensitivity) and the True Negative Rate (Specificity).
239+ 13. f1 - F1-score: It is a performance metric for classification tasks, especially useful when the dataset is imbalanced. It combines Precision and Recall into a single metric by calculating their harmonic mean.
240+ 14. mcc - Matthews Correlation Coefficient: A more robust measure of correlation between prediction and observation than accuracy, especially in cases of imbalanced classes.
241+ 15. fm - Fowlkes-Mallows Index: It measures the geometric mean of precision and recall.
242+ 16. bm - Informedness: It measures the extent to which the model's predictions are better than random guessing.
243+ 17. mk - Markedness: It measures the extent to which the model's predictions are better than random guessing, focusing on the positive class.
244+ """
245+
246+ match metric :
247+ case "tpr" :
248+ numerator , denominator = tp , p
249+ case "tnr" : #
250+ numerator , denominator = tn , n
251+ case "ppv" :
252+ numerator , denominator = tp , (tp + fp )
253+ case "npv" :
254+ numerator , denominator = tn , (tn + fn )
255+ case "fnr" :
256+ numerator , denominator = fn , p
257+ case "fpr" :
258+ numerator , denominator = fp , n
259+ case "fdr" :
260+ numerator , denominator = fp , (fp + tp )
261+ case "for" :
262+ numerator , denominator = fn , (fn + tn )
263+ case "pt" :
264+ tpr = torch .where (p > 0 , tp / p , nan_tensor )
265+ tnr = torch .where (n > 0 , tn / n , nan_tensor )
266+ numerator = torch .sqrt (tpr * (1.0 - tnr )) + tnr - 1.0
267+ denominator = tpr + tnr - 1.0
268+ case "ts" :
269+ numerator , denominator = tp , (tp + fn + fp )
270+ case "acc" :
271+ numerator , denominator = (tp + tn ), (p + n )
272+ case "ba" :
273+ tpr = torch .where (p > 0 , tp / p , nan_tensor )
274+ tnr = torch .where (n > 0 , tn / n , nan_tensor )
275+ numerator , denominator = (tpr + tnr ), 2.0
276+ case "f1" :
277+ numerator , denominator = tp * 2.0 , (tp * 2.0 + fn + fp )
278+ case "mcc" :
279+ numerator = tp * tn - fp * fn
280+ denominator = torch .sqrt ((tp + fp ) * (tp + fn ) * (tn + fp ) * (tn + fn ))
281+ case "fm" :
282+ tpr = torch .where (p > 0 , tp / p , nan_tensor )
283+ ppv = torch .where ((tp + fp ) > 0 , tp / (tp + fp ), nan_tensor )
284+ numerator = torch .sqrt (ppv * tpr )
285+ denominator = 1.0
286+ case "bm" :
287+ tpr = torch .where (p > 0 , tp / p , nan_tensor )
288+ tnr = torch .where (n > 0 , tn / n , nan_tensor )
289+ numerator = tpr + tnr - 1.0
290+ denominator = 1.0
291+ case "mk" :
292+ ppv = torch .where ((tp + fp ) > 0 , tp / (tp + fp ), nan_tensor )
293+ npv = torch .where ((tn + fn ) > 0 , tn / (tn + fn ), nan_tensor )
294+ numerator = ppv + npv - 1.0
295+ denominator = 1.0
296+ case _:
297+ raise NotImplementedError ("the metric is not implemented." )
298+
299+ if isinstance (denominator , torch .Tensor ):
300+ return torch .where (denominator != 0 , numerator / denominator , nan_tensor )
301+ return numerator / denominator
302+
303+
304+ def check_confusion_matrix_metric_name (metric_name : str ) -> str :
305+ """
306+ There are many metrics related to confusion matrix, and some of the metrics have
307+ more than one names. In addition, some of the names are very long.
308+ Therefore, this function is used to check and simplify the name.
309+
310+ Returns:
311+ Simplified metric name.
312+
313+ Raises:
314+ NotImplementedError: when the metric is not implemented.
315+ """
316+ metric_name = metric_name .replace (" " , "_" )
317+ metric_name = metric_name .lower ()
318+ if metric_name in ["sensitivity" , "recall" , "hit_rate" , "true_positive_rate" , "tpr" ]:
319+ return "tpr"
320+ if metric_name in ["specificity" , "selectivity" , "true_negative_rate" , "tnr" ]:
321+ return "tnr"
322+ if metric_name in ["precision" , "positive_predictive_value" , "ppv" ]:
323+ return "ppv"
324+ if metric_name in ["negative_predictive_value" , "npv" ]:
325+ return "npv"
326+ if metric_name in ["miss_rate" , "false_negative_rate" , "fnr" ]:
327+ return "fnr"
328+ if metric_name in ["fall_out" , "false_positive_rate" , "fpr" ]:
329+ return "fpr"
330+ if metric_name in ["false_discovery_rate" , "fdr" ]:
331+ return "fdr"
332+ if metric_name in ["false_omission_rate" , "for" ]:
333+ return "for"
334+ if metric_name in ["prevalence_threshold" , "pt" ]:
335+ return "pt"
336+ if metric_name in ["threat_score" , "critical_success_index" , "ts" , "csi" ]:
337+ return "ts"
338+ if metric_name in ["accuracy" , "acc" ]:
339+ return "acc"
340+ if metric_name in ["balanced_accuracy" , "ba" ]:
341+ return "ba"
342+ if metric_name in ["f1_score" , "f1" ]:
343+ return "f1"
344+ if metric_name in ["matthews_correlation_coefficient" , "mcc" ]:
345+ return "mcc"
346+ if metric_name in ["fowlkes_mallows_index" , "fm" ]:
347+ return "fm"
348+ if metric_name in ["informedness" , "bookmaker_informedness" , "bm" , "youden_index" , "youden" ]:
349+ return "bm"
350+ if metric_name in ["markedness" , "deltap" , "mk" ]:
351+ return "mk"
352+ raise NotImplementedError ("the metric is not implemented." )
353+
354+
355+ from sklearn .metrics import confusion_matrix
356+
357+ y_test = [0 ,1 , 1 , 1 , 0 , 1 , 0 , 1 , 0 , 1 , 1 ]
358+ y_pred = [0 , 1 , 1 , 0 , 1 , 0 , 1 , 1 , 1 , 0 , 1 ]
359+
360+ print (confusion_matrix (y_test , y_pred ))
0 commit comments