Skip to content

Commit e320ecc

Browse files
committed
reformat
Signed-off-by: Can-Zhao <[email protected]>
1 parent 0bf0041 commit e320ecc

File tree

3 files changed

+42
-16
lines changed

3 files changed

+42
-16
lines changed

monai/inferers/inferer.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,12 +1402,18 @@ def sample( # type: ignore[override]
14021402
if not scheduler:
14031403
scheduler = self.scheduler
14041404
image = input_noise
1405+
1406+
all_next_timesteps = torch.cat((scheduler.timesteps[1:], torch.tensor([0], dtype=scheduler.timesteps.dtype)))
14051407
if verbose and has_tqdm:
1406-
progress_bar = tqdm(scheduler.timesteps)
1408+
progress_bar = tqdm(
1409+
zip(scheduler.timesteps, all_next_timesteps),
1410+
total=min(len(scheduler.timesteps), len(all_next_timesteps)),
1411+
)
14071412
else:
1408-
progress_bar = iter(scheduler.timesteps)
1413+
progress_bar = iter(zip(scheduler.timesteps, all_next_timesteps))
14091414
intermediates = []
1410-
for t in progress_bar:
1415+
1416+
for t, next_t in progress_bar:
14111417
diffuse = diffusion_model
14121418
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
14131419
diffuse = partial(diffusion_model, seg=seg)
@@ -1446,7 +1452,11 @@ def sample( # type: ignore[override]
14461452
)
14471453

14481454
# 3. compute previous image: x_t -> x_t-1
1449-
image, _ = scheduler.step(model_output, t, image) # type: ignore[operator]
1455+
if not isinstance(scheduler, RFlowScheduler):
1456+
image, _ = scheduler.step(model_output, t, image)
1457+
else:
1458+
image, _ = scheduler.step(model_output, t, image, next_t)
1459+
14501460
if save_intermediates and t % intermediate_steps == 0:
14511461
intermediates.append(image)
14521462
if save_intermediates:

tests/inferers/test_controlnet_inferers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,7 @@ def test_sampler_conditioned(self, model_params, controlnet_params, input_shape)
587587
controlnet.eval()
588588
mask = torch.randn(input_shape).to(device)
589589
noise = torch.randn(input_shape).to(device)
590-
590+
591591
# DDIM
592592
scheduler = DDIMScheduler(num_train_timesteps=1000)
593593
inferer = ControlNetDiffusionInferer(scheduler=scheduler)
@@ -755,11 +755,11 @@ def test_prediction_shape(
755755
input = torch.randn(input_shape).to(device)
756756
mask = torch.randn(input_shape).to(device)
757757
noise = torch.randn(latent_shape).to(device)
758+
758759
for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]:
759760
inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
760761
scheduler.set_timesteps(num_inference_steps=10)
761762
timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
762-
763763
if dm_model_type == "SPADEDiffusionModelUNet":
764764
input_shape_seg = list(input_shape)
765765
if "label_nc" in stage_2_params.keys():

tests/networks/schedulers/test_scheduler_rflow.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,34 @@
2121

2222
TEST_2D_CASE = []
2323
for sample_method in ["uniform", "logit-normal"]:
24-
TEST_2D_CASE.append([{"sample_method": sample_method, "use_timestep_transform": False}, (2, 6, 16, 16), (2, 6, 16, 16)])
24+
TEST_2D_CASE.append(
25+
[{"sample_method": sample_method, "use_timestep_transform": False}, (2, 6, 16, 16), (2, 6, 16, 16)]
26+
)
2527

2628
for sample_method in ["uniform", "logit-normal"]:
27-
TEST_2D_CASE.append([{"sample_method": sample_method, "use_timestep_transform": True, "spatial_dim": 2}, (2, 6, 16, 16), (2, 6, 16, 16)])
29+
TEST_2D_CASE.append(
30+
[
31+
{"sample_method": sample_method, "use_timestep_transform": True, "spatial_dim": 2},
32+
(2, 6, 16, 16),
33+
(2, 6, 16, 16),
34+
]
35+
)
2836

2937

3038
TEST_3D_CASE = []
3139
for sample_method in ["uniform", "logit-normal"]:
32-
TEST_3D_CASE.append([{"sample_method": sample_method, "use_timestep_transform": False}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)])
40+
TEST_3D_CASE.append(
41+
[{"sample_method": sample_method, "use_timestep_transform": False}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)]
42+
)
3343

3444
for sample_method in ["uniform", "logit-normal"]:
35-
TEST_3D_CASE.append([{"sample_method": sample_method, "use_timestep_transform": True, "spatial_dim": 3}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)])
45+
TEST_3D_CASE.append(
46+
[
47+
{"sample_method": sample_method, "use_timestep_transform": True, "spatial_dim": 3},
48+
(2, 6, 16, 16, 16),
49+
(2, 6, 16, 16, 16),
50+
]
51+
)
3652

3753
TEST_CASES = TEST_2D_CASE + TEST_3D_CASE
3854

@@ -54,35 +70,35 @@ def test_add_noise(self, input_param, input_shape, expected_shape):
5470

5571
@parameterized.expand(TEST_CASES)
5672
def test_step_shape(self, input_param, input_shape, expected_shape):
57-
scheduler = RFlowScheduler(**input_param)
73+
scheduler = RFlowScheduler(**input_param)
5874
model_output = torch.randn(input_shape)
5975
sample = torch.randn(input_shape)
60-
scheduler.set_timesteps(num_inference_steps=100, input_img_size_numel=torch.numel(sample[0,0,...]))
76+
scheduler.set_timesteps(num_inference_steps=100, input_img_size_numel=torch.numel(sample[0, 0, ...]))
6177
output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample)
6278
self.assertEqual(output_step[0].shape, expected_shape)
6379
self.assertEqual(output_step[1].shape, expected_shape)
6480

6581
@parameterized.expand(TEST_FULl_LOOP)
6682
def test_full_timestep_loop(self, input_param, input_shape, expected_output):
67-
scheduler = RFlowScheduler(**input_param)
83+
scheduler = RFlowScheduler(**input_param)
6884
torch.manual_seed(42)
6985
model_output = torch.randn(input_shape)
7086
sample = torch.randn(input_shape)
71-
scheduler.set_timesteps(50, input_img_size_numel=torch.numel(sample[0,0,...]))
87+
scheduler.set_timesteps(50, input_img_size_numel=torch.numel(sample[0, 0, ...]))
7288
for t in range(50):
7389
sample, _ = scheduler.step(model_output=model_output, timestep=t, sample=sample)
7490
assert_allclose(sample, expected_output, rtol=1e-3, atol=1e-3)
7591

7692
def test_set_timesteps(self):
7793
scheduler = RFlowScheduler(num_train_timesteps=1000)
78-
scheduler.set_timesteps(num_inference_steps=100, input_img_size_numel=16*16*16)
94+
scheduler.set_timesteps(num_inference_steps=100, input_img_size_numel=16 * 16 * 16)
7995
self.assertEqual(scheduler.num_inference_steps, 100)
8096
self.assertEqual(len(scheduler.timesteps), 100)
8197

8298
def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self):
8399
scheduler = RFlowScheduler(num_train_timesteps=1000)
84100
with self.assertRaises(ValueError):
85-
scheduler.set_timesteps(num_inference_steps=2000, input_img_size_numel=16*16*16)
101+
scheduler.set_timesteps(num_inference_steps=2000, input_img_size_numel=16 * 16 * 16)
86102

87103

88104
if __name__ == "__main__":

0 commit comments

Comments
 (0)