Skip to content

Commit f2f8e34

Browse files
authored
Merge branch 'dev' into dev
2 parents 5eaf79f + 8aef9a9 commit f2f8e34

File tree

2 files changed

+135
-79
lines changed

2 files changed

+135
-79
lines changed

monai/metrics/meandice.py

Lines changed: 132 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torch
1515

1616
from monai.metrics.utils import do_metric_reduction
17-
from monai.utils import MetricReduction
17+
from monai.utils import MetricReduction, deprecated_arg
1818

1919
from .metric import CumulativeIterationMetric
2020

@@ -23,35 +23,76 @@
2323

2424
class DiceMetric(CumulativeIterationMetric):
2525
"""
26-
Compute average Dice score for a set of pairs of prediction-groundtruth segmentations.
26+
Computes Dice score for a set of pairs of prediction-groundtruth labels. It supports single-channel label maps
27+
or multi-channel images with class segmentations per channel. This allows the computation for both multi-class
28+
and multi-label tasks.
2729
28-
It supports both multi-classes and multi-labels tasks.
29-
Input `y_pred` is compared with ground truth `y`.
30-
`y_pred` is expected to have binarized predictions and `y` can be single-channel class indices or in the
31-
one-hot format. The `include_background` parameter can be set to ``False`` to exclude
32-
the first category (channel index 0) which is by convention assumed to be background. If the non-background
33-
segmentations are small compared to the total image size they can get overwhelmed by the signal from the
34-
background. `y_preds` and `y` can be a list of channel-first Tensor (CHW[D]) or a batch-first Tensor (BCHW[D]),
35-
`y` can also be in the format of `B1HW[D]`.
30+
If either prediction ``y_pred`` or ground truth ``y`` have shape BCHW[D], it is expected that these represent one-
31+
hot segmentations for C number of classes. If either shape is B1HW[D], it is expected that these are label maps
32+
and the number of classes must be specified by the ``num_classes`` parameter. In either case for either inputs,
33+
this metric applies no activations and so non-binary values will produce unexpected results if this metric is used
34+
for binary overlap measurement (ie. either was expected to be one-hot formatted). Soft labels are thus permitted by
35+
this metric. Typically this implies that raw predictions from a network must first be activated and possibly made
36+
into label maps, eg. for a multi-class prediction tensor softmax and then argmax should be applied over the channel
37+
dimensions to produce a label map.
38+
39+
The ``include_background`` parameter can be set to `False` to exclude the first category (channel index 0) which
40+
is by convention assumed to be background. If the non-background segmentations are small compared to the total
41+
image size they can get overwhelmed by the signal from the background. This assumes the shape of both prediction
42+
and ground truth is BCHW[D].
43+
44+
The typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.
45+
46+
Further information can be found in the official
47+
`MONAI Dice Overview <https://github.com/Project-MONAI/tutorials/blob/main/modules/dice_loss_metric_notes.ipynb>`.
48+
49+
Example:
50+
51+
.. code-block:: python
52+
53+
import torch
54+
from monai.metrics import DiceMetric
55+
from monai.losses import DiceLoss
56+
from monai.networks import one_hot
57+
58+
batch_size, n_classes, h, w = 7, 5, 128, 128
59+
60+
y_pred = torch.rand(batch_size, n_classes, h, w) # network predictions
61+
y_pred = torch.argmax(y_pred, 1, True) # convert to label map
62+
63+
# ground truth as label map
64+
y = torch.randint(0, n_classes, size=(batch_size, 1, h, w))
65+
66+
dm = DiceMetric(
67+
reduction="mean_batch", return_with_label=True, num_classes=n_classes
68+
)
69+
70+
raw_scores = dm(y_pred, y)
71+
print(dm.aggregate())
72+
73+
# now compute the Dice loss which should be the same as 1 - raw_scores
74+
dl = DiceLoss(to_onehot_y=True, reduction="none")
75+
loss = dl(one_hot(y_pred, n_classes), y).squeeze()
76+
77+
print(1.0 - loss) # same as raw_scores
3678
37-
Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.
3879
3980
Args:
40-
include_background: whether to include Dice computation on the first channel of
41-
the predicted output. Defaults to ``True``.
42-
reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
43-
available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
44-
``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
45-
get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).
46-
Here `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric.
47-
ignore_empty: whether to ignore empty ground truth cases during calculation.
48-
If `True`, NaN value will be set for empty ground truth cases.
49-
If `False`, 1 will be set if the predictions of empty ground truth cases are also empty.
50-
num_classes: number of input channels (always including the background). When this is None,
81+
include_background: whether to include Dice computation on the first channel/category of the prediction and
82+
ground truth. Defaults to ``True``, use ``False`` to exclude the background class.
83+
reduction: defines mode of reduction to the metrics, this will only apply reduction on `not-nan` values. The
84+
available reduction modes are enumerated by :py:class:`monai.utils.enums.MetricReduction`. If "none", is
85+
selected, the metric will not do reduction.
86+
get_not_nans: whether to return the `not_nans` count. If True, aggregate() returns `(metric, not_nans)` where
87+
`not_nans` counts the number of valid values in the result, and will have the same shape.
88+
ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be
89+
set for an empty ground truth cases, otherwise 1 will be set if the predictions of empty ground truth cases
90+
are also empty.
91+
num_classes: number of input channels (always including the background). When this is ``None``,
5192
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
5293
single-channel class indices and the number of classes is not automatically inferred from data.
5394
return_with_label: whether to return the metrics with label, only works when reduction is "mean_batch".
54-
If `True`, use "label_{index}" as the key corresponding to C channels; if 'include_background' is True,
95+
If `True`, use "label_{index}" as the key corresponding to C channels; if ``include_background`` is True,
5596
the index begins at "0", otherwise at "1". It can also take a list of label names.
5697
The outcome will then be returned as a dictionary.
5798
@@ -77,22 +118,21 @@ def __init__(
77118
include_background=self.include_background,
78119
reduction=MetricReduction.NONE,
79120
get_not_nans=False,
80-
softmax=False,
121+
apply_argmax=False,
81122
ignore_empty=self.ignore_empty,
82123
num_classes=self.num_classes,
83124
)
84125

85126
def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override]
86127
"""
128+
Compute the dice value using ``DiceHelper``.
129+
87130
Args:
88-
y_pred: input data to compute, typical segmentation model output.
89-
It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values
90-
should be binarized.
91-
y: ground truth to compute mean Dice metric. `y` can be single-channel class indices or
92-
in the one-hot format.
131+
y_pred: prediction value, see class docstring for format definition.
132+
y: ground truth label.
93133
94134
Raises:
95-
ValueError: when `y_pred` has less than three dimensions.
135+
ValueError: when `y_pred` has fewer than three dimensions.
96136
"""
97137
dims = y_pred.ndimension()
98138
if dims < 3:
@@ -107,10 +147,8 @@ def aggregate(
107147
Execute reduction and aggregation logic for the output of `compute_dice`.
108148
109149
Args:
110-
reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
111-
available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
112-
``"mean_channel"``, ``"sum_channel"``}, default to `self.reduction`. if "none", will not do reduction.
113-
150+
reduction: defines mode of reduction as enumerated in :py:class:`monai.utils.enums.MetricReduction`.
151+
By default this will do no reduction.
114152
"""
115153
data = self.get_buffer()
116154
if not isinstance(data, torch.Tensor):
@@ -138,18 +176,20 @@ def compute_dice(
138176
ignore_empty: bool = True,
139177
num_classes: int | None = None,
140178
) -> torch.Tensor:
141-
"""Computes Dice score metric for a batch of predictions.
179+
"""
180+
Computes Dice score metric for a batch of predictions. This performs the same computation as
181+
:py:class:`monai.metrics.DiceMetric`, which is preferrable to use over this function. For input formats, see the
182+
documentation for that class .
142183
143184
Args:
144185
y_pred: input data to compute, typical segmentation model output.
145-
`y_pred` can be single-channel class indices or in the one-hot format.
146-
y: ground truth to compute mean dice metric. `y` can be single-channel class indices or in the one-hot format.
147-
include_background: whether to include Dice computation on the first channel of
148-
the predicted output. Defaults to True.
149-
ignore_empty: whether to ignore empty ground truth cases during calculation.
150-
If `True`, NaN value will be set for empty ground truth cases.
151-
If `False`, 1 will be set if the predictions of empty ground truth cases are also empty.
152-
num_classes: number of input channels (always including the background). When this is None,
186+
y: ground truth to compute mean dice metric.
187+
include_background: whether to include Dice computation on the first channel/category of the prediction and
188+
ground truth. Defaults to ``True``, use ``False`` to exclude the background class.
189+
ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be
190+
set for an empty ground truth cases, otherwise 1 will be set if the predictions of empty ground truth cases
191+
are also empty.
192+
num_classes: number of input channels (always including the background). When this is ``None``,
153193
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
154194
single-channel class indices and the number of classes is not automatically inferred from data.
155195
@@ -161,16 +201,16 @@ def compute_dice(
161201
include_background=include_background,
162202
reduction=MetricReduction.NONE,
163203
get_not_nans=False,
164-
softmax=False,
204+
apply_argmax=False,
165205
ignore_empty=ignore_empty,
166206
num_classes=num_classes,
167207
)(y_pred=y_pred, y=y)
168208

169209

170210
class DiceHelper:
171211
"""
172-
Compute Dice score between two tensors `y_pred` and `y`.
173-
`y_pred` and `y` can be single-channel class indices or in the one-hot format.
212+
Compute Dice score between two tensors ``y_pred`` and ``y``. This is used by :py:class:`monai.metrics.DiceMetric`,
213+
see the documentation for that class for input formats.
174214
175215
Example:
176216
@@ -188,49 +228,65 @@ class DiceHelper:
188228
score, not_nans = DiceHelper(include_background=False, sigmoid=True, softmax=True)(y_pred, y)
189229
print(score, not_nans)
190230
231+
Args:
232+
include_background: whether to include Dice computation on the first channel/category of the prediction and
233+
ground truth. Defaults to ``True``, use ``False`` to exclude the background class.
234+
threshold: if ``True`, ``y_pred`` will be thresholded at a value of 0.5. Defaults to False.
235+
apply_argmax: whether ``y_pred`` are softmax activated outputs. If True, `argmax` will be performed to
236+
get the discrete prediction. Defaults to the value of ``not threshold``.
237+
activate: if this and ``threshold` are ``True``, sigmoid activation is applied to ``y_pred`` before
238+
thresholding. Defaults to False.
239+
get_not_nans: whether to return the number of not-nan values.
240+
reduction: defines mode of reduction to the metrics, this will only apply reduction on `not-nan` values. The
241+
available reduction modes are enumerated by :py:class:`monai.utils.enums.MetricReduction`. If "none", is
242+
selected, the metric will not do reduction.
243+
ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be
244+
set for an empty ground truth cases, otherwise 1 will be set if the predictions of empty ground truth cases
245+
are also empty.
246+
num_classes: number of input channels (always including the background). When this is ``None``,
247+
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
248+
single-channel class indices and the number of classes is not automatically inferred from data.
191249
"""
192250

251+
@deprecated_arg("softmax", "1.5", "1.7", "Use `apply_argmax` instead.", new_name="apply_argmax")
252+
@deprecated_arg("sigmoid", "1.5", "1.7", "Use `threshold` instead.", new_name="threshold")
193253
def __init__(
194254
self,
195255
include_background: bool | None = None,
196-
sigmoid: bool = False,
197-
softmax: bool | None = None,
256+
threshold: bool = False,
257+
apply_argmax: bool | None = None,
198258
activate: bool = False,
199259
get_not_nans: bool = True,
200260
reduction: MetricReduction | str = MetricReduction.MEAN_BATCH,
201261
ignore_empty: bool = True,
202262
num_classes: int | None = None,
263+
sigmoid: bool | None = None,
264+
softmax: bool | None = None,
203265
) -> None:
204-
"""
266+
# handling deprecated arguments
267+
if sigmoid is not None:
268+
threshold = sigmoid
269+
if softmax is not None:
270+
apply_argmax = softmax
205271

206-
Args:
207-
include_background: whether to include the score on the first channel
208-
(default to the value of `sigmoid`, False).
209-
sigmoid: whether ``y_pred`` are/will be sigmoid activated outputs. If True, thresholding at 0.5
210-
will be performed to get the discrete prediction. Defaults to False.
211-
softmax: whether ``y_pred`` are softmax activated outputs. If True, `argmax` will be performed to
212-
get the discrete prediction. Defaults to the value of ``not sigmoid``.
213-
activate: whether to apply sigmoid to ``y_pred`` if ``sigmoid`` is True. Defaults to False.
214-
This option is only valid when ``sigmoid`` is True.
215-
get_not_nans: whether to return the number of not-nan values.
216-
reduction: define mode of reduction to the metrics
217-
ignore_empty: if `True`, NaN value will be set for empty ground truth cases.
218-
If `False`, 1 will be set if the Union of ``y_pred`` and ``y`` is empty.
219-
num_classes: number of input channels (always including the background). When this is None,
220-
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
221-
single-channel class indices and the number of classes is not automatically inferred from data.
222-
"""
223-
self.sigmoid = sigmoid
272+
self.threshold = threshold
224273
self.reduction = reduction
225274
self.get_not_nans = get_not_nans
226-
self.include_background = sigmoid if include_background is None else include_background
227-
self.softmax = not sigmoid if softmax is None else softmax
275+
self.include_background = threshold if include_background is None else include_background
276+
self.apply_argmax = not threshold if apply_argmax is None else apply_argmax
228277
self.activate = activate
229278
self.ignore_empty = ignore_empty
230279
self.num_classes = num_classes
231280

232281
def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
233-
""""""
282+
"""
283+
Compute the dice metric for binary inputs which have only spatial dimensions. This method is called separately
284+
for each batch item and for each channel of those items.
285+
286+
Args:
287+
y_pred: input predictions with shape HW[D].
288+
y: ground truth with shape HW[D].
289+
"""
234290
y_o = torch.sum(y)
235291
if y_o > 0:
236292
return (2.0 * torch.sum(torch.masked_select(y, y_pred))) / (y_o + torch.sum(y_pred))
@@ -243,25 +299,25 @@ def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor
243299

244300
def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
245301
"""
302+
Compute the metric for the given prediction and ground truth.
246303
247304
Args:
248305
y_pred: input predictions with shape (batch_size, num_classes or 1, spatial_dims...).
249306
the number of channels is inferred from ``y_pred.shape[1]`` when ``num_classes is None``.
250307
y: ground truth with shape (batch_size, num_classes or 1, spatial_dims...).
251308
"""
252-
_softmax, _sigmoid = self.softmax, self.sigmoid
309+
_apply_argmax, _threshold = self.apply_argmax, self.threshold
253310
if self.num_classes is None:
254311
n_pred_ch = y_pred.shape[1] # y_pred is in one-hot format or multi-channel scores
255312
else:
256313
n_pred_ch = self.num_classes
257314
if y_pred.shape[1] == 1 and self.num_classes > 1: # y_pred is single-channel class indices
258-
_softmax = _sigmoid = False
315+
_apply_argmax = _threshold = False
259316

260-
if _softmax:
261-
if n_pred_ch > 1:
262-
y_pred = torch.argmax(y_pred, dim=1, keepdim=True)
317+
if _apply_argmax and n_pred_ch > 1:
318+
y_pred = torch.argmax(y_pred, dim=1, keepdim=True)
263319

264-
elif _sigmoid:
320+
elif _threshold:
265321
if self.activate:
266322
y_pred = torch.sigmoid(y_pred)
267323
y_pred = y_pred > 0.5

tests/metrics/test_compute_meandice.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -267,15 +267,15 @@ def test_nans(self, input_data, expected_value):
267267
@parameterized.expand([TEST_CASE_3])
268268
def test_helper(self, input_data, _unused):
269269
vals = {"y_pred": dict(input_data).pop("y_pred"), "y": dict(input_data).pop("y")}
270-
result = DiceHelper(sigmoid=True)(**vals)
270+
result = DiceHelper(threshold=True)(**vals)
271271
np.testing.assert_allclose(result[0].cpu().numpy(), [0.0, 0.0, 0.0], atol=1e-4)
272272
np.testing.assert_allclose(sorted(result[1].cpu().numpy()), [0.0, 1.0, 2.0], atol=1e-4)
273-
result = DiceHelper(softmax=True, get_not_nans=False)(**vals)
273+
result = DiceHelper(apply_argmax=True, get_not_nans=False)(**vals)
274274
np.testing.assert_allclose(result[0].cpu().numpy(), [0.0, 0.0], atol=1e-4)
275275

276276
num_classes = vals["y_pred"].shape[1]
277277
vals["y_pred"] = torch.argmax(vals["y_pred"], dim=1, keepdim=True)
278-
result = DiceHelper(sigmoid=True, num_classes=num_classes)(**vals)
278+
result = DiceHelper(threshold=True, num_classes=num_classes)(**vals)
279279
np.testing.assert_allclose(result[0].cpu().numpy(), [0.0, 0.0, 0.0], atol=1e-4)
280280

281281
# DiceMetric class tests

0 commit comments

Comments
 (0)