Skip to content

Commit 97c98bd

Browse files
Merge branch 'Project-MONAI:dev' into 8328-nnunet-bundle-integration
2 parents ea8028f + ab07523 commit 97c98bd

21 files changed

+741
-69
lines changed

README.md

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@
1818

1919
MONAI is a [PyTorch](https://pytorch.org/)-based, [open-source](https://github.com/Project-MONAI/MONAI/blob/dev/LICENSE) framework for deep learning in healthcare imaging, part of the [PyTorch Ecosystem](https://pytorch.org/ecosystem/).
2020
Its ambitions are as follows:
21+
2122
- Developing a community of academic, industrial and clinical researchers collaborating on a common foundation;
2223
- Creating state-of-the-art, end-to-end training workflows for healthcare imaging;
2324
- Providing researchers with the optimized and standardized way to create and evaluate deep learning models.
2425

25-
2626
## Features
27+
2728
> _Please see [the technical highlights](https://docs.monai.io/en/latest/highlights.html) and [What's New](https://docs.monai.io/en/latest/whatsnew.html) of the milestone releases._
2829
2930
- flexible pre-processing for multi-dimensional medical imaging data;
@@ -32,7 +33,6 @@ Its ambitions are as follows:
3233
- customizable design for varying user expertise;
3334
- multi-GPU multi-node data parallelism support.
3435

35-
3636
## Installation
3737

3838
To install [the current release](https://pypi.org/project/monai/), you can simply run:
@@ -53,30 +53,34 @@ Technical documentation is available at [docs.monai.io](https://docs.monai.io).
5353

5454
## Citation
5555

56-
If you have used MONAI in your research, please cite us! The citation can be exported from: https://arxiv.org/abs/2211.02701.
56+
If you have used MONAI in your research, please cite us! The citation can be exported from: <https://arxiv.org/abs/2211.02701>.
5757

5858
## Model Zoo
59+
5960
[The MONAI Model Zoo](https://github.com/Project-MONAI/model-zoo) is a place for researchers and data scientists to share the latest and great models from the community.
6061
Utilizing [the MONAI Bundle format](https://docs.monai.io/en/latest/bundle_intro.html) makes it easy to [get started](https://github.com/Project-MONAI/tutorials/tree/main/model_zoo) building workflows with MONAI.
6162

6263
## Contributing
64+
6365
For guidance on making a contribution to MONAI, see the [contributing guidelines](https://github.com/Project-MONAI/MONAI/blob/dev/CONTRIBUTING.md).
6466

6567
## Community
68+
6669
Join the conversation on Twitter/X [@ProjectMONAI](https://twitter.com/ProjectMONAI) or join our [Slack channel](https://forms.gle/QTxJq3hFictp31UM9).
6770

6871
Ask and answer questions over on [MONAI's GitHub Discussions tab](https://github.com/Project-MONAI/MONAI/discussions).
6972

7073
## Links
71-
- Website: https://monai.io/
72-
- API documentation (milestone): https://docs.monai.io/
73-
- API documentation (latest dev): https://docs.monai.io/en/latest/
74-
- Code: https://github.com/Project-MONAI/MONAI
75-
- Project tracker: https://github.com/Project-MONAI/MONAI/projects
76-
- Issue tracker: https://github.com/Project-MONAI/MONAI/issues
77-
- Wiki: https://github.com/Project-MONAI/MONAI/wiki
78-
- Test status: https://github.com/Project-MONAI/MONAI/actions
79-
- PyPI package: https://pypi.org/project/monai/
80-
- conda-forge: https://anaconda.org/conda-forge/monai
81-
- Weekly previews: https://pypi.org/project/monai-weekly/
82-
- Docker Hub: https://hub.docker.com/r/projectmonai/monai
74+
75+
- Website: <https://monai.io/>
76+
- API documentation (milestone): <https://docs.monai.io/>
77+
- API documentation (latest dev): <https://docs.monai.io/en/latest/>
78+
- Code: <https://github.com/Project-MONAI/MONAI>
79+
- Project tracker: <https://github.com/Project-MONAI/MONAI/projects>
80+
- Issue tracker: <https://github.com/Project-MONAI/MONAI/issues>
81+
- Wiki: <https://github.com/Project-MONAI/MONAI/wiki>
82+
- Test status: <https://github.com/Project-MONAI/MONAI/actions>
83+
- PyPI package: <https://pypi.org/project/monai/>
84+
- conda-forge: <https://anaconda.org/conda-forge/monai>
85+
- Weekly previews: <https://pypi.org/project/monai-weekly/>
86+
- Docker Hub: <https://hub.docker.com/r/projectmonai/monai>

docs/source/handlers.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@ ROC AUC metrics handler
5353
:members:
5454

5555

56+
Average Precision metric handler
57+
--------------------------------
58+
.. autoclass:: AveragePrecision
59+
:members:
60+
61+
5662
Confusion matrix metrics handler
5763
--------------------------------
5864
.. autoclass:: ConfusionMatrix

docs/source/metrics.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,13 @@ Metrics
8080
.. autoclass:: ROCAUCMetric
8181
:members:
8282

83+
`Average Precision`
84+
-------------------
85+
.. autofunction:: compute_average_precision
86+
87+
.. autoclass:: AveragePrecisionMetric
88+
:members:
89+
8390
`Confusion matrix`
8491
------------------
8592
.. autofunction:: get_confusion_matrix

monai/handlers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from __future__ import annotations
1313

14+
from .average_precision import AveragePrecision
1415
from .checkpoint_loader import CheckpointLoader
1516
from .checkpoint_saver import CheckpointSaver
1617
from .classification_saver import ClassificationSaver
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
from collections.abc import Callable
15+
16+
from monai.handlers.ignite_metric import IgniteMetricHandler
17+
from monai.metrics import AveragePrecisionMetric
18+
from monai.utils import Average
19+
20+
21+
class AveragePrecision(IgniteMetricHandler):
22+
"""
23+
Computes Average Precision (AP).
24+
accumulating predictions and the ground-truth during an epoch and applying `compute_average_precision`.
25+
26+
Args:
27+
average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
28+
Type of averaging performed if not binary classification. Defaults to ``"macro"``.
29+
30+
- ``"macro"``: calculate metrics for each label, and find their unweighted mean.
31+
This does not take label imbalance into account.
32+
- ``"weighted"``: calculate metrics for each label, and find their average,
33+
weighted by support (the number of true instances for each label).
34+
- ``"micro"``: calculate metrics globally by considering each element of the label
35+
indicator matrix as a label.
36+
- ``"none"``: the scores for each class are returned.
37+
38+
output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then
39+
construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or
40+
lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`.
41+
`engine.state` and `output_transform` inherit from the ignite concept:
42+
https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:
43+
https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.
44+
45+
Note:
46+
Average Precision expects y to be comprised of 0's and 1's.
47+
y_pred must either be probability estimates or confidence values.
48+
49+
"""
50+
51+
def __init__(self, average: Average | str = Average.MACRO, output_transform: Callable = lambda x: x) -> None:
52+
metric_fn = AveragePrecisionMetric(average=Average(average))
53+
super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=False)

monai/inferers/inferer.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1202,15 +1202,16 @@ def sample( # type: ignore[override]
12021202

12031203
if self.autoencoder_latent_shape is not None:
12041204
latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
1205-
latent_intermediates = [
1206-
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates
1207-
]
1205+
if save_intermediates:
1206+
latent_intermediates = [
1207+
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0)
1208+
for l in latent_intermediates
1209+
]
12081210

12091211
decode = autoencoder_model.decode_stage_2_outputs
12101212
if isinstance(autoencoder_model, SPADEAutoencoderKL):
12111213
decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg)
12121214
image = decode(latent / self.scale_factor)
1213-
12141215
if save_intermediates:
12151216
intermediates = []
12161217
for latent_intermediate in latent_intermediates:
@@ -1333,13 +1334,15 @@ def __call__( # type: ignore[override]
13331334
raise NotImplementedError(f"{mode} condition is not supported")
13341335

13351336
noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
1336-
down_block_res_samples, mid_block_res_sample = controlnet(
1337-
x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond
1338-
)
1337+
13391338
if mode == "concat" and condition is not None:
13401339
noisy_image = torch.cat([noisy_image, condition], dim=1)
13411340
condition = None
13421341

1342+
down_block_res_samples, mid_block_res_sample = controlnet(
1343+
x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond, context=condition
1344+
)
1345+
13431346
diffuse = diffusion_model
13441347
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
13451348
diffuse = partial(diffusion_model, seg=seg)
@@ -1395,17 +1398,21 @@ def sample( # type: ignore[override]
13951398
progress_bar = iter(scheduler.timesteps)
13961399
intermediates = []
13971400
for t in progress_bar:
1398-
# 1. ControlNet forward
1399-
down_block_res_samples, mid_block_res_sample = controlnet(
1400-
x=image, timesteps=torch.Tensor((t,)).to(input_noise.device), controlnet_cond=cn_cond
1401-
)
1402-
# 2. predict noise model_output
14031401
diffuse = diffusion_model
14041402
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
14051403
diffuse = partial(diffusion_model, seg=seg)
14061404

14071405
if mode == "concat" and conditioning is not None:
1406+
# 1. Conditioning
14081407
model_input = torch.cat([image, conditioning], dim=1)
1408+
# 2. ControlNet forward
1409+
down_block_res_samples, mid_block_res_sample = controlnet(
1410+
x=model_input,
1411+
timesteps=torch.Tensor((t,)).to(input_noise.device),
1412+
controlnet_cond=cn_cond,
1413+
context=None,
1414+
)
1415+
# 3. predict noise model_output
14091416
model_output = diffuse(
14101417
model_input,
14111418
timesteps=torch.Tensor((t,)).to(input_noise.device),
@@ -1414,6 +1421,12 @@ def sample( # type: ignore[override]
14141421
mid_block_additional_residual=mid_block_res_sample,
14151422
)
14161423
else:
1424+
down_block_res_samples, mid_block_res_sample = controlnet(
1425+
x=image,
1426+
timesteps=torch.Tensor((t,)).to(input_noise.device),
1427+
controlnet_cond=cn_cond,
1428+
context=conditioning,
1429+
)
14171430
model_output = diffuse(
14181431
image,
14191432
timesteps=torch.Tensor((t,)).to(input_noise.device),
@@ -1484,16 +1497,16 @@ def get_likelihood( # type: ignore[override]
14841497
for t in progress_bar:
14851498
timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long()
14861499
noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
1487-
down_block_res_samples, mid_block_res_sample = controlnet(
1488-
x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond
1489-
)
14901500

14911501
diffuse = diffusion_model
14921502
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
14931503
diffuse = partial(diffusion_model, seg=seg)
14941504

14951505
if mode == "concat" and conditioning is not None:
14961506
noisy_image = torch.cat([noisy_image, conditioning], dim=1)
1507+
down_block_res_samples, mid_block_res_sample = controlnet(
1508+
x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond, context=None
1509+
)
14971510
model_output = diffuse(
14981511
noisy_image,
14991512
timesteps=timesteps,
@@ -1502,6 +1515,12 @@ def get_likelihood( # type: ignore[override]
15021515
mid_block_additional_residual=mid_block_res_sample,
15031516
)
15041517
else:
1518+
down_block_res_samples, mid_block_res_sample = controlnet(
1519+
x=noisy_image,
1520+
timesteps=torch.Tensor((t,)).to(inputs.device),
1521+
controlnet_cond=cn_cond,
1522+
context=conditioning,
1523+
)
15051524
model_output = diffuse(
15061525
x=noisy_image,
15071526
timesteps=timesteps,
@@ -1727,9 +1746,11 @@ def sample( # type: ignore[override]
17271746

17281747
if self.autoencoder_latent_shape is not None:
17291748
latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
1730-
latent_intermediates = [
1731-
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates
1732-
]
1749+
if save_intermediates:
1750+
latent_intermediates = [
1751+
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0)
1752+
for l in latent_intermediates
1753+
]
17331754

17341755
decode = autoencoder_model.decode_stage_2_outputs
17351756
if isinstance(autoencoder_model, SPADEAutoencoderKL):

monai/metrics/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from __future__ import annotations
1313

1414
from .active_learning_metrics import LabelQualityScore, VarianceMetric, compute_variance, label_quality_score
15+
from .average_precision import AveragePrecisionMetric, compute_average_precision
1516
from .confusion_matrix import ConfusionMatrixMetric, compute_confusion_matrix_metric, get_confusion_matrix
1617
from .cumulative_average import CumulativeAverage
1718
from .f_beta_score import FBetaScore

0 commit comments

Comments
 (0)