Skip to content

Commit f3d67ce

Browse files
author
Virginia Fernandez
committed
Fix tests
1 parent 4d2b365 commit f3d67ce

File tree

2 files changed

+27
-27
lines changed

2 files changed

+27
-27
lines changed

tests/inferers/test_diffusion_inferer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def test_call(self, model_params, input_shape):
6969
inferer = DiffusionInferer(scheduler=scheduler)
7070
scheduler.set_timesteps(num_inference_steps=10)
7171
timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
72-
sample = inferer(inputs=input, noise=noise, diffusion_model=model, timesteps=timesteps)
72+
sample = inferer(inputs=input, noise=noise, network=model, timesteps=timesteps)
7373
self.assertEqual(sample.shape, input_shape)
7474

7575
@parameterized.expand(TEST_CASES)
@@ -84,7 +84,7 @@ def test_sample_intermediates(self, model_params, input_shape):
8484
inferer = DiffusionInferer(scheduler=scheduler)
8585
scheduler.set_timesteps(num_inference_steps=10)
8686
sample, intermediates = inferer.sample(
87-
input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1
87+
input_noise=noise, network=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1
8888
)
8989
self.assertEqual(len(intermediates), 10)
9090

@@ -101,7 +101,7 @@ def test_sample_cfg(self, model_params, input_shape):
101101
scheduler.set_timesteps(num_inference_steps=10)
102102
sample, intermediates = inferer.sample(
103103
input_noise=noise,
104-
diffusion_model=model,
104+
network=model,
105105
scheduler=scheduler,
106106
save_intermediates=True,
107107
intermediate_steps=1,
@@ -121,7 +121,7 @@ def test_ddpm_sampler(self, model_params, input_shape):
121121
inferer = DiffusionInferer(scheduler=scheduler)
122122
scheduler.set_timesteps(num_inference_steps=10)
123123
sample, intermediates = inferer.sample(
124-
input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1
124+
input_noise=noise, network=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1
125125
)
126126
self.assertEqual(len(intermediates), 10)
127127

@@ -137,7 +137,7 @@ def test_ddim_sampler(self, model_params, input_shape):
137137
inferer = DiffusionInferer(scheduler=scheduler)
138138
scheduler.set_timesteps(num_inference_steps=10)
139139
sample, intermediates = inferer.sample(
140-
input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1
140+
input_noise=noise, network=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1
141141
)
142142
self.assertEqual(len(intermediates), 10)
143143

@@ -153,7 +153,7 @@ def test_rflow_sampler(self, model_params, input_shape):
153153
inferer = DiffusionInferer(scheduler=scheduler)
154154
scheduler.set_timesteps(num_inference_steps=10)
155155
sample, intermediates = inferer.sample(
156-
input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1
156+
input_noise=noise, network=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1
157157
)
158158
self.assertEqual(len(intermediates), 10)
159159

tests/inferers/test_latent_diffusion_inferer.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -355,14 +355,14 @@ def test_prediction_shape(
355355
prediction = inferer(
356356
inputs=input,
357357
autoencoder_model=stage_1,
358-
diffusion_model=stage_2,
358+
network=stage_2,
359359
seg=input_seg,
360360
noise=noise,
361361
timesteps=timesteps,
362362
)
363363
else:
364364
prediction = inferer(
365-
inputs=input, autoencoder_model=stage_1, diffusion_model=stage_2, noise=noise, timesteps=timesteps
365+
inputs=input, autoencoder_model=stage_1, network=stage_2, noise=noise, timesteps=timesteps
366366
)
367367
self.assertEqual(prediction.shape, latent_shape)
368368

@@ -404,13 +404,13 @@ def test_sample_shape(
404404
sample = inferer.sample(
405405
input_noise=noise,
406406
autoencoder_model=stage_1,
407-
diffusion_model=stage_2,
407+
network=stage_2,
408408
scheduler=scheduler,
409409
seg=input_seg,
410410
)
411411
else:
412412
sample = inferer.sample(
413-
input_noise=noise, autoencoder_model=stage_1, diffusion_model=stage_2, scheduler=scheduler
413+
input_noise=noise, autoencoder_model=stage_1, network=stage_2, scheduler=scheduler
414414
)
415415
self.assertEqual(sample.shape, input_shape)
416416

@@ -452,14 +452,14 @@ def test_sample_shape_with_cfg(
452452
sample = inferer.sample(
453453
input_noise=noise,
454454
autoencoder_model=stage_1,
455-
diffusion_model=stage_2,
455+
network=stage_2,
456456
scheduler=scheduler,
457457
seg=input_seg,
458458
cfg=5,
459459
)
460460
else:
461461
sample = inferer.sample(
462-
input_noise=noise, autoencoder_model=stage_1, diffusion_model=stage_2, scheduler=scheduler, cfg=5
462+
input_noise=noise, autoencoder_model=stage_1, network=stage_2, scheduler=scheduler, cfg=5
463463
)
464464
self.assertEqual(sample.shape, input_shape)
465465

@@ -503,7 +503,7 @@ def test_sample_intermediates(
503503
sample, intermediates = inferer.sample(
504504
input_noise=noise,
505505
autoencoder_model=stage_1,
506-
diffusion_model=stage_2,
506+
network=stage_2,
507507
scheduler=scheduler,
508508
seg=input_seg,
509509
save_intermediates=True,
@@ -513,7 +513,7 @@ def test_sample_intermediates(
513513
sample, intermediates = inferer.sample(
514514
input_noise=noise,
515515
autoencoder_model=stage_1,
516-
diffusion_model=stage_2,
516+
network=stage_2,
517517
scheduler=scheduler,
518518
save_intermediates=True,
519519
intermediate_steps=1,
@@ -560,7 +560,7 @@ def test_get_likelihoods(
560560
sample, intermediates = inferer.get_likelihood(
561561
inputs=input,
562562
autoencoder_model=stage_1,
563-
diffusion_model=stage_2,
563+
network=stage_2,
564564
scheduler=scheduler,
565565
save_intermediates=True,
566566
seg=input_seg,
@@ -569,7 +569,7 @@ def test_get_likelihoods(
569569
sample, intermediates = inferer.get_likelihood(
570570
inputs=input,
571571
autoencoder_model=stage_1,
572-
diffusion_model=stage_2,
572+
network=stage_2,
573573
scheduler=scheduler,
574574
save_intermediates=True,
575575
)
@@ -615,7 +615,7 @@ def test_resample_likelihoods(
615615
sample, intermediates = inferer.get_likelihood(
616616
inputs=input,
617617
autoencoder_model=stage_1,
618-
diffusion_model=stage_2,
618+
network=stage_2,
619619
scheduler=scheduler,
620620
save_intermediates=True,
621621
resample_latent_likelihoods=True,
@@ -625,7 +625,7 @@ def test_resample_likelihoods(
625625
sample, intermediates = inferer.get_likelihood(
626626
inputs=input,
627627
autoencoder_model=stage_1,
628-
diffusion_model=stage_2,
628+
network=stage_2,
629629
scheduler=scheduler,
630630
save_intermediates=True,
631631
resample_latent_likelihoods=True,
@@ -682,7 +682,7 @@ def test_prediction_shape_conditioned_concat(
682682
prediction = inferer(
683683
inputs=input,
684684
autoencoder_model=stage_1,
685-
diffusion_model=stage_2,
685+
network=stage_2,
686686
noise=noise,
687687
timesteps=timesteps,
688688
condition=conditioning,
@@ -693,7 +693,7 @@ def test_prediction_shape_conditioned_concat(
693693
prediction = inferer(
694694
inputs=input,
695695
autoencoder_model=stage_1,
696-
diffusion_model=stage_2,
696+
network=stage_2,
697697
noise=noise,
698698
timesteps=timesteps,
699699
condition=conditioning,
@@ -747,7 +747,7 @@ def test_sample_shape_conditioned_concat(
747747
sample = inferer.sample(
748748
input_noise=noise,
749749
autoencoder_model=stage_1,
750-
diffusion_model=stage_2,
750+
network=stage_2,
751751
scheduler=scheduler,
752752
conditioning=conditioning,
753753
mode="concat",
@@ -757,7 +757,7 @@ def test_sample_shape_conditioned_concat(
757757
sample = inferer.sample(
758758
input_noise=noise,
759759
autoencoder_model=stage_1,
760-
diffusion_model=stage_2,
760+
network=stage_2,
761761
scheduler=scheduler,
762762
conditioning=conditioning,
763763
mode="concat",
@@ -813,14 +813,14 @@ def test_shape_different_latents(
813813
prediction = inferer(
814814
inputs=input,
815815
autoencoder_model=stage_1,
816-
diffusion_model=stage_2,
816+
network=stage_2,
817817
noise=noise,
818818
timesteps=timesteps,
819819
seg=input_seg,
820820
)
821821
else:
822822
prediction = inferer(
823-
inputs=input, autoencoder_model=stage_1, diffusion_model=stage_2, noise=noise, timesteps=timesteps
823+
inputs=input, autoencoder_model=stage_1, network=stage_2, noise=noise, timesteps=timesteps
824824
)
825825
self.assertEqual(prediction.shape, latent_shape)
826826

@@ -875,14 +875,14 @@ def test_sample_shape_different_latents(
875875
input_seg = torch.randn(input_shape_seg).to(device)
876876
prediction, _ = inferer.sample(
877877
autoencoder_model=stage_1,
878-
diffusion_model=stage_2,
878+
network=stage_2,
879879
input_noise=noise,
880880
save_intermediates=True,
881881
seg=input_seg,
882882
)
883883
else:
884884
prediction = inferer.sample(
885-
autoencoder_model=stage_1, diffusion_model=stage_2, input_noise=noise, save_intermediates=False
885+
autoencoder_model=stage_1, network=stage_2, input_noise=noise, save_intermediates=False
886886
)
887887
self.assertEqual(prediction.shape, input_shape)
888888

@@ -929,7 +929,7 @@ def test_incompatible_spade_setup(self):
929929
_ = inferer.sample(
930930
input_noise=noise,
931931
autoencoder_model=stage_1,
932-
diffusion_model=stage_2,
932+
network=stage_2,
933933
scheduler=scheduler,
934934
seg=input_seg,
935935
)

0 commit comments

Comments
 (0)