|
28 | 28 |
|
29 | 29 | from __future__ import annotations |
30 | 30 |
|
31 | | -from typing import Any |
| 31 | +from typing import Any, Union |
32 | 32 |
|
33 | 33 | import numpy as np |
34 | 34 | import torch |
@@ -171,15 +171,16 @@ def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timeste |
171 | 171 | Returns: |
172 | 172 | noisy_samples: sample with added noise |
173 | 173 | """ |
174 | | - timepoints = timesteps.float() / self.num_train_timesteps |
| 174 | + timepoints: torch.Tensor = timesteps.float() / self.num_train_timesteps |
175 | 175 | timepoints = 1 - timepoints # [1,1/1000] |
176 | 176 |
|
177 | 177 | # timepoint (bsz) noise: (bsz, 4, frame, w ,h) |
178 | 178 | # expand timepoint to noise shape |
179 | 179 | timepoints = timepoints.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1) |
180 | 180 | timepoints = timepoints.repeat(1, noise.shape[1], noise.shape[2], noise.shape[3], noise.shape[4]) |
| 181 | + noisy_samples: torch.Tensor = timepoints * original_samples + (1 - timepoints) * noise |
181 | 182 |
|
182 | | - return timepoints * original_samples + (1 - timepoints) * noise |
| 183 | + return noisy_samples |
183 | 184 |
|
184 | 185 | def set_timesteps( |
185 | 186 | self, |
@@ -255,27 +256,38 @@ def sample_timesteps(self, x_start): |
255 | 256 | return t |
256 | 257 |
|
257 | 258 | def step( |
258 | | - self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, next_timestep: int | None = None |
259 | | - ) -> tuple[torch.Tensor, Any]: |
| 259 | + self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, next_timestep: Union[int, None] = None |
| 260 | + ) -> tuple[torch.Tensor, None]: |
260 | 261 | """ |
261 | | - Predict the sample at the previous timestep. Core function to propagate the diffusion |
262 | | - process from the learned model outputs. |
| 262 | + Predicts the next sample in the diffusion process. |
263 | 263 |
|
264 | 264 | Args: |
265 | | - model_output: direct output from learned diffusion model. |
266 | | - timestep: current discrete timestep in the diffusion chain. |
267 | | - sample: current instance of sample being created by diffusion process. |
268 | | - next_timestep: next discrete timestep in the diffusion chain. |
| 265 | + model_output (torch.Tensor): Output from the trained diffusion model. |
| 266 | + timestep (int): Current timestep in the diffusion chain. |
| 267 | + sample (torch.Tensor): Current sample in the process. |
| 268 | + next_timestep (Union[int, None]): Optional next timestep. |
| 269 | +
|
269 | 270 | Returns: |
270 | | - pred_prev_sample: Predicted previous sample |
271 | | - None |
| 271 | + tuple[torch.Tensor, None]: Predicted sample at the next step and additional info. |
272 | 272 | """ |
| 273 | + # Ensure num_inference_steps exists and is a valid integer |
| 274 | + if not hasattr(self, "num_inference_steps") or not isinstance(self.num_inference_steps, int): |
| 275 | + raise AttributeError( |
| 276 | + "num_inference_steps is missing or not an integer in the class." |
| 277 | + "Please run self.set_timesteps(num_inference_steps,device,input_img_size_numel) to set it." |
| 278 | + ) |
| 279 | + |
273 | 280 | v_pred = model_output |
274 | | - if next_timestep is None: |
275 | | - dt = 1.0 / self.num_inference_steps |
| 281 | + |
| 282 | + if next_timestep is not None: |
| 283 | + next_timestep = int(next_timestep) |
| 284 | + dt: float = ( |
| 285 | + float(timestep - next_timestep) / self.num_train_timesteps |
| 286 | + ) # Now next_timestep is guaranteed to be int |
276 | 287 | else: |
277 | | - dt = timestep - next_timestep |
278 | | - dt = dt / self.num_train_timesteps |
279 | | - z = sample + v_pred * dt |
| 288 | + dt = ( |
| 289 | + 1.0 / float(self.num_inference_steps) if self.num_inference_steps > 0 else 0.0 |
| 290 | + ) # Avoid division by zero |
280 | 291 |
|
| 292 | + z = sample + v_pred * dt |
281 | 293 | return z, None |
0 commit comments