Skip to content

Commit 4bf6c02

Browse files
committed
reformat
Signed-off-by: Can-Zhao <[email protected]>
1 parent 1294ceb commit 4bf6c02

File tree

1 file changed

+30
-18
lines changed

1 file changed

+30
-18
lines changed

monai/networks/schedulers/rectified_flow.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
from __future__ import annotations
3030

31-
from typing import Any
31+
from typing import Any, Union
3232

3333
import numpy as np
3434
import torch
@@ -171,15 +171,16 @@ def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timeste
171171
Returns:
172172
noisy_samples: sample with added noise
173173
"""
174-
timepoints = timesteps.float() / self.num_train_timesteps
174+
timepoints: torch.Tensor = timesteps.float() / self.num_train_timesteps
175175
timepoints = 1 - timepoints # [1,1/1000]
176176

177177
# timepoint (bsz) noise: (bsz, 4, frame, w ,h)
178178
# expand timepoint to noise shape
179179
timepoints = timepoints.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1)
180180
timepoints = timepoints.repeat(1, noise.shape[1], noise.shape[2], noise.shape[3], noise.shape[4])
181+
noisy_samples: torch.Tensor = timepoints * original_samples + (1 - timepoints) * noise
181182

182-
return timepoints * original_samples + (1 - timepoints) * noise
183+
return noisy_samples
183184

184185
def set_timesteps(
185186
self,
@@ -255,27 +256,38 @@ def sample_timesteps(self, x_start):
255256
return t
256257

257258
def step(
258-
self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, next_timestep: int | None = None
259-
) -> tuple[torch.Tensor, Any]:
259+
self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, next_timestep: Union[int, None] = None
260+
) -> tuple[torch.Tensor, None]:
260261
"""
261-
Predict the sample at the previous timestep. Core function to propagate the diffusion
262-
process from the learned model outputs.
262+
Predicts the next sample in the diffusion process.
263263
264264
Args:
265-
model_output: direct output from learned diffusion model.
266-
timestep: current discrete timestep in the diffusion chain.
267-
sample: current instance of sample being created by diffusion process.
268-
next_timestep: next discrete timestep in the diffusion chain.
265+
model_output (torch.Tensor): Output from the trained diffusion model.
266+
timestep (int): Current timestep in the diffusion chain.
267+
sample (torch.Tensor): Current sample in the process.
268+
next_timestep (Union[int, None]): Optional next timestep.
269+
269270
Returns:
270-
pred_prev_sample: Predicted previous sample
271-
None
271+
tuple[torch.Tensor, None]: Predicted sample at the next step and additional info.
272272
"""
273+
# Ensure num_inference_steps exists and is a valid integer
274+
if not hasattr(self, "num_inference_steps") or not isinstance(self.num_inference_steps, int):
275+
raise AttributeError(
276+
"num_inference_steps is missing or not an integer in the class."
277+
"Please run self.set_timesteps(num_inference_steps,device,input_img_size_numel) to set it."
278+
)
279+
273280
v_pred = model_output
274-
if next_timestep is None:
275-
dt = 1.0 / self.num_inference_steps
281+
282+
if next_timestep is not None:
283+
next_timestep = int(next_timestep)
284+
dt: float = (
285+
float(timestep - next_timestep) / self.num_train_timesteps
286+
) # Now next_timestep is guaranteed to be int
276287
else:
277-
dt = timestep - next_timestep
278-
dt = dt / self.num_train_timesteps
279-
z = sample + v_pred * dt
288+
dt = (
289+
1.0 / float(self.num_inference_steps) if self.num_inference_steps > 0 else 0.0
290+
) # Avoid division by zero
280291

292+
z = sample + v_pred * dt
281293
return z, None

0 commit comments

Comments
 (0)