1414import torch
1515
1616from monai .metrics .utils import do_metric_reduction , ignore_background
17- from monai .utils import MetricReduction , Weight , look_up_option
17+ from monai .utils import MetricReduction , Weight , deprecated_arg , deprecated_arg_default , look_up_option
1818
1919from .metric import CumulativeIterationMetric
2020
2121
2222class GeneralizedDiceScore (CumulativeIterationMetric ):
23- """Compute the Generalized Dice Score metric between tensors, as the complement of the Generalized Dice Loss defined in:
23+ """
24+ Compute the Generalized Dice Score metric between tensors.
2425
26+ This metric is the complement of the Generalized Dice Loss defined in:
2527 Sudre, C. et. al. (2017) Generalised Dice overlap as a deep learning
26- loss function for highly unbalanced segmentations. DLMIA 2017.
28+ loss function for highly unbalanced segmentations. DLMIA 2017.
2729
28- The inputs `y_pred` and `y` are expected to be one-hot, binarized channel-first
29- or batch-first tensors, i.e., CHW[D] or BCHW[D].
30+ The inputs `y_pred` and `y` are expected to be one-hot, binarized batch-first tensors, i.e., NCHW[D].
3031
3132 Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.
3233
3334 Args:
34- include_background (bool, optional): whether to include the background class (assumed to be in channel 0), in the
35+ include_background: Whether to include the background class (assumed to be in channel 0) in the
3536 score computation. Defaults to True.
36- reduction (str, optional): define mode of reduction to the metrics. Available reduction modes:
37- {``"none"``, ``"mean_batch"``, ``"sum_batch"``}. Default to ``"mean_batch"``. If "none", will not do reduction.
38- weight_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform
37+ reduction: Define mode of reduction to the metrics. Available reduction modes:
38+ {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
39+ ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
40+ weight_type: {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform
3941 ground truth volume into a weight factor. Defaults to ``"square"``.
4042
4143 Raises:
42- ValueError: when the `weight_type ` is not one of {``"none"``, ``"mean"``, ``"sum"``} .
44+ ValueError: When the `reduction ` is not one of MetricReduction enum .
4345 """
4446
47+ @deprecated_arg_default (
48+ "reduction" ,
49+ old_default = MetricReduction .MEAN_BATCH ,
50+ new_default = MetricReduction .MEAN ,
51+ since = "1.4.0" ,
52+ replaced = "1.5.0" ,
53+ msg_suffix = (
54+ "Old versions computed `mean` when `mean_batch` was provided due to bug in reduction, "
55+ "If you want to retain the old behavior (calculating the mean), please explicitly set the parameter to 'mean'."
56+ ),
57+ )
4558 def __init__ (
4659 self ,
4760 include_background : bool = True ,
@@ -50,79 +63,90 @@ def __init__(
5063 ) -> None :
5164 super ().__init__ ()
5265 self .include_background = include_background
53- reduction_options = [
54- "none" ,
55- "mean_batch" ,
56- "sum_batch" ,
57- MetricReduction .NONE ,
58- MetricReduction .MEAN_BATCH ,
59- MetricReduction .SUM_BATCH ,
60- ]
61- self .reduction = reduction
62- if self .reduction not in reduction_options :
63- raise ValueError (f"reduction must be one of { reduction_options } " )
66+ self .reduction = look_up_option (reduction , MetricReduction )
6467 self .weight_type = look_up_option (weight_type , Weight )
68+ self .sum_over_classes = self .reduction in {
69+ MetricReduction .SUM ,
70+ MetricReduction .MEAN ,
71+ MetricReduction .MEAN_CHANNEL ,
72+ MetricReduction .SUM_CHANNEL ,
73+ }
6574
6675 def _compute_tensor (self , y_pred : torch .Tensor , y : torch .Tensor ) -> torch .Tensor : # type: ignore[override]
67- """Computes the Generalized Dice Score and returns a tensor with its per image values.
76+ """
77+ Computes the Generalized Dice Score and returns a tensor with its per image values.
6878
6979 Args:
70- y_pred (torch.Tensor): binarized segmentation model output. It must be in one-hot format and in the NCHW[D] format,
80+ y_pred (torch.Tensor): Binarized segmentation model output. It must be in one-hot format and in the NCHW[D] format,
7181 where N is the batch dimension, C is the channel dimension, and the remaining are the spatial dimensions.
72- y (torch.Tensor): binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`.
82+ y (torch.Tensor): Binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`.
83+
84+ Returns:
85+ torch.Tensor: Generalized Dice Score averaged across batch and class
7386
7487 Raises:
75- ValueError: if `y_pred` and `y` have less than 3 dimensions, or `y_pred` and `y` don't have the same shape.
88+ ValueError: If `y_pred` and `y` have less than 3 dimensions, or `y_pred` and `y` don't have the same shape.
7689 """
7790 return compute_generalized_dice (
78- y_pred = y_pred , y = y , include_background = self .include_background , weight_type = self .weight_type
91+ y_pred = y_pred ,
92+ y = y ,
93+ include_background = self .include_background ,
94+ weight_type = self .weight_type ,
95+ sum_over_classes = self .sum_over_classes ,
7996 )
8097
98+ @deprecated_arg (
99+ "reduction" ,
100+ since = "1.3.3" ,
101+ removed = "1.7.0" ,
102+ msg_suffix = "Reduction will be ignored. Set reduction during init. as gen.dice needs it during compute" ,
103+ )
81104 def aggregate (self , reduction : MetricReduction | str | None = None ) -> torch .Tensor :
82105 """
83106 Execute reduction logic for the output of `compute_generalized_dice`.
84107
85- Args:
86- reduction (Union[MetricReduction, str, None], optional): define mode of reduction to the metrics.
87- Available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``}.
88- Defaults to ``"mean"``. If "none", will not do reduction.
108+ Returns:
109+ torch.Tensor: Aggregated metric value.
110+
111+ Raises:
112+ ValueError: If the data to aggregate is not a PyTorch Tensor.
89113 """
90114 data = self .get_buffer ()
91115 if not isinstance (data , torch .Tensor ):
92116 raise ValueError ("The data to aggregate must be a PyTorch Tensor." )
93117
94- # Validate reduction argument if specified
95- if reduction is not None :
96- reduction_options = ["none" , "mean" , "sum" , "mean_batch" , "sum_batch" ]
97- if reduction not in reduction_options :
98- raise ValueError (f"reduction must be one of { reduction_options } " )
99-
100118 # Do metric reduction and return
101- f , _ = do_metric_reduction (data , reduction or self .reduction )
119+ f , _ = do_metric_reduction (data , self .reduction )
102120
103121 return f
104122
105123
106124def compute_generalized_dice (
107- y_pred : torch .Tensor , y : torch .Tensor , include_background : bool = True , weight_type : Weight | str = Weight .SQUARE
125+ y_pred : torch .Tensor ,
126+ y : torch .Tensor ,
127+ include_background : bool = True ,
128+ weight_type : Weight | str = Weight .SQUARE ,
129+ sum_over_classes : bool = False ,
108130) -> torch .Tensor :
109- """Computes the Generalized Dice Score and returns a tensor with its per image values.
131+ """
132+ Computes the Generalized Dice Score and returns a tensor with its per image values.
110133
111134 Args:
112- y_pred (torch.Tensor): binarized segmentation model output. It should be binarized, in one-hot format
135+ y_pred (torch.Tensor): Binarized segmentation model output. It should be binarized, in one-hot format
113136 and in the NCHW[D] format, where N is the batch dimension, C is the channel dimension, and the
114137 remaining are the spatial dimensions.
115- y (torch.Tensor): binarized ground-truth. It should be binarized, in one-hot format and have the same shape as `y_pred`.
116- include_background (bool, optional): whether to include score computation on the first channel of the
138+ y (torch.Tensor): Binarized ground-truth. It should be binarized, in one-hot format and have the same shape as `y_pred`.
139+ include_background: Whether to include score computation on the first channel of the
117140 predicted output. Defaults to True.
118141 weight_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to
119142 transform ground truth volume into a weight factor. Defaults to ``"square"``.
143+ sum_over_labels (bool): Whether to sum the numerator and denominator across all labels before the final computation.
120144
121145 Returns:
122- torch.Tensor: per batch and per class Generalized Dice Score, i.e., with the shape [batch_size, num_classes].
146+ torch.Tensor: Per batch and per class Generalized Dice Score, i.e., with the shape [batch_size, num_classes].
123147
124148 Raises:
125- ValueError: if `y_pred` or `y` are not PyTorch tensors, if `y_pred` and `y` have less than three dimensions,
149+ ValueError: If `y_pred` or `y` are not PyTorch tensors, if `y_pred` and `y` have less than three dimensions,
126150 or `y_pred` and `y` don't have the same shape.
127151 """
128152 # Ensure tensors have at least 3 dimensions and have the same shape
@@ -158,16 +182,21 @@ def compute_generalized_dice(
158182 b [infs ] = 0
159183 b [infs ] = torch .max (b )
160184
161- # Compute the weighted numerator and denominator, summing along the class axis
162- numer = 2.0 * (intersection * w ).sum (dim = 1 )
163- denom = (denominator * w ).sum (dim = 1 )
185+ # Compute the weighted numerator and denominator, summing along the class axis when sum_over_classes is True
186+ if sum_over_classes :
187+ numer = 2.0 * (intersection * w ).sum (dim = 1 , keepdim = True )
188+ denom = (denominator * w ).sum (dim = 1 , keepdim = True )
189+ y_pred_o = y_pred_o .sum (dim = - 1 , keepdim = True )
190+ else :
191+ numer = 2.0 * (intersection * w )
192+ denom = denominator * w
193+ y_pred_o = y_pred_o
164194
165195 # Compute the score
166196 generalized_dice_score = numer / denom
167197
168198 # Handle zero division. Where denom == 0 and the prediction volume is 0, score is 1.
169199 # Where denom == 0 but the prediction volume is not 0, score is 0
170- y_pred_o = y_pred_o .sum (dim = - 1 )
171200 denom_zeros = denom == 0
172201 generalized_dice_score [denom_zeros ] = torch .where (
173202 (y_pred_o == 0 )[denom_zeros ],
0 commit comments