Skip to content

Commit 5e066e2

Browse files
committed
Amending documentation and tests
Signed-off-by: Eric Kerfoot <[email protected]>
1 parent b11540b commit 5e066e2

File tree

2 files changed

+78
-45
lines changed

2 files changed

+78
-45
lines changed

monai/metrics/meandice.py

Lines changed: 75 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torch
1515

1616
from monai.metrics.utils import do_metric_reduction
17-
from monai.utils import MetricReduction
17+
from monai.utils import MetricReduction, deprecated_arg
1818

1919
from .metric import CumulativeIterationMetric
2020

@@ -24,36 +24,44 @@
2424
class DiceMetric(CumulativeIterationMetric):
2525
"""
2626
Computes Dice score for a set of pairs of prediction-groundtruth labels. It supports single-channel label maps
27-
or multi-channel images with class segmentations per channel. This allows the computation for both multi-class
28-
and multi-label tasks.
27+
or multi-channel images with class segmentations per channel. This allows the computation for both multi-class
28+
and multi-label tasks.
2929
3030
If either prediction ``y_pred`` or ground truth ``y`` have shape BCHW[D], it is expected that these represent one-
3131
hot segmentations for C number of classes. If either shape is B1HW[D], it is expected that these are label maps
32-
and the number of classes must be specified by the ``num_classes`` parameter. In either case for either inputs,
33-
this metric applies no activations and so non-binary values will produce unexpected results if this metric is used
34-
for binary overlap measurement. Soft labels are thus permitted by this metric.
35-
36-
The ``include_background`` parameter can be set to `False` to exclude the first category (channel index 0) which
37-
is by convention assumed to be background. If the non-background segmentations are small compared to the total
32+
and the number of classes must be specified by the ``num_classes`` parameter. In either case for either inputs,
33+
this metric applies no activations and so non-binary values will produce unexpected results if this metric is used
34+
for binary overlap measurement (ie. either was expected to be one-hot formatted). Soft labels are thus permitted by
35+
this metric. Typically this implies that raw predictions from a network must first be activated and possibly made
36+
into label maps, eg. for a multi-class prediction tensor softmax and then argmax should be applied over the channel
37+
dimensions to produce a label map.
38+
39+
The ``include_background`` parameter can be set to `False` to exclude the first category (channel index 0) which
40+
is by convention assumed to be background. If the non-background segmentations are small compared to the total
3841
image size they can get overwhelmed by the signal from the background. This assumes the shape of both prediction
3942
and ground truth is BCHW[D].
4043
41-
An example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.
44+
The typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.
45+
46+
Further information can be found in the official
47+
`MONAI Dice Overview <https://github.com/Project-MONAI/tutorials/blob/main/modules/dice_loss_metric_notes.ipynb>`.
4248
4349
Example:
4450
4551
.. code-block:: python
4652
4753
import torch
4854
from monai.metrics import DiceMetric
55+
from monai.losses import DiceLoss
56+
from monai.networks import one_hot
4957
5058
batch_size, n_classes, h, w = 7, 5, 128, 128
5159
5260
y_pred = torch.rand(batch_size, n_classes, h, w) # network predictions
5361
y_pred = torch.argmax(y_pred, 1, True) # convert to label map
5462
5563
# ground truth as label map
56-
y = torch.randint(0, n_classes, size=(batch_size, 1, h, w))
64+
y = torch.randint(0, n_classes, size=(batch_size, 1, h, w))
5765
5866
dm = DiceMetric(
5967
reduction="mean_batch", return_with_label=True, num_classes=n_classes
@@ -62,16 +70,22 @@ class DiceMetric(CumulativeIterationMetric):
6270
raw_scores = dm(y_pred, y)
6371
print(dm.aggregate())
6472
73+
# now compute the Dice loss which should be the same as 1 - raw_scores
74+
dl = DiceLoss(to_onehot_y=True, reduction="none")
75+
loss = dl(one_hot(y_pred, n_classes), y).squeeze()
76+
77+
print(1.0 - loss) # same as raw_scores
78+
6579
6680
Args:
6781
include_background: whether to include Dice computation on the first channel/category of the prediction and
68-
ground truth. Defaults to ``True``, use ``False`` to exclude the background class.
82+
ground truth. Defaults to ``True``, use ``False`` to exclude the background class.
6983
reduction: defines mode of reduction to the metrics, this will only apply reduction on `not-nan` values. The
7084
available reduction modes are enumerated by :py:class:`monai.utils.enums.MetricReduction`. If "none", is
7185
selected, the metric will not do reduction.
7286
get_not_nans: whether to return the `not_nans` count. If True, aggregate() returns `(metric, not_nans)` where
7387
`not_nans` counts the number of valid values in the result, and will have the same shape.
74-
ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be
88+
ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be
7589
set for an empty ground truth cases, otherwise 1 will be set if the predictions of empty ground truth cases
7690
are also empty.
7791
num_classes: number of input channels (always including the background). When this is ``None``,
@@ -104,14 +118,14 @@ def __init__(
104118
include_background=self.include_background,
105119
reduction=MetricReduction.NONE,
106120
get_not_nans=False,
107-
softmax=False,
121+
apply_argmax=False,
108122
ignore_empty=self.ignore_empty,
109123
num_classes=self.num_classes,
110124
)
111125

112126
def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override]
113127
"""
114-
Compute the dice value using ``DiceHelper``.
128+
Compute the dice value using ``DiceHelper``.
115129
116130
Args:
117131
y_pred: prediction value, see class docstring for format definition.
@@ -133,7 +147,7 @@ def aggregate(
133147
Execute reduction and aggregation logic for the output of `compute_dice`.
134148
135149
Args:
136-
reduction: defines mode of reduction as enumerated in :py:class:`monai.utils.enums.MetricReduction`.
150+
reduction: defines mode of reduction as enumerated in :py:class:`monai.utils.enums.MetricReduction`.
137151
By default this will do no reduction.
138152
"""
139153
data = self.get_buffer()
@@ -163,15 +177,16 @@ def compute_dice(
163177
num_classes: int | None = None,
164178
) -> torch.Tensor:
165179
"""
166-
Computes Dice score metric for a batch of predictions. This performs the same computation as
167-
:py:class:`monai.metrics.DiceMetric`, see the documentation for that class for input formats.
180+
Computes Dice score metric for a batch of predictions. This performs the same computation as
181+
:py:class:`monai.metrics.DiceMetric`, which is preferrable to use over this function. For input formats, see the
182+
documentation for that class .
168183
169184
Args:
170185
y_pred: input data to compute, typical segmentation model output.
171-
y: ground truth to compute mean dice metric.
186+
y: ground truth to compute mean dice metric.
172187
include_background: whether to include Dice computation on the first channel/category of the prediction and
173-
ground truth. Defaults to ``True``, use ``False`` to exclude the background class.
174-
ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be
188+
ground truth. Defaults to ``True``, use ``False`` to exclude the background class.
189+
ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be
175190
set for an empty ground truth cases, otherwise 1 will be set if the predictions of empty ground truth cases
176191
are also empty.
177192
num_classes: number of input channels (always including the background). When this is ``None``,
@@ -186,16 +201,16 @@ def compute_dice(
186201
include_background=include_background,
187202
reduction=MetricReduction.NONE,
188203
get_not_nans=False,
189-
softmax=False,
204+
apply_argmax=False,
190205
ignore_empty=ignore_empty,
191206
num_classes=num_classes,
192207
)(y_pred=y_pred, y=y)
193208

194209

195210
class DiceHelper:
196211
"""
197-
Compute Dice score between two tensors ``y_pred`` and ``y``. This is used by :py:class:`monai.metrics.DiceMetric`,
198-
see the documentation for that class for input formats.
212+
Compute Dice score between two tensors ``y_pred`` and ``y``. This is used by :py:class:`monai.metrics.DiceMetric`,
213+
see the documentation for that class for input formats.
199214
200215
Example:
201216
@@ -215,45 +230,63 @@ class DiceHelper:
215230
216231
Args:
217232
include_background: whether to include Dice computation on the first channel/category of the prediction and
218-
ground truth. Defaults to ``True``, use ``False`` to exclude the background class.
219-
sigmoid: if ``True`, ``y_pred`` will be thresholded at a value of 0.5. Defaults to False.
220-
softmax: whether ``y_pred`` are softmax activated outputs. If True, `argmax` will be performed to
221-
get the discrete prediction. Defaults to the value of ``not sigmoid``.
222-
activate: if this and ``sigmoid` are ``True``, sigmoid activation is applied to ``y_pred``. Defaults to False.
233+
ground truth. Defaults to ``True``, use ``False`` to exclude the background class.
234+
threshold: if ``True`, ``y_pred`` will be thresholded at a value of 0.5. Defaults to False.
235+
apply_argmax: whether ``y_pred`` are softmax activated outputs. If True, `argmax` will be performed to
236+
get the discrete prediction. Defaults to the value of ``not threshold``.
237+
activate: if this and ``threshold` are ``True``, sigmoid activation is applied to ``y_pred`` before
238+
thresholding. Defaults to False.
223239
get_not_nans: whether to return the number of not-nan values.
224240
reduction: defines mode of reduction to the metrics, this will only apply reduction on `not-nan` values. The
225241
available reduction modes are enumerated by :py:class:`monai.utils.enums.MetricReduction`. If "none", is
226242
selected, the metric will not do reduction.
227-
ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be
243+
ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be
228244
set for an empty ground truth cases, otherwise 1 will be set if the predictions of empty ground truth cases
229245
are also empty.
230246
num_classes: number of input channels (always including the background). When this is ``None``,
231247
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
232248
single-channel class indices and the number of classes is not automatically inferred from data.
233249
"""
234250

251+
@deprecated_arg("softmax", "1.5", "1.7", "Use `apply_argmax` instead.", new_name="apply_argmax")
252+
@deprecated_arg("sigmoid", "1.5", "1.7", "Use `threshold` instead.", new_name="threshold")
235253
def __init__(
236254
self,
237255
include_background: bool | None = None,
238-
sigmoid: bool = False,
239-
softmax: bool | None = None,
256+
threshold: bool = False,
257+
apply_argmax: bool | None = None,
240258
activate: bool = False,
241259
get_not_nans: bool = True,
242260
reduction: MetricReduction | str = MetricReduction.MEAN_BATCH,
243261
ignore_empty: bool = True,
244262
num_classes: int | None = None,
263+
sigmoid: bool | None = None,
264+
softmax: bool | None = None,
245265
) -> None:
246-
self.sigmoid = sigmoid
266+
# handling deprecated arguments
267+
if sigmoid is not None:
268+
threshold = sigmoid
269+
if softmax is not None:
270+
apply_argmax = softmax
271+
272+
self.threshold = threshold
247273
self.reduction = reduction
248274
self.get_not_nans = get_not_nans
249-
self.include_background = sigmoid if include_background is None else include_background
250-
self.softmax = not sigmoid if softmax is None else softmax
275+
self.include_background = threshold if include_background is None else include_background
276+
self.apply_argmax = not threshold if apply_argmax is None else apply_argmax
251277
self.activate = activate
252278
self.ignore_empty = ignore_empty
253279
self.num_classes = num_classes
254280

255281
def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
256-
""""""
282+
"""
283+
Compute the dice metric for binary inputs which have only spatial dimensions. This method is called separately
284+
for each batch item and for each channel of those items.
285+
286+
Args:
287+
y_pred: input predictions with shape HW[D].
288+
y: ground truth with shape HW[D].
289+
"""
257290
y_o = torch.sum(y)
258291
if y_o > 0:
259292
return (2.0 * torch.sum(torch.masked_select(y, y_pred))) / (y_o + torch.sum(y_pred))
@@ -266,25 +299,25 @@ def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor
266299

267300
def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
268301
"""
302+
Compute the metric for the given prediction and ground truth.
269303
270304
Args:
271305
y_pred: input predictions with shape (batch_size, num_classes or 1, spatial_dims...).
272306
the number of channels is inferred from ``y_pred.shape[1]`` when ``num_classes is None``.
273307
y: ground truth with shape (batch_size, num_classes or 1, spatial_dims...).
274308
"""
275-
_softmax, _sigmoid = self.softmax, self.sigmoid
309+
_apply_argmax, _threshold = self.apply_argmax, self.threshold
276310
if self.num_classes is None:
277311
n_pred_ch = y_pred.shape[1] # y_pred is in one-hot format or multi-channel scores
278312
else:
279313
n_pred_ch = self.num_classes
280314
if y_pred.shape[1] == 1 and self.num_classes > 1: # y_pred is single-channel class indices
281-
_softmax = _sigmoid = False
315+
_apply_argmax = _threshold = False
282316

283-
if _softmax:
284-
if n_pred_ch > 1:
285-
y_pred = torch.argmax(y_pred, dim=1, keepdim=True)
317+
if _apply_argmax and n_pred_ch > 1:
318+
y_pred = torch.argmax(y_pred, dim=1, keepdim=True)
286319

287-
elif _sigmoid:
320+
elif _threshold:
288321
if self.activate:
289322
y_pred = torch.sigmoid(y_pred)
290323
y_pred = y_pred > 0.5

tests/metrics/test_compute_meandice.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -267,15 +267,15 @@ def test_nans(self, input_data, expected_value):
267267
@parameterized.expand([TEST_CASE_3])
268268
def test_helper(self, input_data, _unused):
269269
vals = {"y_pred": dict(input_data).pop("y_pred"), "y": dict(input_data).pop("y")}
270-
result = DiceHelper(sigmoid=True)(**vals)
270+
result = DiceHelper(threshold=True)(**vals)
271271
np.testing.assert_allclose(result[0].cpu().numpy(), [0.0, 0.0, 0.0], atol=1e-4)
272272
np.testing.assert_allclose(sorted(result[1].cpu().numpy()), [0.0, 1.0, 2.0], atol=1e-4)
273-
result = DiceHelper(softmax=True, get_not_nans=False)(**vals)
273+
result = DiceHelper(apply_argmax=True, get_not_nans=False)(**vals)
274274
np.testing.assert_allclose(result[0].cpu().numpy(), [0.0, 0.0], atol=1e-4)
275275

276276
num_classes = vals["y_pred"].shape[1]
277277
vals["y_pred"] = torch.argmax(vals["y_pred"], dim=1, keepdim=True)
278-
result = DiceHelper(sigmoid=True, num_classes=num_classes)(**vals)
278+
result = DiceHelper(threshold=True, num_classes=num_classes)(**vals)
279279
np.testing.assert_allclose(result[0].cpu().numpy(), [0.0, 0.0, 0.0], atol=1e-4)
280280

281281
# DiceMetric class tests

0 commit comments

Comments
 (0)