diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 769b6cc0e7..7083373859 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -1202,15 +1202,16 @@ def sample( # type: ignore[override] if self.autoencoder_latent_shape is not None: latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0) - latent_intermediates = [ - torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates - ] + if save_intermediates: + latent_intermediates = [ + torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) + for l in latent_intermediates + ] decode = autoencoder_model.decode_stage_2_outputs if isinstance(autoencoder_model, SPADEAutoencoderKL): decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg) image = decode(latent / self.scale_factor) - if save_intermediates: intermediates = [] for latent_intermediate in latent_intermediates: @@ -1727,9 +1728,11 @@ def sample( # type: ignore[override] if self.autoencoder_latent_shape is not None: latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0) - latent_intermediates = [ - torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates - ] + if save_intermediates: + latent_intermediates = [ + torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) + for l in latent_intermediates + ] decode = autoencoder_model.decode_stage_2_outputs if isinstance(autoencoder_model, SPADEAutoencoderKL): diff --git a/tests/inferers/test_controlnet_inferers.py b/tests/inferers/test_controlnet_inferers.py index e3b0aeb5a2..2ab5cec335 100644 --- a/tests/inferers/test_controlnet_inferers.py +++ b/tests/inferers/test_controlnet_inferers.py @@ -722,7 +722,7 @@ def test_prediction_shape( @parameterized.expand(LATENT_CNDM_TEST_CASES) @skipUnless(has_einops, "Requires einops") - def test_sample_shape( + def test_pred_shape( self, ae_model_type, autoencoder_params, @@ -1165,7 +1165,7 @@ def test_sample_shape_conditioned_concat( @parameterized.expand(LATENT_CNDM_TEST_CASES_DIFF_SHAPES) @skipUnless(has_einops, "Requires einops") - def test_sample_shape_different_latents( + def test_shape_different_latents( self, ae_model_type, autoencoder_params, @@ -1242,6 +1242,84 @@ def test_sample_shape_different_latents( ) self.assertEqual(prediction.shape, latent_shape) + @parameterized.expand(LATENT_CNDM_TEST_CASES_DIFF_SHAPES) + @skipUnless(has_einops, "Requires einops") + def test_sample_shape_different_latents( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + stage_1 = None + + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + noise = torch.randn(latent_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + # We infer the VAE shape + if ae_model_type == "VQVAE": + autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]))) for i in input_shape[2:]] + else: + autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]] + + inferer = ControlNetLatentDiffusionInferer( + scheduler=scheduler, + scale_factor=1.0, + ldm_latent_shape=list(latent_shape[2:]), + autoencoder_latent_shape=autoencoder_latent_shape, + ) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet" or ae_model_type == "SPADEAutoencoderKL": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction, _ = inferer.sample( + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + input_noise=noise, + seg=input_seg, + save_intermediates=True, + ) + else: + prediction = inferer.sample( + autoencoder_model=stage_1, + diffusion_model=stage_2, + input_noise=noise, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=False, + ) + self.assertEqual(prediction.shape, input_shape) + @skipUnless(has_einops, "Requires einops") def test_incompatible_spade_setup(self): stage_1 = SPADEAutoencoderKL( diff --git a/tests/inferers/test_latent_diffusion_inferer.py b/tests/inferers/test_latent_diffusion_inferer.py index 2e04ad6c5c..4f81b96ca1 100644 --- a/tests/inferers/test_latent_diffusion_inferer.py +++ b/tests/inferers/test_latent_diffusion_inferer.py @@ -714,7 +714,7 @@ def test_sample_shape_conditioned_concat( @parameterized.expand(TEST_CASES_DIFF_SHAPES) @skipUnless(has_einops, "Requires einops") - def test_sample_shape_different_latents( + def test_shape_different_latents( self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape ): stage_1 = None @@ -772,6 +772,66 @@ def test_sample_shape_different_latents( ) self.assertEqual(prediction.shape, latent_shape) + @parameterized.expand(TEST_CASES_DIFF_SHAPES) + @skipUnless(has_einops, "Requires einops") + def test_sample_shape_different_latents( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + stage_1 = None + + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + noise = torch.randn(latent_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + # We infer the VAE shape + if ae_model_type == "VQVAE": + autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]))) for i in input_shape[2:]] + else: + autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]] + + inferer = LatentDiffusionInferer( + scheduler=scheduler, + scale_factor=1.0, + ldm_latent_shape=list(latent_shape[2:]), + autoencoder_latent_shape=autoencoder_latent_shape, + ) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet" or ae_model_type == "SPADEAutoencoderKL": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction, _ = inferer.sample( + autoencoder_model=stage_1, + diffusion_model=stage_2, + input_noise=noise, + save_intermediates=True, + seg=input_seg, + ) + else: + prediction = inferer.sample( + autoencoder_model=stage_1, diffusion_model=stage_2, input_noise=noise, save_intermediates=False + ) + self.assertEqual(prediction.shape, input_shape) + @skipUnless(has_einops, "Requires einops") def test_incompatible_spade_setup(self): stage_1 = SPADEAutoencoderKL(