diff --git a/docs/source/networks.rst b/docs/source/networks.rst index e2e509a99b..0119c6db4d 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -750,3 +750,38 @@ Utilities .. automodule:: monai.apps.reconstruction.networks.nets.utils :members: + +Noise Schedulers +---------------- +.. automodule:: monai.networks.schedulers +.. currentmodule:: monai.networks.schedulers + +`Scheduler` +~~~~~~~~~~~ +.. autoclass:: Scheduler + :members: + +`NoiseSchedules` +~~~~~~~~~~~~~~~~ +.. autoclass:: NoiseSchedules + :members: + +`DDPMScheduler` +~~~~~~~~~~~~~~~ +.. autoclass:: DDPMScheduler + :members: + +`DDIMScheduler` +~~~~~~~~~~~~~~~ +.. autoclass:: DDIMScheduler + :members: + +`PNDMScheduler` +~~~~~~~~~~~~~~~ +.. autoclass:: PNDMScheduler + :members: + +`RFlowScheduler` +~~~~~~~~~~~~~~~~ +.. autoclass:: RFlowScheduler + :members: diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index 6251ea8e83..86b4e68864 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -232,6 +232,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.print_info: logger.info(f"Number of splits: {self.num_splits}") + if self.dim_split <= 1 and self.num_splits <= 1: + x = self.conv(x) + return x + # compute size of splits l = x.size(self.dim_split + 2) split_size = l // self.num_splits diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index df23b9aea0..bfb2756ebe 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -39,7 +39,7 @@ SPADEAutoencoderKL, SPADEDiffusionModelUNet, ) -from monai.networks.schedulers import Scheduler +from monai.networks.schedulers import RFlowScheduler, Scheduler from monai.transforms import CenterSpatialCrop, SpatialPad from monai.utils import BlendMode, Ordering, PatchKeys, PytorchPadMode, ensure_tuple, optional_import from monai.visualize import CAM, GradCAM, GradCAMpp @@ -859,12 +859,18 @@ def sample( if not scheduler: scheduler = self.scheduler image = input_noise + + all_next_timesteps = torch.cat((scheduler.timesteps[1:], torch.tensor([0], dtype=scheduler.timesteps.dtype))) if verbose and has_tqdm: - progress_bar = tqdm(scheduler.timesteps) + progress_bar = tqdm( + zip(scheduler.timesteps, all_next_timesteps), + total=min(len(scheduler.timesteps), len(all_next_timesteps)), + ) else: - progress_bar = iter(scheduler.timesteps) + progress_bar = iter(zip(scheduler.timesteps, all_next_timesteps)) intermediates = [] - for t in progress_bar: + + for t, next_t in progress_bar: # 1. predict noise model_output diffusion_model = ( partial(diffusion_model, seg=seg) @@ -882,9 +888,13 @@ def sample( ) # 2. compute previous image: x_t -> x_t-1 - image, _ = scheduler.step(model_output, t, image) # type: ignore[operator] + if not isinstance(scheduler, RFlowScheduler): + image, _ = scheduler.step(model_output, t, image) # type: ignore + else: + image, _ = scheduler.step(model_output, t, image, next_t) # type: ignore if save_intermediates and t % intermediate_steps == 0: intermediates.append(image) + if save_intermediates: return image, intermediates else: @@ -1392,12 +1402,18 @@ def sample( # type: ignore[override] if not scheduler: scheduler = self.scheduler image = input_noise + + all_next_timesteps = torch.cat((scheduler.timesteps[1:], torch.tensor([0], dtype=scheduler.timesteps.dtype))) if verbose and has_tqdm: - progress_bar = tqdm(scheduler.timesteps) + progress_bar = tqdm( + zip(scheduler.timesteps, all_next_timesteps), + total=min(len(scheduler.timesteps), len(all_next_timesteps)), + ) else: - progress_bar = iter(scheduler.timesteps) + progress_bar = iter(zip(scheduler.timesteps, all_next_timesteps)) intermediates = [] - for t in progress_bar: + + for t, next_t in progress_bar: diffuse = diffusion_model if isinstance(diffusion_model, SPADEDiffusionModelUNet): diffuse = partial(diffusion_model, seg=seg) @@ -1436,7 +1452,11 @@ def sample( # type: ignore[override] ) # 3. compute previous image: x_t -> x_t-1 - image, _ = scheduler.step(model_output, t, image) # type: ignore[operator] + if not isinstance(scheduler, RFlowScheduler): + image, _ = scheduler.step(model_output, t, image) # type: ignore + else: + image, _ = scheduler.step(model_output, t, image, next_t) # type: ignore + if save_intermediates and t % intermediate_steps == 0: intermediates.append(image) if save_intermediates: diff --git a/monai/networks/schedulers/__init__.py b/monai/networks/schedulers/__init__.py index 29e9020d65..b7b34f9a77 100644 --- a/monai/networks/schedulers/__init__.py +++ b/monai/networks/schedulers/__init__.py @@ -14,4 +14,5 @@ from .ddim import DDIMScheduler from .ddpm import DDPMScheduler from .pndm import PNDMScheduler +from .rectified_flow import RFlowScheduler from .scheduler import NoiseSchedules, Scheduler diff --git a/monai/networks/schedulers/rectified_flow.py b/monai/networks/schedulers/rectified_flow.py new file mode 100644 index 0000000000..452160ae0c --- /dev/null +++ b/monai/networks/schedulers/rectified_flow.py @@ -0,0 +1,306 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/hpcaitech/Open-Sora/blob/main/opensora/schedulers/rf/rectified_flow.py +# which has the following license: +# https://github.com/hpcaitech/Open-Sora/blob/main/LICENSE +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +from typing import Union + +import numpy as np +import torch +from torch.distributions import LogisticNormal + +from .scheduler import Scheduler + + +def timestep_transform( + t, input_img_size_numel, base_img_size_numel=32 * 32 * 32, scale=1.0, num_train_timesteps=1000, spatial_dim=3 +): + """ + Applies a transformation to the timestep based on image resolution scaling. + + Args: + t (torch.Tensor): The original timestep(s). + input_img_size_numel (torch.Tensor): The input image's size (H * W * D). + base_img_size_numel (int): reference H*W*D size, usually smaller than input_img_size_numel. + scale (float): Scaling factor for the transformation. + num_train_timesteps (int): Total number of training timesteps. + spatial_dim (int): Number of spatial dimensions in the image. + + Returns: + torch.Tensor: Transformed timestep(s). + """ + t = t / num_train_timesteps + ratio_space = (input_img_size_numel / base_img_size_numel) ** (1.0 / spatial_dim) + + ratio = ratio_space * scale + new_t = ratio * t / (1 + (ratio - 1) * t) + + new_t = new_t * num_train_timesteps + return new_t + + +class RFlowScheduler(Scheduler): + """ + A rectified flow scheduler for guiding the diffusion process in a generative model. + + Supports uniform and logit-normal sampling methods, timestep transformation for + different resolutions, and noise addition during diffusion. + + Args: + num_train_timesteps (int): Total number of training timesteps. + use_discrete_timesteps (bool): Whether to use discrete timesteps. + sample_method (str): Training time step sampling method ('uniform' or 'logit-normal'). + loc (float): Location parameter for logit-normal distribution, used only if sample_method='logit-normal'. + scale (float): Scale parameter for logit-normal distribution, used only if sample_method='logit-normal'. + use_timestep_transform (bool): Whether to apply timestep transformation. + If true, there will be more inference timesteps at early(noisy) stages for larger image volumes. + transform_scale (float): Scaling factor for timestep transformation, used only if use_timestep_transform=True. + steps_offset (int): Offset added to computed timesteps, used only if use_timestep_transform=True. + base_img_size_numel (int): Reference image volume size for scaling, used only if use_timestep_transform=True. + spatial_dim (int): 2 or 3, incidcating 2D or 3D images, used only if use_timestep_transform=True. + + Example: + + .. code-block:: python + + # define a scheduler + noise_scheduler = RFlowScheduler( + num_train_timesteps = 1000, + use_discrete_timesteps = True, + sample_method = 'logit-normal', + use_timestep_transform = True, + base_img_size_numel = 32 * 32 * 32, + spatial_dim = 3 + ) + + # during training + inputs = torch.ones(2,4,64,64,32) + noise = torch.randn_like(inputs) + timesteps = noise_scheduler.sample_timesteps(inputs) + noisy_inputs = noise_scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + predicted_velocity = diffusion_unet( + x=noisy_inputs, + timesteps=timesteps + ) + loss = loss_l1(predicted_velocity, (inputs - noise)) + + # during inference + noisy_inputs = torch.randn(2,4,64,64,32) + input_img_size_numel = torch.prod(torch.tensor(noisy_inputs.shape[-3:]) + noise_scheduler.set_timesteps( + num_inference_steps=30, input_img_size_numel=input_img_size_numel) + ) + all_next_timesteps = torch.cat( + (noise_scheduler.timesteps[1:], torch.tensor([0], dtype=noise_scheduler.timesteps.dtype)) + ) + for t, next_t in tqdm( + zip(noise_scheduler.timesteps, all_next_timesteps), + total=min(len(noise_scheduler.timesteps), len(all_next_timesteps)), + ): + predicted_velocity = diffusion_unet( + x=noisy_inputs, + timesteps=timesteps + ) + noisy_inputs, _ = noise_scheduler.step(predicted_velocity, t, noisy_inputs, next_t) + final_output = noisy_inputs + """ + + def __init__( + self, + num_train_timesteps: int = 1000, + use_discrete_timesteps: bool = True, + sample_method: str = "uniform", + loc: float = 0.0, + scale: float = 1.0, + use_timestep_transform: bool = False, + transform_scale: float = 1.0, + steps_offset: int = 0, + base_img_size_numel: int = 32 * 32 * 32, + spatial_dim: int = 3, + ): + self.num_train_timesteps = num_train_timesteps + self.use_discrete_timesteps = use_discrete_timesteps + self.base_img_size_numel = base_img_size_numel + self.spatial_dim = spatial_dim + + # sample method + if sample_method not in ["uniform", "logit-normal"]: + raise ValueError( + f"sample_method = {sample_method}, which has to be chosen from ['uniform', 'logit-normal']." + ) + self.sample_method = sample_method + if sample_method == "logit-normal": + self.distribution = LogisticNormal(torch.tensor([loc]), torch.tensor([scale])) + self.sample_t = lambda x: self.distribution.sample((x.shape[0],))[:, 0].to(x.device) + + # timestep transform + self.use_timestep_transform = use_timestep_transform + self.transform_scale = transform_scale + self.steps_offset = steps_offset + + def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: + """ + Add noise to the original samples. + + Args: + original_samples: original samples + noise: noise to add to samples + timesteps: timesteps tensor with shape of (N,), indicating the timestep to be computed for each sample. + + Returns: + noisy_samples: sample with added noise + """ + timepoints: torch.Tensor = timesteps.float() / self.num_train_timesteps + timepoints = 1 - timepoints # [1,1/1000] + + # expand timepoint to noise shape + if noise.ndim == 5: + timepoints = timepoints[..., None, None, None, None].expand(-1, *noise.shape[1:]) + elif noise.ndim == 4: + timepoints = timepoints[..., None, None, None].expand(-1, *noise.shape[1:]) + else: + raise ValueError(f"noise tensor has to be 4D or 5D tensor, yet got shape of {noise.shape}") + + noisy_samples: torch.Tensor = timepoints * original_samples + (1 - timepoints) * noise + + return noisy_samples + + def set_timesteps( + self, + num_inference_steps: int, + device: str | torch.device | None = None, + input_img_size_numel: int | None = None, + ) -> None: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model. + device: target device to put the data. + input_img_size_numel: int, H*W*D of the image, used with self.use_timestep_transform is True. + """ + if num_inference_steps > self.num_train_timesteps or num_inference_steps < 1: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} should be at least 1, " + "and cannot be larger than `self.num_train_timesteps`:" + f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + # prepare timesteps + timesteps = [ + (1.0 - i / self.num_inference_steps) * self.num_train_timesteps for i in range(self.num_inference_steps) + ] + if self.use_discrete_timesteps: + timesteps = [int(round(t)) for t in timesteps] + if self.use_timestep_transform: + timesteps = [ + timestep_transform( + t, + input_img_size_numel=input_img_size_numel, + base_img_size_numel=self.base_img_size_numel, + num_train_timesteps=self.num_train_timesteps, + spatial_dim=self.spatial_dim, + ) + for t in timesteps + ] + timesteps_np = np.array(timesteps).astype(np.float16) + if self.use_discrete_timesteps: + timesteps_np = timesteps_np.astype(np.int64) + self.timesteps = torch.from_numpy(timesteps_np).to(device) + self.timesteps += self.steps_offset + + def sample_timesteps(self, x_start): + """ + Randomly samples training timesteps using the chosen sampling method. + + Args: + x_start (torch.Tensor): The input tensor for sampling. + + Returns: + torch.Tensor: Sampled timesteps. + """ + if self.sample_method == "uniform": + t = torch.rand((x_start.shape[0],), device=x_start.device) * self.num_train_timesteps + elif self.sample_method == "logit-normal": + t = self.sample_t(x_start) * self.num_train_timesteps + + if self.use_discrete_timesteps: + t = t.long() + + if self.use_timestep_transform: + input_img_size_numel = torch.prod(torch.tensor(x_start.shape[2:])) + t = timestep_transform( + t, + input_img_size_numel=input_img_size_numel, + base_img_size_numel=self.base_img_size_numel, + num_train_timesteps=self.num_train_timesteps, + spatial_dim=len(x_start.shape) - 2, + ) + + return t + + def step( + self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, next_timestep: Union[int, None] = None + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Predicts the next sample in the diffusion process. + + Args: + model_output (torch.Tensor): Output from the trained diffusion model. + timestep (int): Current timestep in the diffusion chain. + sample (torch.Tensor): Current sample in the process. + next_timestep (Union[int, None]): Optional next timestep. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Predicted sample at the next step and additional info. + """ + # Ensure num_inference_steps exists and is a valid integer + if not hasattr(self, "num_inference_steps") or not isinstance(self.num_inference_steps, int): + raise AttributeError( + "num_inference_steps is missing or not an integer in the class." + "Please run self.set_timesteps(num_inference_steps,device,input_img_size_numel) to set it." + ) + + v_pred = model_output + + if next_timestep is not None: + next_timestep = int(next_timestep) + dt: float = ( + float(timestep - next_timestep) / self.num_train_timesteps + ) # Now next_timestep is guaranteed to be int + else: + dt = ( + 1.0 / float(self.num_inference_steps) if self.num_inference_steps > 0 else 0.0 + ) # Avoid division by zero + + pred_post_sample = sample + v_pred * dt + pred_original_sample = sample + v_pred * timestep / self.num_train_timesteps + + return pred_post_sample, pred_original_sample diff --git a/tests/inferers/test_controlnet_inferers.py b/tests/inferers/test_controlnet_inferers.py index 909f2cf398..1ce81a71d5 100644 --- a/tests/inferers/test_controlnet_inferers.py +++ b/tests/inferers/test_controlnet_inferers.py @@ -26,7 +26,7 @@ SPADEAutoencoderKL, SPADEDiffusionModelUNet, ) -from monai.networks.schedulers import DDIMScheduler, DDPMScheduler +from monai.networks.schedulers import DDIMScheduler, DDPMScheduler, RFlowScheduler from monai.utils import optional_import _, has_scipy = optional_import("scipy") @@ -545,6 +545,32 @@ def test_ddim_sampler(self, model_params, controlnet_params, input_shape): ) self.assertEqual(len(intermediates), 10) + @parameterized.expand(CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_rflow_sampler(self, model_params, controlnet_params, input_shape): + model = DiffusionModelUNet(**model_params) + controlnet = ControlNet(**controlnet_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet.to(device) + controlnet.eval() + mask = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + scheduler = RFlowScheduler(num_train_timesteps=1000) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + intermediate_steps=1, + ) + self.assertEqual(len(intermediates), 10) + @parameterized.expand(CNDM_TEST_CASES) @skipUnless(has_einops, "Requires einops") def test_sampler_conditioned(self, model_params, controlnet_params, input_shape): @@ -561,6 +587,8 @@ def test_sampler_conditioned(self, model_params, controlnet_params, input_shape) controlnet.eval() mask = torch.randn(input_shape).to(device) noise = torch.randn(input_shape).to(device) + + # DDIM scheduler = DDIMScheduler(num_train_timesteps=1000) inferer = ControlNetDiffusionInferer(scheduler=scheduler) scheduler.set_timesteps(num_inference_steps=10) @@ -577,6 +605,23 @@ def test_sampler_conditioned(self, model_params, controlnet_params, input_shape) ) self.assertEqual(len(intermediates), 10) + # RFlow + scheduler = RFlowScheduler(num_train_timesteps=1000) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + conditioning = torch.randn([input_shape[0], 1, 3]).to(device) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + conditioning=conditioning, + ) + self.assertEqual(len(intermediates), 10) + @parameterized.expand(CNDM_TEST_CASES) @skipUnless(has_einops, "Requires einops") def test_get_likelihood(self, model_params, controlnet_params, input_shape): @@ -638,6 +683,8 @@ def test_sampler_conditioned_concat(self, model_params, controlnet_params, input conditioning_shape = list(input_shape) conditioning_shape[1] = n_concat_channel conditioning = torch.randn(conditioning_shape).to(device) + + # DDIM scheduler = DDIMScheduler(num_train_timesteps=1000) inferer = ControlNetDiffusionInferer(scheduler=scheduler) scheduler.set_timesteps(num_inference_steps=10) @@ -654,6 +701,23 @@ def test_sampler_conditioned_concat(self, model_params, controlnet_params, input ) self.assertEqual(len(intermediates), 10) + # RFlow + scheduler = RFlowScheduler(num_train_timesteps=1000) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + conditioning=conditioning, + mode="concat", + ) + self.assertEqual(len(intermediates), 10) + class LatentControlNetTestDiffusionSamplingInferer(unittest.TestCase): @parameterized.expand(LATENT_CNDM_TEST_CASES) @@ -691,39 +755,39 @@ def test_prediction_shape( input = torch.randn(input_shape).to(device) mask = torch.randn(input_shape).to(device) noise = torch.randn(latent_shape).to(device) - scheduler = DDPMScheduler(num_train_timesteps=10) - inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) - scheduler.set_timesteps(num_inference_steps=10) - timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() - if dm_model_type == "SPADEDiffusionModelUNet": - input_shape_seg = list(input_shape) - if "label_nc" in stage_2_params.keys(): - input_shape_seg[1] = stage_2_params["label_nc"] + for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]: + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + seg=input_seg, + noise=noise, + timesteps=timesteps, + ) else: - input_shape_seg[1] = autoencoder_params["label_nc"] - input_seg = torch.randn(input_shape_seg).to(device) - prediction = inferer( - inputs=input, - autoencoder_model=stage_1, - diffusion_model=stage_2, - controlnet=controlnet, - cn_cond=mask, - seg=input_seg, - noise=noise, - timesteps=timesteps, - ) - else: - prediction = inferer( - inputs=input, - autoencoder_model=stage_1, - diffusion_model=stage_2, - noise=noise, - timesteps=timesteps, - controlnet=controlnet, - cn_cond=mask, - ) - self.assertEqual(prediction.shape, latent_shape) + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + timesteps=timesteps, + controlnet=controlnet, + cn_cond=mask, + ) + self.assertEqual(prediction.shape, latent_shape) @parameterized.expand(LATENT_CNDM_TEST_CASES) @skipUnless(has_einops, "Requires einops") diff --git a/tests/inferers/test_diffusion_inferer.py b/tests/inferers/test_diffusion_inferer.py index 7f37025d3c..59b320d8a7 100644 --- a/tests/inferers/test_diffusion_inferer.py +++ b/tests/inferers/test_diffusion_inferer.py @@ -19,7 +19,7 @@ from monai.inferers import DiffusionInferer from monai.networks.nets import DiffusionModelUNet -from monai.networks.schedulers import DDIMScheduler, DDPMScheduler +from monai.networks.schedulers import DDIMScheduler, DDPMScheduler, RFlowScheduler from monai.utils import optional_import _, has_scipy = optional_import("scipy") @@ -120,6 +120,22 @@ def test_ddim_sampler(self, model_params, input_shape): ) self.assertEqual(len(intermediates), 10) + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_rflow_sampler(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + scheduler = RFlowScheduler(num_train_timesteps=1000) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1 + ) + self.assertEqual(len(intermediates), 10) + @parameterized.expand(TEST_CASES) @skipUnless(has_einops, "Requires einops") def test_sampler_conditioned(self, model_params, input_shape): @@ -144,6 +160,30 @@ def test_sampler_conditioned(self, model_params, input_shape): ) self.assertEqual(len(intermediates), 10) + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_sampler_conditioned_rflow(self, model_params, input_shape): + model_params["with_conditioning"] = True + model_params["cross_attention_dim"] = 3 + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + scheduler = RFlowScheduler(num_train_timesteps=1000) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + conditioning = torch.randn([input_shape[0], 1, 3]).to(device) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + conditioning=conditioning, + ) + self.assertEqual(len(intermediates), 10) + @parameterized.expand(TEST_CASES) @skipUnless(has_einops, "Requires einops") def test_get_likelihood(self, model_params, input_shape): @@ -204,6 +244,37 @@ def test_sampler_conditioned_concat(self, model_params, input_shape): ) self.assertEqual(len(intermediates), 10) + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_sampler_conditioned_concat_rflow(self, model_params, input_shape): + # copy the model_params dict to prevent from modifying test cases + model_params = model_params.copy() + n_concat_channel = 2 + model_params["in_channels"] = model_params["in_channels"] + n_concat_channel + model_params["cross_attention_dim"] = None + model_params["with_conditioning"] = False + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + conditioning_shape = list(input_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + scheduler = RFlowScheduler(num_train_timesteps=1000) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + conditioning=conditioning, + mode="concat", + ) + self.assertEqual(len(intermediates), 10) + @parameterized.expand(TEST_CASES) @skipUnless(has_einops, "Requires einops") def test_call_conditioned_concat(self, model_params, input_shape): @@ -231,6 +302,33 @@ def test_call_conditioned_concat(self, model_params, input_shape): ) self.assertEqual(sample.shape, input_shape) + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_call_conditioned_concat_rflow(self, model_params, input_shape): + # copy the model_params dict to prevent from modifying test cases + model_params = model_params.copy() + n_concat_channel = 2 + model_params["in_channels"] = model_params["in_channels"] + n_concat_channel + model_params["cross_attention_dim"] = None + model_params["with_conditioning"] = False + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + input = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + conditioning_shape = list(input_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + scheduler = RFlowScheduler(num_train_timesteps=1000) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + sample = inferer( + inputs=input, noise=noise, diffusion_model=model, timesteps=timesteps, condition=conditioning, mode="concat" + ) + self.assertEqual(sample.shape, input_shape) + if __name__ == "__main__": unittest.main() diff --git a/tests/inferers/test_latent_diffusion_inferer.py b/tests/inferers/test_latent_diffusion_inferer.py index 4f81b96ca1..c20cb5d6ff 100644 --- a/tests/inferers/test_latent_diffusion_inferer.py +++ b/tests/inferers/test_latent_diffusion_inferer.py @@ -19,7 +19,7 @@ from monai.inferers import LatentDiffusionInferer from monai.networks.nets import VQVAE, AutoencoderKL, DiffusionModelUNet, SPADEAutoencoderKL, SPADEDiffusionModelUNet -from monai.networks.schedulers import DDPMScheduler +from monai.networks.schedulers import DDPMScheduler, RFlowScheduler from monai.utils import optional_import _, has_einops = optional_import("einops") @@ -339,31 +339,32 @@ def test_prediction_shape( input = torch.randn(input_shape).to(device) noise = torch.randn(latent_shape).to(device) - scheduler = DDPMScheduler(num_train_timesteps=10) - inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) - scheduler.set_timesteps(num_inference_steps=10) - timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() - if dm_model_type == "SPADEDiffusionModelUNet": - input_shape_seg = list(input_shape) - if "label_nc" in stage_2_params.keys(): - input_shape_seg[1] = stage_2_params["label_nc"] + for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]: + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + seg=input_seg, + noise=noise, + timesteps=timesteps, + ) else: - input_shape_seg[1] = autoencoder_params["label_nc"] - input_seg = torch.randn(input_shape_seg).to(device) - prediction = inferer( - inputs=input, - autoencoder_model=stage_1, - diffusion_model=stage_2, - seg=input_seg, - noise=noise, - timesteps=timesteps, - ) - else: - prediction = inferer( - inputs=input, autoencoder_model=stage_1, diffusion_model=stage_2, noise=noise, timesteps=timesteps - ) - self.assertEqual(prediction.shape, latent_shape) + prediction = inferer( + inputs=input, autoencoder_model=stage_1, diffusion_model=stage_2, noise=noise, timesteps=timesteps + ) + self.assertEqual(prediction.shape, latent_shape) @parameterized.expand(TEST_CASES) @skipUnless(has_einops, "Requires einops") @@ -388,29 +389,30 @@ def test_sample_shape( stage_2.eval() noise = torch.randn(latent_shape).to(device) - scheduler = DDPMScheduler(num_train_timesteps=10) - inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) - scheduler.set_timesteps(num_inference_steps=10) - if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet": - input_shape_seg = list(input_shape) - if "label_nc" in stage_2_params.keys(): - input_shape_seg[1] = stage_2_params["label_nc"] + for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]: + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + seg=input_seg, + ) else: - input_shape_seg[1] = autoencoder_params["label_nc"] - input_seg = torch.randn(input_shape_seg).to(device) - sample = inferer.sample( - input_noise=noise, - autoencoder_model=stage_1, - diffusion_model=stage_2, - scheduler=scheduler, - seg=input_seg, - ) - else: - sample = inferer.sample( - input_noise=noise, autoencoder_model=stage_1, diffusion_model=stage_2, scheduler=scheduler - ) - self.assertEqual(sample.shape, input_shape) + sample = inferer.sample( + input_noise=noise, autoencoder_model=stage_1, diffusion_model=stage_2, scheduler=scheduler + ) + self.assertEqual(sample.shape, input_shape) @parameterized.expand(TEST_CASES) @skipUnless(has_einops, "Requires einops") @@ -437,37 +439,38 @@ def test_sample_intermediates( stage_2.eval() noise = torch.randn(latent_shape).to(device) - scheduler = DDPMScheduler(num_train_timesteps=10) - inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) - scheduler.set_timesteps(num_inference_steps=10) - if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet": - input_shape_seg = list(input_shape) - if "label_nc" in stage_2_params.keys(): - input_shape_seg[1] = stage_2_params["label_nc"] + for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]: + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample, intermediates = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + seg=input_seg, + save_intermediates=True, + intermediate_steps=1, + ) else: - input_shape_seg[1] = autoencoder_params["label_nc"] - input_seg = torch.randn(input_shape_seg).to(device) - sample, intermediates = inferer.sample( - input_noise=noise, - autoencoder_model=stage_1, - diffusion_model=stage_2, - scheduler=scheduler, - seg=input_seg, - save_intermediates=True, - intermediate_steps=1, - ) - else: - sample, intermediates = inferer.sample( - input_noise=noise, - autoencoder_model=stage_1, - diffusion_model=stage_2, - scheduler=scheduler, - save_intermediates=True, - intermediate_steps=1, - ) - self.assertEqual(len(intermediates), 10) - self.assertEqual(intermediates[0].shape, input_shape) + sample, intermediates = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape, input_shape) @parameterized.expand(TEST_CASES) @skipUnless(has_einops, "Requires einops") @@ -614,40 +617,40 @@ def test_prediction_shape_conditioned_concat( conditioning_shape[1] = n_concat_channel conditioning = torch.randn(conditioning_shape).to(device) - scheduler = DDPMScheduler(num_train_timesteps=10) - inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) - scheduler.set_timesteps(num_inference_steps=10) - - timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() - - if dm_model_type == "SPADEDiffusionModelUNet": - input_shape_seg = list(input_shape) - if "label_nc" in stage_2_params.keys(): - input_shape_seg[1] = stage_2_params["label_nc"] + for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]: + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + timesteps=timesteps, + condition=conditioning, + mode="concat", + seg=input_seg, + ) else: - input_shape_seg[1] = autoencoder_params["label_nc"] - input_seg = torch.randn(input_shape_seg).to(device) - prediction = inferer( - inputs=input, - autoencoder_model=stage_1, - diffusion_model=stage_2, - noise=noise, - timesteps=timesteps, - condition=conditioning, - mode="concat", - seg=input_seg, - ) - else: - prediction = inferer( - inputs=input, - autoencoder_model=stage_1, - diffusion_model=stage_2, - noise=noise, - timesteps=timesteps, - condition=conditioning, - mode="concat", - ) - self.assertEqual(prediction.shape, latent_shape) + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + timesteps=timesteps, + condition=conditioning, + mode="concat", + ) + self.assertEqual(prediction.shape, latent_shape) @parameterized.expand(TEST_CASES) @skipUnless(has_einops, "Requires einops") @@ -681,36 +684,36 @@ def test_sample_shape_conditioned_concat( conditioning_shape[1] = n_concat_channel conditioning = torch.randn(conditioning_shape).to(device) - scheduler = DDPMScheduler(num_train_timesteps=10) - inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) - scheduler.set_timesteps(num_inference_steps=10) - - if dm_model_type == "SPADEDiffusionModelUNet": - input_shape_seg = list(input_shape) - if "label_nc" in stage_2_params.keys(): - input_shape_seg[1] = stage_2_params["label_nc"] + for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]: + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + conditioning=conditioning, + mode="concat", + seg=input_seg, + ) else: - input_shape_seg[1] = autoencoder_params["label_nc"] - input_seg = torch.randn(input_shape_seg).to(device) - sample = inferer.sample( - input_noise=noise, - autoencoder_model=stage_1, - diffusion_model=stage_2, - scheduler=scheduler, - conditioning=conditioning, - mode="concat", - seg=input_seg, - ) - else: - sample = inferer.sample( - input_noise=noise, - autoencoder_model=stage_1, - diffusion_model=stage_2, - scheduler=scheduler, - conditioning=conditioning, - mode="concat", - ) - self.assertEqual(sample.shape, input_shape) + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + conditioning=conditioning, + mode="concat", + ) + self.assertEqual(sample.shape, input_shape) @parameterized.expand(TEST_CASES_DIFF_SHAPES) @skipUnless(has_einops, "Requires einops") @@ -738,39 +741,39 @@ def test_shape_different_latents( input = torch.randn(input_shape).to(device) noise = torch.randn(latent_shape).to(device) - scheduler = DDPMScheduler(num_train_timesteps=10) - # We infer the VAE shape - autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]] - inferer = LatentDiffusionInferer( - scheduler=scheduler, - scale_factor=1.0, - ldm_latent_shape=list(latent_shape[2:]), - autoencoder_latent_shape=autoencoder_latent_shape, - ) - scheduler.set_timesteps(num_inference_steps=10) - - timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() - - if dm_model_type == "SPADEDiffusionModelUNet": - input_shape_seg = list(input_shape) - if "label_nc" in stage_2_params.keys(): - input_shape_seg[1] = stage_2_params["label_nc"] - else: - input_shape_seg[1] = autoencoder_params["label_nc"] - input_seg = torch.randn(input_shape_seg).to(device) - prediction = inferer( - inputs=input, - autoencoder_model=stage_1, - diffusion_model=stage_2, - noise=noise, - timesteps=timesteps, - seg=input_seg, - ) - else: - prediction = inferer( - inputs=input, autoencoder_model=stage_1, diffusion_model=stage_2, noise=noise, timesteps=timesteps + for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]: + # We infer the VAE shape + autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]] + inferer = LatentDiffusionInferer( + scheduler=scheduler, + scale_factor=1.0, + ldm_latent_shape=list(latent_shape[2:]), + autoencoder_latent_shape=autoencoder_latent_shape, ) - self.assertEqual(prediction.shape, latent_shape) + scheduler.set_timesteps(num_inference_steps=10) + + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + timesteps=timesteps, + seg=input_seg, + ) + else: + prediction = inferer( + inputs=input, autoencoder_model=stage_1, diffusion_model=stage_2, noise=noise, timesteps=timesteps + ) + self.assertEqual(prediction.shape, latent_shape) @parameterized.expand(TEST_CASES_DIFF_SHAPES) @skipUnless(has_einops, "Requires einops") @@ -797,40 +800,42 @@ def test_sample_shape_different_latents( stage_2.eval() noise = torch.randn(latent_shape).to(device) - scheduler = DDPMScheduler(num_train_timesteps=10) - # We infer the VAE shape - if ae_model_type == "VQVAE": - autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]))) for i in input_shape[2:]] - else: - autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]] - - inferer = LatentDiffusionInferer( - scheduler=scheduler, - scale_factor=1.0, - ldm_latent_shape=list(latent_shape[2:]), - autoencoder_latent_shape=autoencoder_latent_shape, - ) - scheduler.set_timesteps(num_inference_steps=10) - - if dm_model_type == "SPADEDiffusionModelUNet" or ae_model_type == "SPADEAutoencoderKL": - input_shape_seg = list(input_shape) - if "label_nc" in stage_2_params.keys(): - input_shape_seg[1] = stage_2_params["label_nc"] + for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]: + # We infer the VAE shape + if ae_model_type == "VQVAE": + autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]))) for i in input_shape[2:]] else: - input_shape_seg[1] = autoencoder_params["label_nc"] - input_seg = torch.randn(input_shape_seg).to(device) - prediction, _ = inferer.sample( - autoencoder_model=stage_1, - diffusion_model=stage_2, - input_noise=noise, - save_intermediates=True, - seg=input_seg, - ) - else: - prediction = inferer.sample( - autoencoder_model=stage_1, diffusion_model=stage_2, input_noise=noise, save_intermediates=False + autoencoder_latent_shape = [ + i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:] + ] + + inferer = LatentDiffusionInferer( + scheduler=scheduler, + scale_factor=1.0, + ldm_latent_shape=list(latent_shape[2:]), + autoencoder_latent_shape=autoencoder_latent_shape, ) - self.assertEqual(prediction.shape, input_shape) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet" or ae_model_type == "SPADEAutoencoderKL": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction, _ = inferer.sample( + autoencoder_model=stage_1, + diffusion_model=stage_2, + input_noise=noise, + save_intermediates=True, + seg=input_seg, + ) + else: + prediction = inferer.sample( + autoencoder_model=stage_1, diffusion_model=stage_2, input_noise=noise, save_intermediates=False + ) + self.assertEqual(prediction.shape, input_shape) @skipUnless(has_einops, "Requires einops") def test_incompatible_spade_setup(self): @@ -866,18 +871,19 @@ def test_incompatible_spade_setup(self): stage_2.eval() noise = torch.randn((1, 3, 4, 4)).to(device) input_seg = torch.randn((1, 3, 8, 8)).to(device) - scheduler = DDPMScheduler(num_train_timesteps=10) - inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) - scheduler.set_timesteps(num_inference_steps=10) - with self.assertRaises(ValueError): - _ = inferer.sample( - input_noise=noise, - autoencoder_model=stage_1, - diffusion_model=stage_2, - scheduler=scheduler, - seg=input_seg, - ) + for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]: + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + with self.assertRaises(ValueError): + _ = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + seg=input_seg, + ) if __name__ == "__main__": diff --git a/tests/networks/schedulers/test_scheduler_rflow.py b/tests/networks/schedulers/test_scheduler_rflow.py new file mode 100644 index 0000000000..08f4ed3730 --- /dev/null +++ b/tests/networks/schedulers/test_scheduler_rflow.py @@ -0,0 +1,105 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.schedulers import RFlowScheduler +from tests.test_utils import assert_allclose + +TEST_2D_CASE = [] +for sample_method in ["uniform", "logit-normal"]: + TEST_2D_CASE.append( + [{"sample_method": sample_method, "use_timestep_transform": False}, (2, 6, 16, 16), (2, 6, 16, 16)] + ) + +for sample_method in ["uniform", "logit-normal"]: + TEST_2D_CASE.append( + [ + {"sample_method": sample_method, "use_timestep_transform": True, "spatial_dim": 2}, + (2, 6, 16, 16), + (2, 6, 16, 16), + ] + ) + + +TEST_3D_CASE = [] +for sample_method in ["uniform", "logit-normal"]: + TEST_3D_CASE.append( + [{"sample_method": sample_method, "use_timestep_transform": False}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)] + ) + +for sample_method in ["uniform", "logit-normal"]: + TEST_3D_CASE.append( + [ + {"sample_method": sample_method, "use_timestep_transform": True, "spatial_dim": 3}, + (2, 6, 16, 16, 16), + (2, 6, 16, 16, 16), + ] + ) + +TEST_CASES = TEST_2D_CASE + TEST_3D_CASE + +TEST_FULl_LOOP = [ + [{"sample_method": "uniform"}, (1, 1, 2, 2), torch.Tensor([[[[-0.786166, -0.057519], [2.442662, -0.407664]]]])] +] + + +class TestRFlowScheduler(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_add_noise(self, input_param, input_shape, expected_shape): + scheduler = RFlowScheduler(**input_param) + original_sample = torch.zeros(input_shape) + timesteps = scheduler.sample_timesteps(original_sample) + noise = torch.randn_like(original_sample) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (original_sample.shape[0],)).long() + noisy = scheduler.add_noise(original_samples=original_sample, noise=noise, timesteps=timesteps) + self.assertEqual(noisy.shape, expected_shape) + + @parameterized.expand(TEST_CASES) + def test_step_shape(self, input_param, input_shape, expected_shape): + scheduler = RFlowScheduler(**input_param) + model_output = torch.randn(input_shape) + sample = torch.randn(input_shape) + scheduler.set_timesteps(num_inference_steps=100, input_img_size_numel=torch.numel(sample[0, 0, ...])) + output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample) + self.assertEqual(output_step[0].shape, expected_shape) + self.assertEqual(output_step[1].shape, expected_shape) + + @parameterized.expand(TEST_FULl_LOOP) + def test_full_timestep_loop(self, input_param, input_shape, expected_output): + scheduler = RFlowScheduler(**input_param) + torch.manual_seed(42) + model_output = torch.randn(input_shape) + sample = torch.randn(input_shape) + scheduler.set_timesteps(50, input_img_size_numel=torch.numel(sample[0, 0, ...])) + for t in range(50): + sample, _ = scheduler.step(model_output=model_output, timestep=t, sample=sample) + assert_allclose(sample, expected_output, rtol=1e-3, atol=1e-3) + + def test_set_timesteps(self): + scheduler = RFlowScheduler(num_train_timesteps=1000) + scheduler.set_timesteps(num_inference_steps=100, input_img_size_numel=16 * 16 * 16) + self.assertEqual(scheduler.num_inference_steps, 100) + self.assertEqual(len(scheduler.timesteps), 100) + + def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self): + scheduler = RFlowScheduler(num_train_timesteps=1000) + with self.assertRaises(ValueError): + scheduler.set_timesteps(num_inference_steps=2000, input_img_size_numel=16 * 16 * 16) + + +if __name__ == "__main__": + unittest.main()