@@ -355,14 +355,14 @@ def test_prediction_shape(
355355 prediction = inferer (
356356 inputs = input ,
357357 autoencoder_model = stage_1 ,
358- diffusion_model = stage_2 ,
358+ network = stage_2 ,
359359 seg = input_seg ,
360360 noise = noise ,
361361 timesteps = timesteps ,
362362 )
363363 else :
364364 prediction = inferer (
365- inputs = input , autoencoder_model = stage_1 , diffusion_model = stage_2 , noise = noise , timesteps = timesteps
365+ inputs = input , autoencoder_model = stage_1 , network = stage_2 , noise = noise , timesteps = timesteps
366366 )
367367 self .assertEqual (prediction .shape , latent_shape )
368368
@@ -404,13 +404,13 @@ def test_sample_shape(
404404 sample = inferer .sample (
405405 input_noise = noise ,
406406 autoencoder_model = stage_1 ,
407- diffusion_model = stage_2 ,
407+ network = stage_2 ,
408408 scheduler = scheduler ,
409409 seg = input_seg ,
410410 )
411411 else :
412412 sample = inferer .sample (
413- input_noise = noise , autoencoder_model = stage_1 , diffusion_model = stage_2 , scheduler = scheduler
413+ input_noise = noise , autoencoder_model = stage_1 , network = stage_2 , scheduler = scheduler
414414 )
415415 self .assertEqual (sample .shape , input_shape )
416416
@@ -452,14 +452,14 @@ def test_sample_shape_with_cfg(
452452 sample = inferer .sample (
453453 input_noise = noise ,
454454 autoencoder_model = stage_1 ,
455- diffusion_model = stage_2 ,
455+ network = stage_2 ,
456456 scheduler = scheduler ,
457457 seg = input_seg ,
458458 cfg = 5 ,
459459 )
460460 else :
461461 sample = inferer .sample (
462- input_noise = noise , autoencoder_model = stage_1 , diffusion_model = stage_2 , scheduler = scheduler , cfg = 5
462+ input_noise = noise , autoencoder_model = stage_1 , network = stage_2 , scheduler = scheduler , cfg = 5
463463 )
464464 self .assertEqual (sample .shape , input_shape )
465465
@@ -503,7 +503,7 @@ def test_sample_intermediates(
503503 sample , intermediates = inferer .sample (
504504 input_noise = noise ,
505505 autoencoder_model = stage_1 ,
506- diffusion_model = stage_2 ,
506+ network = stage_2 ,
507507 scheduler = scheduler ,
508508 seg = input_seg ,
509509 save_intermediates = True ,
@@ -513,7 +513,7 @@ def test_sample_intermediates(
513513 sample , intermediates = inferer .sample (
514514 input_noise = noise ,
515515 autoencoder_model = stage_1 ,
516- diffusion_model = stage_2 ,
516+ network = stage_2 ,
517517 scheduler = scheduler ,
518518 save_intermediates = True ,
519519 intermediate_steps = 1 ,
@@ -560,7 +560,7 @@ def test_get_likelihoods(
560560 sample , intermediates = inferer .get_likelihood (
561561 inputs = input ,
562562 autoencoder_model = stage_1 ,
563- diffusion_model = stage_2 ,
563+ network = stage_2 ,
564564 scheduler = scheduler ,
565565 save_intermediates = True ,
566566 seg = input_seg ,
@@ -569,7 +569,7 @@ def test_get_likelihoods(
569569 sample , intermediates = inferer .get_likelihood (
570570 inputs = input ,
571571 autoencoder_model = stage_1 ,
572- diffusion_model = stage_2 ,
572+ network = stage_2 ,
573573 scheduler = scheduler ,
574574 save_intermediates = True ,
575575 )
@@ -615,7 +615,7 @@ def test_resample_likelihoods(
615615 sample , intermediates = inferer .get_likelihood (
616616 inputs = input ,
617617 autoencoder_model = stage_1 ,
618- diffusion_model = stage_2 ,
618+ network = stage_2 ,
619619 scheduler = scheduler ,
620620 save_intermediates = True ,
621621 resample_latent_likelihoods = True ,
@@ -625,7 +625,7 @@ def test_resample_likelihoods(
625625 sample , intermediates = inferer .get_likelihood (
626626 inputs = input ,
627627 autoencoder_model = stage_1 ,
628- diffusion_model = stage_2 ,
628+ network = stage_2 ,
629629 scheduler = scheduler ,
630630 save_intermediates = True ,
631631 resample_latent_likelihoods = True ,
@@ -682,7 +682,7 @@ def test_prediction_shape_conditioned_concat(
682682 prediction = inferer (
683683 inputs = input ,
684684 autoencoder_model = stage_1 ,
685- diffusion_model = stage_2 ,
685+ network = stage_2 ,
686686 noise = noise ,
687687 timesteps = timesteps ,
688688 condition = conditioning ,
@@ -693,7 +693,7 @@ def test_prediction_shape_conditioned_concat(
693693 prediction = inferer (
694694 inputs = input ,
695695 autoencoder_model = stage_1 ,
696- diffusion_model = stage_2 ,
696+ network = stage_2 ,
697697 noise = noise ,
698698 timesteps = timesteps ,
699699 condition = conditioning ,
@@ -747,7 +747,7 @@ def test_sample_shape_conditioned_concat(
747747 sample = inferer .sample (
748748 input_noise = noise ,
749749 autoencoder_model = stage_1 ,
750- diffusion_model = stage_2 ,
750+ network = stage_2 ,
751751 scheduler = scheduler ,
752752 conditioning = conditioning ,
753753 mode = "concat" ,
@@ -757,7 +757,7 @@ def test_sample_shape_conditioned_concat(
757757 sample = inferer .sample (
758758 input_noise = noise ,
759759 autoencoder_model = stage_1 ,
760- diffusion_model = stage_2 ,
760+ network = stage_2 ,
761761 scheduler = scheduler ,
762762 conditioning = conditioning ,
763763 mode = "concat" ,
@@ -813,14 +813,14 @@ def test_shape_different_latents(
813813 prediction = inferer (
814814 inputs = input ,
815815 autoencoder_model = stage_1 ,
816- diffusion_model = stage_2 ,
816+ network = stage_2 ,
817817 noise = noise ,
818818 timesteps = timesteps ,
819819 seg = input_seg ,
820820 )
821821 else :
822822 prediction = inferer (
823- inputs = input , autoencoder_model = stage_1 , diffusion_model = stage_2 , noise = noise , timesteps = timesteps
823+ inputs = input , autoencoder_model = stage_1 , network = stage_2 , noise = noise , timesteps = timesteps
824824 )
825825 self .assertEqual (prediction .shape , latent_shape )
826826
@@ -875,14 +875,14 @@ def test_sample_shape_different_latents(
875875 input_seg = torch .randn (input_shape_seg ).to (device )
876876 prediction , _ = inferer .sample (
877877 autoencoder_model = stage_1 ,
878- diffusion_model = stage_2 ,
878+ network = stage_2 ,
879879 input_noise = noise ,
880880 save_intermediates = True ,
881881 seg = input_seg ,
882882 )
883883 else :
884884 prediction = inferer .sample (
885- autoencoder_model = stage_1 , diffusion_model = stage_2 , input_noise = noise , save_intermediates = False
885+ autoencoder_model = stage_1 , network = stage_2 , input_noise = noise , save_intermediates = False
886886 )
887887 self .assertEqual (prediction .shape , input_shape )
888888
@@ -929,7 +929,7 @@ def test_incompatible_spade_setup(self):
929929 _ = inferer .sample (
930930 input_noise = noise ,
931931 autoencoder_model = stage_1 ,
932- diffusion_model = stage_2 ,
932+ network = stage_2 ,
933933 scheduler = scheduler ,
934934 seg = input_seg ,
935935 )
0 commit comments