Skip to content

Commit e4e089f

Browse files
authored
Merge branch 'Project-MONAI:dev' into fix-6840
2 parents a34a656 + 0a85eed commit e4e089f

File tree

3 files changed

+151
-10
lines changed

3 files changed

+151
-10
lines changed

monai/inferers/inferer.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1202,15 +1202,16 @@ def sample( # type: ignore[override]
12021202

12031203
if self.autoencoder_latent_shape is not None:
12041204
latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
1205-
latent_intermediates = [
1206-
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates
1207-
]
1205+
if save_intermediates:
1206+
latent_intermediates = [
1207+
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0)
1208+
for l in latent_intermediates
1209+
]
12081210

12091211
decode = autoencoder_model.decode_stage_2_outputs
12101212
if isinstance(autoencoder_model, SPADEAutoencoderKL):
12111213
decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg)
12121214
image = decode(latent / self.scale_factor)
1213-
12141215
if save_intermediates:
12151216
intermediates = []
12161217
for latent_intermediate in latent_intermediates:
@@ -1727,9 +1728,11 @@ def sample( # type: ignore[override]
17271728

17281729
if self.autoencoder_latent_shape is not None:
17291730
latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
1730-
latent_intermediates = [
1731-
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates
1732-
]
1731+
if save_intermediates:
1732+
latent_intermediates = [
1733+
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0)
1734+
for l in latent_intermediates
1735+
]
17331736

17341737
decode = autoencoder_model.decode_stage_2_outputs
17351738
if isinstance(autoencoder_model, SPADEAutoencoderKL):

tests/inferers/test_controlnet_inferers.py

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -722,7 +722,7 @@ def test_prediction_shape(
722722

723723
@parameterized.expand(LATENT_CNDM_TEST_CASES)
724724
@skipUnless(has_einops, "Requires einops")
725-
def test_sample_shape(
725+
def test_pred_shape(
726726
self,
727727
ae_model_type,
728728
autoencoder_params,
@@ -1165,7 +1165,7 @@ def test_sample_shape_conditioned_concat(
11651165

11661166
@parameterized.expand(LATENT_CNDM_TEST_CASES_DIFF_SHAPES)
11671167
@skipUnless(has_einops, "Requires einops")
1168-
def test_sample_shape_different_latents(
1168+
def test_shape_different_latents(
11691169
self,
11701170
ae_model_type,
11711171
autoencoder_params,
@@ -1242,6 +1242,84 @@ def test_sample_shape_different_latents(
12421242
)
12431243
self.assertEqual(prediction.shape, latent_shape)
12441244

1245+
@parameterized.expand(LATENT_CNDM_TEST_CASES_DIFF_SHAPES)
1246+
@skipUnless(has_einops, "Requires einops")
1247+
def test_sample_shape_different_latents(
1248+
self,
1249+
ae_model_type,
1250+
autoencoder_params,
1251+
dm_model_type,
1252+
stage_2_params,
1253+
controlnet_params,
1254+
input_shape,
1255+
latent_shape,
1256+
):
1257+
stage_1 = None
1258+
1259+
if ae_model_type == "AutoencoderKL":
1260+
stage_1 = AutoencoderKL(**autoencoder_params)
1261+
if ae_model_type == "VQVAE":
1262+
stage_1 = VQVAE(**autoencoder_params)
1263+
if ae_model_type == "SPADEAutoencoderKL":
1264+
stage_1 = SPADEAutoencoderKL(**autoencoder_params)
1265+
if dm_model_type == "SPADEDiffusionModelUNet":
1266+
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
1267+
else:
1268+
stage_2 = DiffusionModelUNet(**stage_2_params)
1269+
controlnet = ControlNet(**controlnet_params)
1270+
1271+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
1272+
stage_1.to(device)
1273+
stage_2.to(device)
1274+
controlnet.to(device)
1275+
stage_1.eval()
1276+
stage_2.eval()
1277+
controlnet.eval()
1278+
1279+
noise = torch.randn(latent_shape).to(device)
1280+
mask = torch.randn(input_shape).to(device)
1281+
scheduler = DDPMScheduler(num_train_timesteps=10)
1282+
# We infer the VAE shape
1283+
if ae_model_type == "VQVAE":
1284+
autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]))) for i in input_shape[2:]]
1285+
else:
1286+
autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]]
1287+
1288+
inferer = ControlNetLatentDiffusionInferer(
1289+
scheduler=scheduler,
1290+
scale_factor=1.0,
1291+
ldm_latent_shape=list(latent_shape[2:]),
1292+
autoencoder_latent_shape=autoencoder_latent_shape,
1293+
)
1294+
scheduler.set_timesteps(num_inference_steps=10)
1295+
1296+
if dm_model_type == "SPADEDiffusionModelUNet" or ae_model_type == "SPADEAutoencoderKL":
1297+
input_shape_seg = list(input_shape)
1298+
if "label_nc" in stage_2_params.keys():
1299+
input_shape_seg[1] = stage_2_params["label_nc"]
1300+
else:
1301+
input_shape_seg[1] = autoencoder_params["label_nc"]
1302+
input_seg = torch.randn(input_shape_seg).to(device)
1303+
prediction, _ = inferer.sample(
1304+
autoencoder_model=stage_1,
1305+
diffusion_model=stage_2,
1306+
controlnet=controlnet,
1307+
cn_cond=mask,
1308+
input_noise=noise,
1309+
seg=input_seg,
1310+
save_intermediates=True,
1311+
)
1312+
else:
1313+
prediction = inferer.sample(
1314+
autoencoder_model=stage_1,
1315+
diffusion_model=stage_2,
1316+
input_noise=noise,
1317+
controlnet=controlnet,
1318+
cn_cond=mask,
1319+
save_intermediates=False,
1320+
)
1321+
self.assertEqual(prediction.shape, input_shape)
1322+
12451323
@skipUnless(has_einops, "Requires einops")
12461324
def test_incompatible_spade_setup(self):
12471325
stage_1 = SPADEAutoencoderKL(

tests/inferers/test_latent_diffusion_inferer.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -714,7 +714,7 @@ def test_sample_shape_conditioned_concat(
714714

715715
@parameterized.expand(TEST_CASES_DIFF_SHAPES)
716716
@skipUnless(has_einops, "Requires einops")
717-
def test_sample_shape_different_latents(
717+
def test_shape_different_latents(
718718
self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape
719719
):
720720
stage_1 = None
@@ -772,6 +772,66 @@ def test_sample_shape_different_latents(
772772
)
773773
self.assertEqual(prediction.shape, latent_shape)
774774

775+
@parameterized.expand(TEST_CASES_DIFF_SHAPES)
776+
@skipUnless(has_einops, "Requires einops")
777+
def test_sample_shape_different_latents(
778+
self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape
779+
):
780+
stage_1 = None
781+
782+
if ae_model_type == "AutoencoderKL":
783+
stage_1 = AutoencoderKL(**autoencoder_params)
784+
if ae_model_type == "VQVAE":
785+
stage_1 = VQVAE(**autoencoder_params)
786+
if ae_model_type == "SPADEAutoencoderKL":
787+
stage_1 = SPADEAutoencoderKL(**autoencoder_params)
788+
if dm_model_type == "SPADEDiffusionModelUNet":
789+
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
790+
else:
791+
stage_2 = DiffusionModelUNet(**stage_2_params)
792+
793+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
794+
stage_1.to(device)
795+
stage_2.to(device)
796+
stage_1.eval()
797+
stage_2.eval()
798+
799+
noise = torch.randn(latent_shape).to(device)
800+
scheduler = DDPMScheduler(num_train_timesteps=10)
801+
# We infer the VAE shape
802+
if ae_model_type == "VQVAE":
803+
autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]))) for i in input_shape[2:]]
804+
else:
805+
autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]]
806+
807+
inferer = LatentDiffusionInferer(
808+
scheduler=scheduler,
809+
scale_factor=1.0,
810+
ldm_latent_shape=list(latent_shape[2:]),
811+
autoencoder_latent_shape=autoencoder_latent_shape,
812+
)
813+
scheduler.set_timesteps(num_inference_steps=10)
814+
815+
if dm_model_type == "SPADEDiffusionModelUNet" or ae_model_type == "SPADEAutoencoderKL":
816+
input_shape_seg = list(input_shape)
817+
if "label_nc" in stage_2_params.keys():
818+
input_shape_seg[1] = stage_2_params["label_nc"]
819+
else:
820+
input_shape_seg[1] = autoencoder_params["label_nc"]
821+
input_seg = torch.randn(input_shape_seg).to(device)
822+
prediction, _ = inferer.sample(
823+
autoencoder_model=stage_1,
824+
diffusion_model=stage_2,
825+
input_noise=noise,
826+
save_intermediates=True,
827+
seg=input_seg,
828+
)
829+
else:
830+
prediction = inferer.sample(
831+
autoencoder_model=stage_1, diffusion_model=stage_2, input_noise=noise, save_intermediates=False
832+
)
833+
self.assertEqual(prediction.shape, input_shape)
834+
775835
@skipUnless(has_einops, "Requires einops")
776836
def test_incompatible_spade_setup(self):
777837
stage_1 = SPADEAutoencoderKL(

0 commit comments

Comments
 (0)