@@ -159,19 +159,17 @@ def __init__(
159159 self .transform_scale = transform_scale
160160 self .steps_offset = steps_offset
161161
162- def add_noise (
163- self , original_samples : torch .FloatTensor , noise : torch .FloatTensor , timesteps : torch .IntTensor
164- ) -> torch .FloatTensor :
162+ def add_noise (self , original_samples : torch .Tensor , noise : torch .Tensor , timesteps : torch .Tensor ) -> torch .Tensor :
165163 """
166- Adds noise to the original samples based on the given timesteps .
164+ Add noise to the original samples.
167165
168166 Args:
169- original_samples (torch.FloatTensor): The original sample tensor.
170- noise (torch.FloatTensor): Noise tensor to be added.
171- timesteps (torch.IntTensor): Timesteps corresponding to each sample.
167+ original_samples: original samples
168+ noise: noise to add to samples
169+ timesteps: timesteps tensor indicating the timestep to be computed for each sample.
172170
173171 Returns:
174- torch.FloatTensor: The noisy sample tensor.
172+ noisy_samples: sample with added noise
175173 """
176174 timepoints = timesteps .float () / self .num_train_timesteps
177175 timepoints = 1 - timepoints # [1,1/1000]
@@ -221,10 +219,10 @@ def set_timesteps(
221219 )
222220 for t in timesteps
223221 ]
224- timesteps = np .array (timesteps ).astype (np .float16 )
222+ timesteps_np = np .array (timesteps ).astype (np .float16 )
225223 if self .use_discrete_timesteps :
226- timesteps = timesteps .astype (np .int64 )
227- self .timesteps = torch .from_numpy (timesteps ).to (device )
224+ timesteps_np = timesteps_np .astype (np .int64 )
225+ self .timesteps = torch .from_numpy (timesteps_np ).to (device )
228226 self .timesteps += self .steps_offset
229227
230228 def sample_timesteps (self , x_start ):
@@ -257,7 +255,7 @@ def sample_timesteps(self, x_start):
257255 return t
258256
259257 def step (
260- self , model_output : torch .Tensor , timestep : int , sample : torch .Tensor , next_timestep = None
258+ self , model_output : torch .Tensor , timestep : int , sample : torch .Tensor , next_timestep : int | None = None
261259 ) -> tuple [torch .Tensor , Any ]:
262260 """
263261 Predict the sample at the previous timestep. Core function to propagate the diffusion
0 commit comments