1414import torch
1515
1616from monai .metrics .utils import do_metric_reduction
17- from monai .utils import MetricReduction
17+ from monai .utils import MetricReduction , deprecated_arg
1818
1919from .metric import CumulativeIterationMetric
2020
2323
2424class DiceMetric (CumulativeIterationMetric ):
2525 """
26- Compute average Dice score for a set of pairs of prediction-groundtruth segmentations.
26+ Computes Dice score for a set of pairs of prediction-groundtruth labels. It supports single-channel label maps
27+ or multi-channel images with class segmentations per channel. This allows the computation for both multi-class
28+ and multi-label tasks.
2729
28- It supports both multi-classes and multi-labels tasks.
29- Input `y_pred` is compared with ground truth `y`.
30- `y_pred` is expected to have binarized predictions and `y` can be single-channel class indices or in the
31- one-hot format. The `include_background` parameter can be set to ``False`` to exclude
32- the first category (channel index 0) which is by convention assumed to be background. If the non-background
33- segmentations are small compared to the total image size they can get overwhelmed by the signal from the
34- background. `y_preds` and `y` can be a list of channel-first Tensor (CHW[D]) or a batch-first Tensor (BCHW[D]),
35- `y` can also be in the format of `B1HW[D]`.
30+ If either prediction ``y_pred`` or ground truth ``y`` have shape BCHW[D], it is expected that these represent one-
31+ hot segmentations for C number of classes. If either shape is B1HW[D], it is expected that these are label maps
32+ and the number of classes must be specified by the ``num_classes`` parameter. In either case for either inputs,
33+ this metric applies no activations and so non-binary values will produce unexpected results if this metric is used
34+ for binary overlap measurement (ie. either was expected to be one-hot formatted). Soft labels are thus permitted by
35+ this metric. Typically this implies that raw predictions from a network must first be activated and possibly made
36+ into label maps, eg. for a multi-class prediction tensor softmax and then argmax should be applied over the channel
37+ dimensions to produce a label map.
38+
39+ The ``include_background`` parameter can be set to `False` to exclude the first category (channel index 0) which
40+ is by convention assumed to be background. If the non-background segmentations are small compared to the total
41+ image size they can get overwhelmed by the signal from the background. This assumes the shape of both prediction
42+ and ground truth is BCHW[D].
43+
44+ The typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.
45+
46+ Further information can be found in the official
47+ `MONAI Dice Overview <https://github.com/Project-MONAI/tutorials/blob/main/modules/dice_loss_metric_notes.ipynb>`.
48+
49+ Example:
50+
51+ .. code-block:: python
52+
53+ import torch
54+ from monai.metrics import DiceMetric
55+ from monai.losses import DiceLoss
56+ from monai.networks import one_hot
57+
58+ batch_size, n_classes, h, w = 7, 5, 128, 128
59+
60+ y_pred = torch.rand(batch_size, n_classes, h, w) # network predictions
61+ y_pred = torch.argmax(y_pred, 1, True) # convert to label map
62+
63+ # ground truth as label map
64+ y = torch.randint(0, n_classes, size=(batch_size, 1, h, w))
65+
66+ dm = DiceMetric(
67+ reduction="mean_batch", return_with_label=True, num_classes=n_classes
68+ )
69+
70+ raw_scores = dm(y_pred, y)
71+ print(dm.aggregate())
72+
73+ # now compute the Dice loss which should be the same as 1 - raw_scores
74+ dl = DiceLoss(to_onehot_y=True, reduction="none")
75+ loss = dl(one_hot(y_pred, n_classes), y).squeeze()
76+
77+ print(1.0 - loss) # same as raw_scores
3678
37- Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.
3879
3980 Args:
40- include_background: whether to include Dice computation on the first channel of
41- the predicted output . Defaults to ``True``.
42- reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
43- available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
44- ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
45- get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).
46- Here `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric .
47- ignore_empty: whether to ignore empty ground truth cases during calculation.
48- If `True`, NaN value will be set for empty ground truth cases.
49- If `False`, 1 will be set if the predictions of empty ground truth cases are also empty.
50- num_classes: number of input channels (always including the background). When this is None,
81+ include_background: whether to include Dice computation on the first channel/category of the prediction and
82+ ground truth . Defaults to ``True``, use ``False`` to exclude the background class .
83+ reduction: defines mode of reduction to the metrics, this will only apply reduction on `not-nan` values. The
84+ available reduction modes are enumerated by :py:class:`monai.utils.enums.MetricReduction`. If "none", is
85+ selected, the metric will not do reduction.
86+ get_not_nans: whether to return the `not_nans` count. If True, aggregate() returns ` (metric, not_nans)` where
87+ `not_nans` counts the number of valid values in the result, and will have the same shape .
88+ ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be
89+ set for an empty ground truth cases, otherwise 1 will be set if the predictions of empty ground truth cases
90+ are also empty.
91+ num_classes: number of input channels (always including the background). When this is `` None`` ,
5192 ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
5293 single-channel class indices and the number of classes is not automatically inferred from data.
5394 return_with_label: whether to return the metrics with label, only works when reduction is "mean_batch".
54- If `True`, use "label_{index}" as the key corresponding to C channels; if ' include_background' is True,
95+ If `True`, use "label_{index}" as the key corresponding to C channels; if `` include_background`` is True,
5596 the index begins at "0", otherwise at "1". It can also take a list of label names.
5697 The outcome will then be returned as a dictionary.
5798
@@ -77,22 +118,21 @@ def __init__(
77118 include_background = self .include_background ,
78119 reduction = MetricReduction .NONE ,
79120 get_not_nans = False ,
80- softmax = False ,
121+ apply_argmax = False ,
81122 ignore_empty = self .ignore_empty ,
82123 num_classes = self .num_classes ,
83124 )
84125
85126 def _compute_tensor (self , y_pred : torch .Tensor , y : torch .Tensor ) -> torch .Tensor : # type: ignore[override]
86127 """
128+ Compute the dice value using ``DiceHelper``.
129+
87130 Args:
88- y_pred: input data to compute, typical segmentation model output.
89- It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values
90- should be binarized.
91- y: ground truth to compute mean Dice metric. `y` can be single-channel class indices or
92- in the one-hot format.
131+ y_pred: prediction value, see class docstring for format definition.
132+ y: ground truth label.
93133
94134 Raises:
95- ValueError: when `y_pred` has less than three dimensions.
135+ ValueError: when `y_pred` has fewer than three dimensions.
96136 """
97137 dims = y_pred .ndimension ()
98138 if dims < 3 :
@@ -107,10 +147,8 @@ def aggregate(
107147 Execute reduction and aggregation logic for the output of `compute_dice`.
108148
109149 Args:
110- reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
111- available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
112- ``"mean_channel"``, ``"sum_channel"``}, default to `self.reduction`. if "none", will not do reduction.
113-
150+ reduction: defines mode of reduction as enumerated in :py:class:`monai.utils.enums.MetricReduction`.
151+ By default this will do no reduction.
114152 """
115153 data = self .get_buffer ()
116154 if not isinstance (data , torch .Tensor ):
@@ -138,18 +176,20 @@ def compute_dice(
138176 ignore_empty : bool = True ,
139177 num_classes : int | None = None ,
140178) -> torch .Tensor :
141- """Computes Dice score metric for a batch of predictions.
179+ """
180+ Computes Dice score metric for a batch of predictions. This performs the same computation as
181+ :py:class:`monai.metrics.DiceMetric`, which is preferrable to use over this function. For input formats, see the
182+ documentation for that class .
142183
143184 Args:
144185 y_pred: input data to compute, typical segmentation model output.
145- `y_pred` can be single-channel class indices or in the one-hot format.
146- y: ground truth to compute mean dice metric. `y` can be single-channel class indices or in the one-hot format.
147- include_background: whether to include Dice computation on the first channel of
148- the predicted output. Defaults to True.
149- ignore_empty: whether to ignore empty ground truth cases during calculation.
150- If `True`, NaN value will be set for empty ground truth cases.
151- If `False`, 1 will be set if the predictions of empty ground truth cases are also empty.
152- num_classes: number of input channels (always including the background). When this is None,
186+ y: ground truth to compute mean dice metric.
187+ include_background: whether to include Dice computation on the first channel/category of the prediction and
188+ ground truth. Defaults to ``True``, use ``False`` to exclude the background class.
189+ ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be
190+ set for an empty ground truth cases, otherwise 1 will be set if the predictions of empty ground truth cases
191+ are also empty.
192+ num_classes: number of input channels (always including the background). When this is ``None``,
153193 ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
154194 single-channel class indices and the number of classes is not automatically inferred from data.
155195
@@ -161,16 +201,16 @@ def compute_dice(
161201 include_background = include_background ,
162202 reduction = MetricReduction .NONE ,
163203 get_not_nans = False ,
164- softmax = False ,
204+ apply_argmax = False ,
165205 ignore_empty = ignore_empty ,
166206 num_classes = num_classes ,
167207 )(y_pred = y_pred , y = y )
168208
169209
170210class DiceHelper :
171211 """
172- Compute Dice score between two tensors `y_pred` and `y`.
173- `y_pred` and `y` can be single-channel class indices or in the one-hot format .
212+ Compute Dice score between two tensors `` y_pred`` and ``y``. This is used by :py:class:`monai.metrics.DiceMetric`,
213+ see the documentation for that class for input formats .
174214
175215 Example:
176216
@@ -188,49 +228,65 @@ class DiceHelper:
188228 score, not_nans = DiceHelper(include_background=False, sigmoid=True, softmax=True)(y_pred, y)
189229 print(score, not_nans)
190230
231+ Args:
232+ include_background: whether to include Dice computation on the first channel/category of the prediction and
233+ ground truth. Defaults to ``True``, use ``False`` to exclude the background class.
234+ threshold: if ``True`, ``y_pred`` will be thresholded at a value of 0.5. Defaults to False.
235+ apply_argmax: whether ``y_pred`` are softmax activated outputs. If True, `argmax` will be performed to
236+ get the discrete prediction. Defaults to the value of ``not threshold``.
237+ activate: if this and ``threshold` are ``True``, sigmoid activation is applied to ``y_pred`` before
238+ thresholding. Defaults to False.
239+ get_not_nans: whether to return the number of not-nan values.
240+ reduction: defines mode of reduction to the metrics, this will only apply reduction on `not-nan` values. The
241+ available reduction modes are enumerated by :py:class:`monai.utils.enums.MetricReduction`. If "none", is
242+ selected, the metric will not do reduction.
243+ ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be
244+ set for an empty ground truth cases, otherwise 1 will be set if the predictions of empty ground truth cases
245+ are also empty.
246+ num_classes: number of input channels (always including the background). When this is ``None``,
247+ ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
248+ single-channel class indices and the number of classes is not automatically inferred from data.
191249 """
192250
251+ @deprecated_arg ("softmax" , "1.5" , "1.7" , "Use `apply_argmax` instead." , new_name = "apply_argmax" )
252+ @deprecated_arg ("sigmoid" , "1.5" , "1.7" , "Use `threshold` instead." , new_name = "threshold" )
193253 def __init__ (
194254 self ,
195255 include_background : bool | None = None ,
196- sigmoid : bool = False ,
197- softmax : bool | None = None ,
256+ threshold : bool = False ,
257+ apply_argmax : bool | None = None ,
198258 activate : bool = False ,
199259 get_not_nans : bool = True ,
200260 reduction : MetricReduction | str = MetricReduction .MEAN_BATCH ,
201261 ignore_empty : bool = True ,
202262 num_classes : int | None = None ,
263+ sigmoid : bool | None = None ,
264+ softmax : bool | None = None ,
203265 ) -> None :
204- """
266+ # handling deprecated arguments
267+ if sigmoid is not None :
268+ threshold = sigmoid
269+ if softmax is not None :
270+ apply_argmax = softmax
205271
206- Args:
207- include_background: whether to include the score on the first channel
208- (default to the value of `sigmoid`, False).
209- sigmoid: whether ``y_pred`` are/will be sigmoid activated outputs. If True, thresholding at 0.5
210- will be performed to get the discrete prediction. Defaults to False.
211- softmax: whether ``y_pred`` are softmax activated outputs. If True, `argmax` will be performed to
212- get the discrete prediction. Defaults to the value of ``not sigmoid``.
213- activate: whether to apply sigmoid to ``y_pred`` if ``sigmoid`` is True. Defaults to False.
214- This option is only valid when ``sigmoid`` is True.
215- get_not_nans: whether to return the number of not-nan values.
216- reduction: define mode of reduction to the metrics
217- ignore_empty: if `True`, NaN value will be set for empty ground truth cases.
218- If `False`, 1 will be set if the Union of ``y_pred`` and ``y`` is empty.
219- num_classes: number of input channels (always including the background). When this is None,
220- ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
221- single-channel class indices and the number of classes is not automatically inferred from data.
222- """
223- self .sigmoid = sigmoid
272+ self .threshold = threshold
224273 self .reduction = reduction
225274 self .get_not_nans = get_not_nans
226- self .include_background = sigmoid if include_background is None else include_background
227- self .softmax = not sigmoid if softmax is None else softmax
275+ self .include_background = threshold if include_background is None else include_background
276+ self .apply_argmax = not threshold if apply_argmax is None else apply_argmax
228277 self .activate = activate
229278 self .ignore_empty = ignore_empty
230279 self .num_classes = num_classes
231280
232281 def compute_channel (self , y_pred : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
233- """"""
282+ """
283+ Compute the dice metric for binary inputs which have only spatial dimensions. This method is called separately
284+ for each batch item and for each channel of those items.
285+
286+ Args:
287+ y_pred: input predictions with shape HW[D].
288+ y: ground truth with shape HW[D].
289+ """
234290 y_o = torch .sum (y )
235291 if y_o > 0 :
236292 return (2.0 * torch .sum (torch .masked_select (y , y_pred ))) / (y_o + torch .sum (y_pred ))
@@ -243,25 +299,25 @@ def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor
243299
244300 def __call__ (self , y_pred : torch .Tensor , y : torch .Tensor ) -> torch .Tensor | tuple [torch .Tensor , torch .Tensor ]:
245301 """
302+ Compute the metric for the given prediction and ground truth.
246303
247304 Args:
248305 y_pred: input predictions with shape (batch_size, num_classes or 1, spatial_dims...).
249306 the number of channels is inferred from ``y_pred.shape[1]`` when ``num_classes is None``.
250307 y: ground truth with shape (batch_size, num_classes or 1, spatial_dims...).
251308 """
252- _softmax , _sigmoid = self .softmax , self .sigmoid
309+ _apply_argmax , _threshold = self .apply_argmax , self .threshold
253310 if self .num_classes is None :
254311 n_pred_ch = y_pred .shape [1 ] # y_pred is in one-hot format or multi-channel scores
255312 else :
256313 n_pred_ch = self .num_classes
257314 if y_pred .shape [1 ] == 1 and self .num_classes > 1 : # y_pred is single-channel class indices
258- _softmax = _sigmoid = False
315+ _apply_argmax = _threshold = False
259316
260- if _softmax :
261- if n_pred_ch > 1 :
262- y_pred = torch .argmax (y_pred , dim = 1 , keepdim = True )
317+ if _apply_argmax and n_pred_ch > 1 :
318+ y_pred = torch .argmax (y_pred , dim = 1 , keepdim = True )
263319
264- elif _sigmoid :
320+ elif _threshold :
265321 if self .activate :
266322 y_pred = torch .sigmoid (y_pred )
267323 y_pred = y_pred > 0.5
0 commit comments