Skip to content

Commit 4445bad

Browse files
authored
Merge branch 'dev' into 8085-r2-score
2 parents 10d2423 + d98f348 commit 4445bad

21 files changed

+685
-43
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@
1818

1919
MONAI is a [PyTorch](https://pytorch.org/)-based, [open-source](https://github.com/Project-MONAI/MONAI/blob/dev/LICENSE) framework for deep learning in healthcare imaging, part of the [PyTorch Ecosystem](https://pytorch.org/ecosystem/).
2020
Its ambitions are as follows:
21+
2122
- Developing a community of academic, industrial and clinical researchers collaborating on a common foundation;
2223
- Creating state-of-the-art, end-to-end training workflows for healthcare imaging;
2324
- Providing researchers with the optimized and standardized way to create and evaluate deep learning models.
2425

25-
2626
## Features
27+
2728
> _Please see [the technical highlights](https://docs.monai.io/en/latest/highlights.html) and [What's New](https://docs.monai.io/en/latest/whatsnew.html) of the milestone releases._
2829
2930
- flexible pre-processing for multi-dimensional medical imaging data;

docs/source/handlers.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@ ROC AUC metrics handler
5353
:members:
5454

5555

56+
Average Precision metric handler
57+
--------------------------------
58+
.. autoclass:: AveragePrecision
59+
:members:
60+
61+
5662
Confusion matrix metrics handler
5763
--------------------------------
5864
.. autoclass:: ConfusionMatrix

docs/source/metrics.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,13 @@ Metrics
8080
.. autoclass:: ROCAUCMetric
8181
:members:
8282

83+
`Average Precision`
84+
-------------------
85+
.. autofunction:: compute_average_precision
86+
87+
.. autoclass:: AveragePrecisionMetric
88+
:members:
89+
8390
`Confusion matrix`
8491
------------------
8592
.. autofunction:: get_confusion_matrix

monai/handlers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from __future__ import annotations
1313

14+
from .average_precision import AveragePrecision
1415
from .checkpoint_loader import CheckpointLoader
1516
from .checkpoint_saver import CheckpointSaver
1617
from .classification_saver import ClassificationSaver
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
from collections.abc import Callable
15+
16+
from monai.handlers.ignite_metric import IgniteMetricHandler
17+
from monai.metrics import AveragePrecisionMetric
18+
from monai.utils import Average
19+
20+
21+
class AveragePrecision(IgniteMetricHandler):
22+
"""
23+
Computes Average Precision (AP).
24+
accumulating predictions and the ground-truth during an epoch and applying `compute_average_precision`.
25+
26+
Args:
27+
average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
28+
Type of averaging performed if not binary classification. Defaults to ``"macro"``.
29+
30+
- ``"macro"``: calculate metrics for each label, and find their unweighted mean.
31+
This does not take label imbalance into account.
32+
- ``"weighted"``: calculate metrics for each label, and find their average,
33+
weighted by support (the number of true instances for each label).
34+
- ``"micro"``: calculate metrics globally by considering each element of the label
35+
indicator matrix as a label.
36+
- ``"none"``: the scores for each class are returned.
37+
38+
output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then
39+
construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or
40+
lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`.
41+
`engine.state` and `output_transform` inherit from the ignite concept:
42+
https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:
43+
https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.
44+
45+
Note:
46+
Average Precision expects y to be comprised of 0's and 1's.
47+
y_pred must either be probability estimates or confidence values.
48+
49+
"""
50+
51+
def __init__(self, average: Average | str = Average.MACRO, output_transform: Callable = lambda x: x) -> None:
52+
metric_fn = AveragePrecisionMetric(average=Average(average))
53+
super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=False)

monai/inferers/inferer.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1202,15 +1202,16 @@ def sample( # type: ignore[override]
12021202

12031203
if self.autoencoder_latent_shape is not None:
12041204
latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
1205-
latent_intermediates = [
1206-
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates
1207-
]
1205+
if save_intermediates:
1206+
latent_intermediates = [
1207+
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0)
1208+
for l in latent_intermediates
1209+
]
12081210

12091211
decode = autoencoder_model.decode_stage_2_outputs
12101212
if isinstance(autoencoder_model, SPADEAutoencoderKL):
12111213
decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg)
12121214
image = decode(latent / self.scale_factor)
1213-
12141215
if save_intermediates:
12151216
intermediates = []
12161217
for latent_intermediate in latent_intermediates:
@@ -1727,9 +1728,11 @@ def sample( # type: ignore[override]
17271728

17281729
if self.autoencoder_latent_shape is not None:
17291730
latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
1730-
latent_intermediates = [
1731-
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates
1732-
]
1731+
if save_intermediates:
1732+
latent_intermediates = [
1733+
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0)
1734+
for l in latent_intermediates
1735+
]
17331736

17341737
decode = autoencoder_model.decode_stage_2_outputs
17351738
if isinstance(autoencoder_model, SPADEAutoencoderKL):

monai/metrics/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from __future__ import annotations
1313

1414
from .active_learning_metrics import LabelQualityScore, VarianceMetric, compute_variance, label_quality_score
15+
from .average_precision import AveragePrecisionMetric, compute_average_precision
1516
from .confusion_matrix import ConfusionMatrixMetric, compute_confusion_matrix_metric, get_confusion_matrix
1617
from .cumulative_average import CumulativeAverage
1718
from .f_beta_score import FBetaScore

monai/metrics/average_precision.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import warnings
15+
from typing import TYPE_CHECKING, cast
16+
17+
import numpy as np
18+
19+
if TYPE_CHECKING:
20+
import numpy.typing as npt
21+
22+
import torch
23+
24+
from monai.utils import Average, look_up_option
25+
26+
from .metric import CumulativeIterationMetric
27+
28+
29+
class AveragePrecisionMetric(CumulativeIterationMetric):
30+
"""
31+
Computes Average Precision (AP). AP is a useful metric to evaluate a classifier when the classes are
32+
imbalanced. It can take values between 0.0 and 1.0, 1.0 being the best possible score.
33+
It summarizes a Precision-Recall curve as the weighted mean of precisions achieved at each
34+
threshold, with the increase in recall from the previous threshold used as the weight:
35+
36+
.. math::
37+
\\text{AP} = \\sum_n (R_n - R_{n-1}) P_n
38+
:label: ap
39+
40+
where :math:`P_n` and :math:`R_n` are the precision and recall at the :math:`n^{th}` threshold.
41+
42+
Referring to: `sklearn.metrics.average_precision_score
43+
<https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score>`_.
44+
45+
The input `y_pred` and `y` can be a list of `channel-first` Tensor or a `batch-first` Tensor.
46+
47+
Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.
48+
49+
Args:
50+
average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
51+
Type of averaging performed if not binary classification.
52+
Defaults to ``"macro"``.
53+
54+
- ``"macro"``: calculate metrics for each label, and find their unweighted mean.
55+
This does not take label imbalance into account.
56+
- ``"weighted"``: calculate metrics for each label, and find their average,
57+
weighted by support (the number of true instances for each label).
58+
- ``"micro"``: calculate metrics globally by considering each element of the label
59+
indicator matrix as a label.
60+
- ``"none"``: the scores for each class are returned.
61+
62+
"""
63+
64+
def __init__(self, average: Average | str = Average.MACRO) -> None:
65+
super().__init__()
66+
self.average = average
67+
68+
def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: # type: ignore[override]
69+
return y_pred, y
70+
71+
def aggregate(self, average: Average | str | None = None) -> np.ndarray | float | npt.ArrayLike:
72+
"""
73+
Typically `y_pred` and `y` are stored in the cumulative buffers at each iteration,
74+
This function reads the buffers and computes the Average Precision.
75+
76+
Args:
77+
average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
78+
Type of averaging performed if not binary classification. Defaults to `self.average`.
79+
80+
"""
81+
y_pred, y = self.get_buffer()
82+
# compute final value and do metric reduction
83+
if not isinstance(y_pred, torch.Tensor) or not isinstance(y, torch.Tensor):
84+
raise ValueError("y_pred and y must be PyTorch Tensor.")
85+
86+
return compute_average_precision(y_pred=y_pred, y=y, average=average or self.average)
87+
88+
89+
def _calculate(y_pred: torch.Tensor, y: torch.Tensor) -> float:
90+
if not (y.ndimension() == y_pred.ndimension() == 1 and len(y) == len(y_pred)):
91+
raise AssertionError("y and y_pred must be 1 dimension data with same length.")
92+
y_unique = y.unique()
93+
if len(y_unique) == 1:
94+
warnings.warn(f"y values can not be all {y_unique.item()}, skip AP computation and return `Nan`.")
95+
return float("nan")
96+
if not y_unique.equal(torch.tensor([0, 1], dtype=y.dtype, device=y.device)):
97+
warnings.warn(f"y values must be 0 or 1, but in {y_unique.tolist()}, skip AP computation and return `Nan`.")
98+
return float("nan")
99+
100+
n = len(y)
101+
indices = y_pred.argsort(descending=True)
102+
y = y[indices].cpu().numpy() # type: ignore[assignment]
103+
y_pred = y_pred[indices].cpu().numpy() # type: ignore[assignment]
104+
npos = ap = tmp_pos = 0.0
105+
106+
for i in range(n):
107+
y_i = cast(float, y[i])
108+
if i + 1 < n and y_pred[i] == y_pred[i + 1]:
109+
tmp_pos += y_i
110+
else:
111+
tmp_pos += y_i
112+
npos += tmp_pos
113+
ap += tmp_pos * npos / (i + 1)
114+
tmp_pos = 0
115+
116+
return ap / npos
117+
118+
119+
def compute_average_precision(
120+
y_pred: torch.Tensor, y: torch.Tensor, average: Average | str = Average.MACRO
121+
) -> np.ndarray | float | npt.ArrayLike:
122+
"""Computes Average Precision (AP). AP is a useful metric to evaluate a classifier when the classes are
123+
imbalanced. It summarizes a Precision-Recall according to equation :eq:`ap`.
124+
Referring to: `sklearn.metrics.average_precision_score
125+
<https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score>`_.
126+
127+
Args:
128+
y_pred: input data to compute, typical classification model output.
129+
the first dim must be batch, if multi-classes, it must be in One-Hot format.
130+
for example: shape `[16]` or `[16, 1]` for a binary data, shape `[16, 2]` for 2 classes data.
131+
y: ground truth to compute AP metric, the first dim must be batch.
132+
if multi-classes, it must be in One-Hot format.
133+
for example: shape `[16]` or `[16, 1]` for a binary data, shape `[16, 2]` for 2 classes data.
134+
average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
135+
Type of averaging performed if not binary classification.
136+
Defaults to ``"macro"``.
137+
138+
- ``"macro"``: calculate metrics for each label, and find their unweighted mean.
139+
This does not take label imbalance into account.
140+
- ``"weighted"``: calculate metrics for each label, and find their average,
141+
weighted by support (the number of true instances for each label).
142+
- ``"micro"``: calculate metrics globally by considering each element of the label
143+
indicator matrix as a label.
144+
- ``"none"``: the scores for each class are returned.
145+
146+
Raises:
147+
ValueError: When ``y_pred`` dimension is not one of [1, 2].
148+
ValueError: When ``y`` dimension is not one of [1, 2].
149+
ValueError: When ``average`` is not one of ["macro", "weighted", "micro", "none"].
150+
151+
Note:
152+
Average Precision expects y to be comprised of 0's and 1's. `y_pred` must be either prob. estimates or confidence values.
153+
154+
"""
155+
y_pred_ndim = y_pred.ndimension()
156+
y_ndim = y.ndimension()
157+
if y_pred_ndim not in (1, 2):
158+
raise ValueError(
159+
f"Predictions should be of shape (batch_size, num_classes) or (batch_size, ), got {y_pred.shape}."
160+
)
161+
if y_ndim not in (1, 2):
162+
raise ValueError(f"Targets should be of shape (batch_size, num_classes) or (batch_size, ), got {y.shape}.")
163+
if y_pred_ndim == 2 and y_pred.shape[1] == 1:
164+
y_pred = y_pred.squeeze(dim=-1)
165+
y_pred_ndim = 1
166+
if y_ndim == 2 and y.shape[1] == 1:
167+
y = y.squeeze(dim=-1)
168+
169+
if y_pred_ndim == 1:
170+
return _calculate(y_pred, y)
171+
172+
if y.shape != y_pred.shape:
173+
raise ValueError(f"data shapes of y_pred and y do not match, got {y_pred.shape} and {y.shape}.")
174+
175+
average = look_up_option(average, Average)
176+
if average == Average.MICRO:
177+
return _calculate(y_pred.flatten(), y.flatten())
178+
y, y_pred = y.transpose(0, 1), y_pred.transpose(0, 1)
179+
ap_values = [_calculate(y_pred_, y_) for y_pred_, y_ in zip(y_pred, y)]
180+
if average == Average.NONE:
181+
return ap_values
182+
if average == Average.MACRO:
183+
return np.mean(ap_values)
184+
if average == Average.WEIGHTED:
185+
weights = [sum(y_) for y_ in y]
186+
return np.average(ap_values, weights=weights) # type: ignore[no-any-return]
187+
raise ValueError(f'Unsupported average: {average}, available options are ["macro", "weighted", "micro", "none"].')

monai/transforms/utility/array.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@
6666
optional_import,
6767
)
6868
from monai.utils.enums import TransformBackends
69-
from monai.utils.misc import is_module_ver_at_least
7069
from monai.utils.type_conversion import convert_to_dst_type, get_dtype_string, get_equivalent_dtype
7170

7271
PILImageImage, has_pil = optional_import("PIL.Image", name="Image")
@@ -939,19 +938,10 @@ def __call__(
939938
data = img[[*select_labels]]
940939
else:
941940
where: Callable = np.where if isinstance(img, np.ndarray) else torch.where # type: ignore
942-
if isinstance(img, np.ndarray) or is_module_ver_at_least(torch, (1, 8, 0)):
943-
data = where(in1d(img, select_labels), True, False).reshape(img.shape)
944-
# pre pytorch 1.8.0, need to use 1/0 instead of True/False
945-
else:
946-
data = where(
947-
in1d(img, select_labels), torch.tensor(1, device=img.device), torch.tensor(0, device=img.device)
948-
).reshape(img.shape)
941+
data = where(in1d(img, select_labels), True, False).reshape(img.shape)
949942

950943
if merge_channels or self.merge_channels:
951-
if isinstance(img, np.ndarray) or is_module_ver_at_least(torch, (1, 8, 0)):
952-
return data.any(0)[None]
953-
# pre pytorch 1.8.0 compatibility
954-
return data.to(torch.uint8).any(0)[None].to(bool) # type: ignore
944+
return data.any(0)[None]
955945

956946
return data
957947

monai/transforms/utils_pytorch_numpy_unification.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import torch
1919

2020
from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor
21-
from monai.utils.misc import is_module_ver_at_least
2221
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type
2322

2423
__all__ = [
@@ -215,10 +214,9 @@ def floor_divide(a: NdarrayOrTensor, b) -> NdarrayOrTensor:
215214
Element-wise floor division between two arrays/tensors.
216215
"""
217216
if isinstance(a, torch.Tensor):
218-
if is_module_ver_at_least(torch, (1, 8, 0)):
219-
return torch.div(a, b, rounding_mode="floor")
220217
return torch.floor_divide(a, b)
221-
return np.floor_divide(a, b)
218+
else:
219+
return np.floor_divide(a, b)
222220

223221

224222
def unravel_index(idx, shape) -> NdarrayOrTensor:

0 commit comments

Comments
 (0)