Skip to content

Commit 40be2a6

Browse files
committed
reformat
Signed-off-by: Can-Zhao <[email protected]>
1 parent c2e3cb5 commit 40be2a6

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

tests/inferers/test_latent_diffusion_inferer.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -339,12 +339,12 @@ def test_prediction_shape(
339339

340340
input = torch.randn(input_shape).to(device)
341341
noise = torch.randn(latent_shape).to(device)
342-
342+
343343
for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]:
344344
inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
345345
scheduler.set_timesteps(num_inference_steps=10)
346346
timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
347-
347+
348348
if dm_model_type == "SPADEDiffusionModelUNet":
349349
input_shape_seg = list(input_shape)
350350
if "label_nc" in stage_2_params.keys():
@@ -393,7 +393,7 @@ def test_sample_shape(
393393
for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]:
394394
inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
395395
scheduler.set_timesteps(num_inference_steps=10)
396-
396+
397397
if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet":
398398
input_shape_seg = list(input_shape)
399399
if "label_nc" in stage_2_params.keys():
@@ -443,7 +443,7 @@ def test_sample_intermediates(
443443
for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]:
444444
inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
445445
scheduler.set_timesteps(num_inference_steps=10)
446-
446+
447447
if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet":
448448
input_shape_seg = list(input_shape)
449449
if "label_nc" in stage_2_params.keys():
@@ -620,9 +620,9 @@ def test_prediction_shape_conditioned_concat(
620620
for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]:
621621
inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
622622
scheduler.set_timesteps(num_inference_steps=10)
623-
623+
624624
timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
625-
625+
626626
if dm_model_type == "SPADEDiffusionModelUNet":
627627
input_shape_seg = list(input_shape)
628628
if "label_nc" in stage_2_params.keys():
@@ -687,7 +687,7 @@ def test_sample_shape_conditioned_concat(
687687
for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]:
688688
inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
689689
scheduler.set_timesteps(num_inference_steps=10)
690-
690+
691691
if dm_model_type == "SPADEDiffusionModelUNet":
692692
input_shape_seg = list(input_shape)
693693
if "label_nc" in stage_2_params.keys():
@@ -751,9 +751,9 @@ def test_shape_different_latents(
751751
autoencoder_latent_shape=autoencoder_latent_shape,
752752
)
753753
scheduler.set_timesteps(num_inference_steps=10)
754-
754+
755755
timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
756-
756+
757757
if dm_model_type == "SPADEDiffusionModelUNet":
758758
input_shape_seg = list(input_shape)
759759
if "label_nc" in stage_2_params.keys():
@@ -805,16 +805,18 @@ def test_sample_shape_different_latents(
805805
if ae_model_type == "VQVAE":
806806
autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]))) for i in input_shape[2:]]
807807
else:
808-
autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]]
809-
808+
autoencoder_latent_shape = [
809+
i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]
810+
]
811+
810812
inferer = LatentDiffusionInferer(
811813
scheduler=scheduler,
812814
scale_factor=1.0,
813815
ldm_latent_shape=list(latent_shape[2:]),
814816
autoencoder_latent_shape=autoencoder_latent_shape,
815817
)
816818
scheduler.set_timesteps(num_inference_steps=10)
817-
819+
818820
if dm_model_type == "SPADEDiffusionModelUNet" or ae_model_type == "SPADEAutoencoderKL":
819821
input_shape_seg = list(input_shape)
820822
if "label_nc" in stage_2_params.keys():
@@ -873,7 +875,7 @@ def test_incompatible_spade_setup(self):
873875
for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]:
874876
inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
875877
scheduler.set_timesteps(num_inference_steps=10)
876-
878+
877879
with self.assertRaises(ValueError):
878880
_ = inferer.sample(
879881
input_noise=noise,

0 commit comments

Comments
 (0)