@@ -339,12 +339,12 @@ def test_prediction_shape(
339339
340340 input = torch .randn (input_shape ).to (device )
341341 noise = torch .randn (latent_shape ).to (device )
342-
342+
343343 for scheduler in [DDPMScheduler (num_train_timesteps = 10 ), RFlowScheduler (num_train_timesteps = 1000 )]:
344344 inferer = LatentDiffusionInferer (scheduler = scheduler , scale_factor = 1.0 )
345345 scheduler .set_timesteps (num_inference_steps = 10 )
346346 timesteps = torch .randint (0 , scheduler .num_train_timesteps , (input_shape [0 ],), device = input .device ).long ()
347-
347+
348348 if dm_model_type == "SPADEDiffusionModelUNet" :
349349 input_shape_seg = list (input_shape )
350350 if "label_nc" in stage_2_params .keys ():
@@ -393,7 +393,7 @@ def test_sample_shape(
393393 for scheduler in [DDPMScheduler (num_train_timesteps = 10 ), RFlowScheduler (num_train_timesteps = 1000 )]:
394394 inferer = LatentDiffusionInferer (scheduler = scheduler , scale_factor = 1.0 )
395395 scheduler .set_timesteps (num_inference_steps = 10 )
396-
396+
397397 if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet" :
398398 input_shape_seg = list (input_shape )
399399 if "label_nc" in stage_2_params .keys ():
@@ -443,7 +443,7 @@ def test_sample_intermediates(
443443 for scheduler in [DDPMScheduler (num_train_timesteps = 10 ), RFlowScheduler (num_train_timesteps = 1000 )]:
444444 inferer = LatentDiffusionInferer (scheduler = scheduler , scale_factor = 1.0 )
445445 scheduler .set_timesteps (num_inference_steps = 10 )
446-
446+
447447 if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet" :
448448 input_shape_seg = list (input_shape )
449449 if "label_nc" in stage_2_params .keys ():
@@ -620,9 +620,9 @@ def test_prediction_shape_conditioned_concat(
620620 for scheduler in [DDPMScheduler (num_train_timesteps = 10 ), RFlowScheduler (num_train_timesteps = 1000 )]:
621621 inferer = LatentDiffusionInferer (scheduler = scheduler , scale_factor = 1.0 )
622622 scheduler .set_timesteps (num_inference_steps = 10 )
623-
623+
624624 timesteps = torch .randint (0 , scheduler .num_train_timesteps , (input_shape [0 ],), device = input .device ).long ()
625-
625+
626626 if dm_model_type == "SPADEDiffusionModelUNet" :
627627 input_shape_seg = list (input_shape )
628628 if "label_nc" in stage_2_params .keys ():
@@ -687,7 +687,7 @@ def test_sample_shape_conditioned_concat(
687687 for scheduler in [DDPMScheduler (num_train_timesteps = 10 ), RFlowScheduler (num_train_timesteps = 1000 )]:
688688 inferer = LatentDiffusionInferer (scheduler = scheduler , scale_factor = 1.0 )
689689 scheduler .set_timesteps (num_inference_steps = 10 )
690-
690+
691691 if dm_model_type == "SPADEDiffusionModelUNet" :
692692 input_shape_seg = list (input_shape )
693693 if "label_nc" in stage_2_params .keys ():
@@ -751,9 +751,9 @@ def test_shape_different_latents(
751751 autoencoder_latent_shape = autoencoder_latent_shape ,
752752 )
753753 scheduler .set_timesteps (num_inference_steps = 10 )
754-
754+
755755 timesteps = torch .randint (0 , scheduler .num_train_timesteps , (input_shape [0 ],), device = input .device ).long ()
756-
756+
757757 if dm_model_type == "SPADEDiffusionModelUNet" :
758758 input_shape_seg = list (input_shape )
759759 if "label_nc" in stage_2_params .keys ():
@@ -805,16 +805,18 @@ def test_sample_shape_different_latents(
805805 if ae_model_type == "VQVAE" :
806806 autoencoder_latent_shape = [i // (2 ** (len (autoencoder_params ["channels" ]))) for i in input_shape [2 :]]
807807 else :
808- autoencoder_latent_shape = [i // (2 ** (len (autoencoder_params ["channels" ]) - 1 )) for i in input_shape [2 :]]
809-
808+ autoencoder_latent_shape = [
809+ i // (2 ** (len (autoencoder_params ["channels" ]) - 1 )) for i in input_shape [2 :]
810+ ]
811+
810812 inferer = LatentDiffusionInferer (
811813 scheduler = scheduler ,
812814 scale_factor = 1.0 ,
813815 ldm_latent_shape = list (latent_shape [2 :]),
814816 autoencoder_latent_shape = autoencoder_latent_shape ,
815817 )
816818 scheduler .set_timesteps (num_inference_steps = 10 )
817-
819+
818820 if dm_model_type == "SPADEDiffusionModelUNet" or ae_model_type == "SPADEAutoencoderKL" :
819821 input_shape_seg = list (input_shape )
820822 if "label_nc" in stage_2_params .keys ():
@@ -873,7 +875,7 @@ def test_incompatible_spade_setup(self):
873875 for scheduler in [DDPMScheduler (num_train_timesteps = 10 ), RFlowScheduler (num_train_timesteps = 1000 )]:
874876 inferer = LatentDiffusionInferer (scheduler = scheduler , scale_factor = 1.0 )
875877 scheduler .set_timesteps (num_inference_steps = 10 )
876-
878+
877879 with self .assertRaises (ValueError ):
878880 _ = inferer .sample (
879881 input_noise = noise ,
0 commit comments