Skip to content

Commit 36abb97

Browse files
author
Virginia Fernandez
committed
Fix tests formatting.
1 parent 76f87e5 commit 36abb97

File tree

2 files changed

+4
-21
lines changed

2 files changed

+4
-21
lines changed

tests/inferers/test_diffusion_inferer.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,7 @@ def test_sample_cfg(self, model_params, input_shape):
100100
inferer = DiffusionInferer(scheduler=scheduler)
101101
scheduler.set_timesteps(num_inference_steps=10)
102102
sample, intermediates = inferer.sample(
103-
input_noise=noise,
104-
network=model,
105-
scheduler=scheduler,
106-
save_intermediates=True,
107-
intermediate_steps=1,
108-
cfg=5,
103+
input_noise=noise, network=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1, cfg=5
109104
)
110105
self.assertEqual(sample.shape, noise.shape)
111106

tests/inferers/test_latent_diffusion_inferer.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -402,11 +402,7 @@ def test_sample_shape(
402402
input_shape_seg[1] = autoencoder_params["label_nc"]
403403
input_seg = torch.randn(input_shape_seg).to(device)
404404
sample = inferer.sample(
405-
input_noise=noise,
406-
autoencoder_model=stage_1,
407-
network=stage_2,
408-
scheduler=scheduler,
409-
seg=input_seg,
405+
input_noise=noise, autoencoder_model=stage_1, network=stage_2, scheduler=scheduler, seg=input_seg
410406
)
411407
else:
412408
sample = inferer.sample(
@@ -567,11 +563,7 @@ def test_get_likelihoods(
567563
)
568564
else:
569565
sample, intermediates = inferer.get_likelihood(
570-
inputs=input,
571-
autoencoder_model=stage_1,
572-
network=stage_2,
573-
scheduler=scheduler,
574-
save_intermediates=True,
566+
inputs=input, autoencoder_model=stage_1, network=stage_2, scheduler=scheduler, save_intermediates=True
575567
)
576568
self.assertEqual(len(intermediates), 10)
577569
self.assertEqual(intermediates[0].shape, latent_shape)
@@ -927,11 +919,7 @@ def test_incompatible_spade_setup(self):
927919

928920
with self.assertRaises(ValueError):
929921
_ = inferer.sample(
930-
input_noise=noise,
931-
autoencoder_model=stage_1,
932-
network=stage_2,
933-
scheduler=scheduler,
934-
seg=input_seg,
922+
input_noise=noise, autoencoder_model=stage_1, network=stage_2, scheduler=scheduler, seg=input_seg
935923
)
936924

937925

0 commit comments

Comments
 (0)