Skip to content

Commit 1294ceb

Browse files
committed
reformat
Signed-off-by: Can-Zhao <[email protected]>
1 parent c14ce67 commit 1294ceb

File tree

1 file changed

+10
-12
lines changed

1 file changed

+10
-12
lines changed

monai/networks/schedulers/rectified_flow.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -159,19 +159,17 @@ def __init__(
159159
self.transform_scale = transform_scale
160160
self.steps_offset = steps_offset
161161

162-
def add_noise(
163-
self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
164-
) -> torch.FloatTensor:
162+
def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
165163
"""
166-
Adds noise to the original samples based on the given timesteps.
164+
Add noise to the original samples.
167165
168166
Args:
169-
original_samples (torch.FloatTensor): The original sample tensor.
170-
noise (torch.FloatTensor): Noise tensor to be added.
171-
timesteps (torch.IntTensor): Timesteps corresponding to each sample.
167+
original_samples: original samples
168+
noise: noise to add to samples
169+
timesteps: timesteps tensor indicating the timestep to be computed for each sample.
172170
173171
Returns:
174-
torch.FloatTensor: The noisy sample tensor.
172+
noisy_samples: sample with added noise
175173
"""
176174
timepoints = timesteps.float() / self.num_train_timesteps
177175
timepoints = 1 - timepoints # [1,1/1000]
@@ -221,10 +219,10 @@ def set_timesteps(
221219
)
222220
for t in timesteps
223221
]
224-
timesteps = np.array(timesteps).astype(np.float16)
222+
timesteps_np = np.array(timesteps).astype(np.float16)
225223
if self.use_discrete_timesteps:
226-
timesteps = timesteps.astype(np.int64)
227-
self.timesteps = torch.from_numpy(timesteps).to(device)
224+
timesteps_np = timesteps_np.astype(np.int64)
225+
self.timesteps = torch.from_numpy(timesteps_np).to(device)
228226
self.timesteps += self.steps_offset
229227

230228
def sample_timesteps(self, x_start):
@@ -257,7 +255,7 @@ def sample_timesteps(self, x_start):
257255
return t
258256

259257
def step(
260-
self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, next_timestep=None
258+
self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, next_timestep: int | None = None
261259
) -> tuple[torch.Tensor, Any]:
262260
"""
263261
Predict the sample at the previous timestep. Core function to propagate the diffusion

0 commit comments

Comments
 (0)