Skip to content

Commit d7ea3ed

Browse files
committed
add rectified flow for accelerated diffusion model
Signed-off-by: Can-Zhao <[email protected]>
1 parent 83a4676 commit d7ea3ed

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

monai/networks/schedulers/rectified_flow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ class RFlowScheduler(Scheduler):
9797
)
9898
9999
# during training
100-
inputs = torch.ones(2,4,64,64,64)
100+
inputs = torch.ones(2,4,64,64,32)
101101
noise = torch.randn_like(inputs)
102102
timesteps = noise_scheduler.sample_timesteps(inputs)
103103
noisy_inputs = noise_scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
@@ -108,7 +108,7 @@ class RFlowScheduler(Scheduler):
108108
loss = loss_l1(predicted_velocity, (inputs - noise))
109109
110110
# during inference
111-
noisy_inputs = torch.randn(2,4,64,64,64)
111+
noisy_inputs = torch.randn(2,4,64,64,32)
112112
input_img_size_numel = torch.prod(torch.tensor(noisy_inputs.shape[-3:])
113113
noise_scheduler.set_timesteps(
114114
num_inference_steps=30, input_img_size_numel=input_img_size_numel)

0 commit comments

Comments
 (0)