1414import torch
1515
1616from monai .metrics .utils import do_metric_reduction
17- from monai .utils import MetricReduction
17+ from monai .utils import MetricReduction , deprecated_arg
1818
1919from .metric import CumulativeIterationMetric
2020
2424class DiceMetric (CumulativeIterationMetric ):
2525 """
2626 Computes Dice score for a set of pairs of prediction-groundtruth labels. It supports single-channel label maps
27- or multi-channel images with class segmentations per channel. This allows the computation for both multi-class
28- and multi-label tasks.
27+ or multi-channel images with class segmentations per channel. This allows the computation for both multi-class
28+ and multi-label tasks.
2929
3030 If either prediction ``y_pred`` or ground truth ``y`` have shape BCHW[D], it is expected that these represent one-
3131 hot segmentations for C number of classes. If either shape is B1HW[D], it is expected that these are label maps
32- and the number of classes must be specified by the ``num_classes`` parameter. In either case for either inputs,
33- this metric applies no activations and so non-binary values will produce unexpected results if this metric is used
34- for binary overlap measurement. Soft labels are thus permitted by this metric.
35-
36- The ``include_background`` parameter can be set to `False` to exclude the first category (channel index 0) which
37- is by convention assumed to be background. If the non-background segmentations are small compared to the total
32+ and the number of classes must be specified by the ``num_classes`` parameter. In either case for either inputs,
33+ this metric applies no activations and so non-binary values will produce unexpected results if this metric is used
34+ for binary overlap measurement (ie. either was expected to be one-hot formatted). Soft labels are thus permitted by
35+ this metric. Typically this implies that raw predictions from a network must first be activated and possibly made
36+ into label maps, eg. for a multi-class prediction tensor softmax and then argmax should be applied over the channel
37+ dimensions to produce a label map.
38+
39+ The ``include_background`` parameter can be set to `False` to exclude the first category (channel index 0) which
40+ is by convention assumed to be background. If the non-background segmentations are small compared to the total
3841 image size they can get overwhelmed by the signal from the background. This assumes the shape of both prediction
3942 and ground truth is BCHW[D].
4043
41- An example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.
44+ The typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.
45+
46+ Further information can be found in the official
47+ `MONAI Dice Overview <https://github.com/Project-MONAI/tutorials/blob/main/modules/dice_loss_metric_notes.ipynb>`.
4248
4349 Example:
4450
4551 .. code-block:: python
4652
4753 import torch
4854 from monai.metrics import DiceMetric
55+ from monai.losses import DiceLoss
56+ from monai.networks import one_hot
4957
5058 batch_size, n_classes, h, w = 7, 5, 128, 128
5159
5260 y_pred = torch.rand(batch_size, n_classes, h, w) # network predictions
5361 y_pred = torch.argmax(y_pred, 1, True) # convert to label map
5462
5563 # ground truth as label map
56- y = torch.randint(0, n_classes, size=(batch_size, 1, h, w))
64+ y = torch.randint(0, n_classes, size=(batch_size, 1, h, w))
5765
5866 dm = DiceMetric(
5967 reduction="mean_batch", return_with_label=True, num_classes=n_classes
@@ -62,16 +70,22 @@ class DiceMetric(CumulativeIterationMetric):
6270 raw_scores = dm(y_pred, y)
6371 print(dm.aggregate())
6472
73+ # now compute the Dice loss which should be the same as 1 - raw_scores
74+ dl = DiceLoss(to_onehot_y=True, reduction="none")
75+ loss = dl(one_hot(y_pred, n_classes), y).squeeze()
76+
77+ print(1.0 - loss) # same as raw_scores
78+
6579
6680 Args:
6781 include_background: whether to include Dice computation on the first channel/category of the prediction and
68- ground truth. Defaults to ``True``, use ``False`` to exclude the background class.
82+ ground truth. Defaults to ``True``, use ``False`` to exclude the background class.
6983 reduction: defines mode of reduction to the metrics, this will only apply reduction on `not-nan` values. The
7084 available reduction modes are enumerated by :py:class:`monai.utils.enums.MetricReduction`. If "none", is
7185 selected, the metric will not do reduction.
7286 get_not_nans: whether to return the `not_nans` count. If True, aggregate() returns `(metric, not_nans)` where
7387 `not_nans` counts the number of valid values in the result, and will have the same shape.
74- ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be
88+ ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be
7589 set for an empty ground truth cases, otherwise 1 will be set if the predictions of empty ground truth cases
7690 are also empty.
7791 num_classes: number of input channels (always including the background). When this is ``None``,
@@ -104,14 +118,14 @@ def __init__(
104118 include_background = self .include_background ,
105119 reduction = MetricReduction .NONE ,
106120 get_not_nans = False ,
107- softmax = False ,
121+ apply_argmax = False ,
108122 ignore_empty = self .ignore_empty ,
109123 num_classes = self .num_classes ,
110124 )
111125
112126 def _compute_tensor (self , y_pred : torch .Tensor , y : torch .Tensor ) -> torch .Tensor : # type: ignore[override]
113127 """
114- Compute the dice value using ``DiceHelper``.
128+ Compute the dice value using ``DiceHelper``.
115129
116130 Args:
117131 y_pred: prediction value, see class docstring for format definition.
@@ -133,7 +147,7 @@ def aggregate(
133147 Execute reduction and aggregation logic for the output of `compute_dice`.
134148
135149 Args:
136- reduction: defines mode of reduction as enumerated in :py:class:`monai.utils.enums.MetricReduction`.
150+ reduction: defines mode of reduction as enumerated in :py:class:`monai.utils.enums.MetricReduction`.
137151 By default this will do no reduction.
138152 """
139153 data = self .get_buffer ()
@@ -163,15 +177,16 @@ def compute_dice(
163177 num_classes : int | None = None ,
164178) -> torch .Tensor :
165179 """
166- Computes Dice score metric for a batch of predictions. This performs the same computation as
167- :py:class:`monai.metrics.DiceMetric`, see the documentation for that class for input formats.
180+ Computes Dice score metric for a batch of predictions. This performs the same computation as
181+ :py:class:`monai.metrics.DiceMetric`, which is preferrable to use over this function. For input formats, see the
182+ documentation for that class .
168183
169184 Args:
170185 y_pred: input data to compute, typical segmentation model output.
171- y: ground truth to compute mean dice metric.
186+ y: ground truth to compute mean dice metric.
172187 include_background: whether to include Dice computation on the first channel/category of the prediction and
173- ground truth. Defaults to ``True``, use ``False`` to exclude the background class.
174- ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be
188+ ground truth. Defaults to ``True``, use ``False`` to exclude the background class.
189+ ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be
175190 set for an empty ground truth cases, otherwise 1 will be set if the predictions of empty ground truth cases
176191 are also empty.
177192 num_classes: number of input channels (always including the background). When this is ``None``,
@@ -186,16 +201,16 @@ def compute_dice(
186201 include_background = include_background ,
187202 reduction = MetricReduction .NONE ,
188203 get_not_nans = False ,
189- softmax = False ,
204+ apply_argmax = False ,
190205 ignore_empty = ignore_empty ,
191206 num_classes = num_classes ,
192207 )(y_pred = y_pred , y = y )
193208
194209
195210class DiceHelper :
196211 """
197- Compute Dice score between two tensors ``y_pred`` and ``y``. This is used by :py:class:`monai.metrics.DiceMetric`,
198- see the documentation for that class for input formats.
212+ Compute Dice score between two tensors ``y_pred`` and ``y``. This is used by :py:class:`monai.metrics.DiceMetric`,
213+ see the documentation for that class for input formats.
199214
200215 Example:
201216
@@ -215,45 +230,63 @@ class DiceHelper:
215230
216231 Args:
217232 include_background: whether to include Dice computation on the first channel/category of the prediction and
218- ground truth. Defaults to ``True``, use ``False`` to exclude the background class.
219- sigmoid: if ``True`, ``y_pred`` will be thresholded at a value of 0.5. Defaults to False.
220- softmax: whether ``y_pred`` are softmax activated outputs. If True, `argmax` will be performed to
221- get the discrete prediction. Defaults to the value of ``not sigmoid``.
222- activate: if this and ``sigmoid` are ``True``, sigmoid activation is applied to ``y_pred``. Defaults to False.
233+ ground truth. Defaults to ``True``, use ``False`` to exclude the background class.
234+ threshold: if ``True`, ``y_pred`` will be thresholded at a value of 0.5. Defaults to False.
235+ apply_argmax: whether ``y_pred`` are softmax activated outputs. If True, `argmax` will be performed to
236+ get the discrete prediction. Defaults to the value of ``not threshold``.
237+ activate: if this and ``threshold` are ``True``, sigmoid activation is applied to ``y_pred`` before
238+ thresholding. Defaults to False.
223239 get_not_nans: whether to return the number of not-nan values.
224240 reduction: defines mode of reduction to the metrics, this will only apply reduction on `not-nan` values. The
225241 available reduction modes are enumerated by :py:class:`monai.utils.enums.MetricReduction`. If "none", is
226242 selected, the metric will not do reduction.
227- ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be
243+ ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be
228244 set for an empty ground truth cases, otherwise 1 will be set if the predictions of empty ground truth cases
229245 are also empty.
230246 num_classes: number of input channels (always including the background). When this is ``None``,
231247 ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
232248 single-channel class indices and the number of classes is not automatically inferred from data.
233249 """
234250
251+ @deprecated_arg ("softmax" , "1.5" , "1.7" , "Use `apply_argmax` instead." , new_name = "apply_argmax" )
252+ @deprecated_arg ("sigmoid" , "1.5" , "1.7" , "Use `threshold` instead." , new_name = "threshold" )
235253 def __init__ (
236254 self ,
237255 include_background : bool | None = None ,
238- sigmoid : bool = False ,
239- softmax : bool | None = None ,
256+ threshold : bool = False ,
257+ apply_argmax : bool | None = None ,
240258 activate : bool = False ,
241259 get_not_nans : bool = True ,
242260 reduction : MetricReduction | str = MetricReduction .MEAN_BATCH ,
243261 ignore_empty : bool = True ,
244262 num_classes : int | None = None ,
263+ sigmoid : bool | None = None ,
264+ softmax : bool | None = None ,
245265 ) -> None :
246- self .sigmoid = sigmoid
266+ # handling deprecated arguments
267+ if sigmoid is not None :
268+ threshold = sigmoid
269+ if softmax is not None :
270+ apply_argmax = softmax
271+
272+ self .threshold = threshold
247273 self .reduction = reduction
248274 self .get_not_nans = get_not_nans
249- self .include_background = sigmoid if include_background is None else include_background
250- self .softmax = not sigmoid if softmax is None else softmax
275+ self .include_background = threshold if include_background is None else include_background
276+ self .apply_argmax = not threshold if apply_argmax is None else apply_argmax
251277 self .activate = activate
252278 self .ignore_empty = ignore_empty
253279 self .num_classes = num_classes
254280
255281 def compute_channel (self , y_pred : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
256- """"""
282+ """
283+ Compute the dice metric for binary inputs which have only spatial dimensions. This method is called separately
284+ for each batch item and for each channel of those items.
285+
286+ Args:
287+ y_pred: input predictions with shape HW[D].
288+ y: ground truth with shape HW[D].
289+ """
257290 y_o = torch .sum (y )
258291 if y_o > 0 :
259292 return (2.0 * torch .sum (torch .masked_select (y , y_pred ))) / (y_o + torch .sum (y_pred ))
@@ -266,25 +299,25 @@ def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor
266299
267300 def __call__ (self , y_pred : torch .Tensor , y : torch .Tensor ) -> torch .Tensor | tuple [torch .Tensor , torch .Tensor ]:
268301 """
302+ Compute the metric for the given prediction and ground truth.
269303
270304 Args:
271305 y_pred: input predictions with shape (batch_size, num_classes or 1, spatial_dims...).
272306 the number of channels is inferred from ``y_pred.shape[1]`` when ``num_classes is None``.
273307 y: ground truth with shape (batch_size, num_classes or 1, spatial_dims...).
274308 """
275- _softmax , _sigmoid = self .softmax , self .sigmoid
309+ _apply_argmax , _threshold = self .apply_argmax , self .threshold
276310 if self .num_classes is None :
277311 n_pred_ch = y_pred .shape [1 ] # y_pred is in one-hot format or multi-channel scores
278312 else :
279313 n_pred_ch = self .num_classes
280314 if y_pred .shape [1 ] == 1 and self .num_classes > 1 : # y_pred is single-channel class indices
281- _softmax = _sigmoid = False
315+ _apply_argmax = _threshold = False
282316
283- if _softmax :
284- if n_pred_ch > 1 :
285- y_pred = torch .argmax (y_pred , dim = 1 , keepdim = True )
317+ if _apply_argmax and n_pred_ch > 1 :
318+ y_pred = torch .argmax (y_pred , dim = 1 , keepdim = True )
286319
287- elif _sigmoid :
320+ elif _threshold :
288321 if self .activate :
289322 y_pred = torch .sigmoid (y_pred )
290323 y_pred = y_pred > 0.5
0 commit comments