Skip to content

Commit c608375

Browse files
authored
Merge branch 'dev' into dev
2 parents 7342b84 + 90de55b commit c608375

File tree

9 files changed

+892
-253
lines changed

9 files changed

+892
-253
lines changed

docs/source/networks.rst

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,3 +775,38 @@ Utilities
775775

776776
.. automodule:: monai.apps.reconstruction.networks.nets.utils
777777
:members:
778+
779+
Noise Schedulers
780+
----------------
781+
.. automodule:: monai.networks.schedulers
782+
.. currentmodule:: monai.networks.schedulers
783+
784+
`Scheduler`
785+
~~~~~~~~~~~
786+
.. autoclass:: Scheduler
787+
:members:
788+
789+
`NoiseSchedules`
790+
~~~~~~~~~~~~~~~~
791+
.. autoclass:: NoiseSchedules
792+
:members:
793+
794+
`DDPMScheduler`
795+
~~~~~~~~~~~~~~~
796+
.. autoclass:: DDPMScheduler
797+
:members:
798+
799+
`DDIMScheduler`
800+
~~~~~~~~~~~~~~~
801+
.. autoclass:: DDIMScheduler
802+
:members:
803+
804+
`PNDMScheduler`
805+
~~~~~~~~~~~~~~~
806+
.. autoclass:: PNDMScheduler
807+
:members:
808+
809+
`RFlowScheduler`
810+
~~~~~~~~~~~~~~~~
811+
.. autoclass:: RFlowScheduler
812+
:members:

monai/apps/generation/maisi/networks/autoencoderkl_maisi.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
232232
if self.print_info:
233233
logger.info(f"Number of splits: {self.num_splits}")
234234

235+
if self.dim_split <= 1 and self.num_splits <= 1:
236+
x = self.conv(x)
237+
return x
238+
235239
# compute size of splits
236240
l = x.size(self.dim_split + 2)
237241
split_size = l // self.num_splits

monai/inferers/inferer.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
SPADEAutoencoderKL,
4040
SPADEDiffusionModelUNet,
4141
)
42-
from monai.networks.schedulers import Scheduler
42+
from monai.networks.schedulers import RFlowScheduler, Scheduler
4343
from monai.transforms import CenterSpatialCrop, SpatialPad
4444
from monai.utils import BlendMode, Ordering, PatchKeys, PytorchPadMode, ensure_tuple, optional_import
4545
from monai.visualize import CAM, GradCAM, GradCAMpp
@@ -859,12 +859,18 @@ def sample(
859859
if not scheduler:
860860
scheduler = self.scheduler
861861
image = input_noise
862+
863+
all_next_timesteps = torch.cat((scheduler.timesteps[1:], torch.tensor([0], dtype=scheduler.timesteps.dtype)))
862864
if verbose and has_tqdm:
863-
progress_bar = tqdm(scheduler.timesteps)
865+
progress_bar = tqdm(
866+
zip(scheduler.timesteps, all_next_timesteps),
867+
total=min(len(scheduler.timesteps), len(all_next_timesteps)),
868+
)
864869
else:
865-
progress_bar = iter(scheduler.timesteps)
870+
progress_bar = iter(zip(scheduler.timesteps, all_next_timesteps))
866871
intermediates = []
867-
for t in progress_bar:
872+
873+
for t, next_t in progress_bar:
868874
# 1. predict noise model_output
869875
diffusion_model = (
870876
partial(diffusion_model, seg=seg)
@@ -882,9 +888,13 @@ def sample(
882888
)
883889

884890
# 2. compute previous image: x_t -> x_t-1
885-
image, _ = scheduler.step(model_output, t, image) # type: ignore[operator]
891+
if not isinstance(scheduler, RFlowScheduler):
892+
image, _ = scheduler.step(model_output, t, image) # type: ignore
893+
else:
894+
image, _ = scheduler.step(model_output, t, image, next_t) # type: ignore
886895
if save_intermediates and t % intermediate_steps == 0:
887896
intermediates.append(image)
897+
888898
if save_intermediates:
889899
return image, intermediates
890900
else:
@@ -1392,12 +1402,18 @@ def sample( # type: ignore[override]
13921402
if not scheduler:
13931403
scheduler = self.scheduler
13941404
image = input_noise
1405+
1406+
all_next_timesteps = torch.cat((scheduler.timesteps[1:], torch.tensor([0], dtype=scheduler.timesteps.dtype)))
13951407
if verbose and has_tqdm:
1396-
progress_bar = tqdm(scheduler.timesteps)
1408+
progress_bar = tqdm(
1409+
zip(scheduler.timesteps, all_next_timesteps),
1410+
total=min(len(scheduler.timesteps), len(all_next_timesteps)),
1411+
)
13971412
else:
1398-
progress_bar = iter(scheduler.timesteps)
1413+
progress_bar = iter(zip(scheduler.timesteps, all_next_timesteps))
13991414
intermediates = []
1400-
for t in progress_bar:
1415+
1416+
for t, next_t in progress_bar:
14011417
diffuse = diffusion_model
14021418
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
14031419
diffuse = partial(diffusion_model, seg=seg)
@@ -1436,7 +1452,11 @@ def sample( # type: ignore[override]
14361452
)
14371453

14381454
# 3. compute previous image: x_t -> x_t-1
1439-
image, _ = scheduler.step(model_output, t, image) # type: ignore[operator]
1455+
if not isinstance(scheduler, RFlowScheduler):
1456+
image, _ = scheduler.step(model_output, t, image) # type: ignore
1457+
else:
1458+
image, _ = scheduler.step(model_output, t, image, next_t) # type: ignore
1459+
14401460
if save_intermediates and t % intermediate_steps == 0:
14411461
intermediates.append(image)
14421462
if save_intermediates:

monai/networks/schedulers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@
1414
from .ddim import DDIMScheduler
1515
from .ddpm import DDPMScheduler
1616
from .pndm import PNDMScheduler
17+
from .rectified_flow import RFlowScheduler
1718
from .scheduler import NoiseSchedules, Scheduler

0 commit comments

Comments
 (0)