@@ -174,10 +174,15 @@ def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timeste
174174 timepoints : torch .Tensor = timesteps .float () / self .num_train_timesteps
175175 timepoints = 1 - timepoints # [1,1/1000]
176176
177- # timepoint (bsz) noise: (bsz, 4, frame, w ,h)
178177 # expand timepoint to noise shape
179- timepoints = timepoints .unsqueeze (1 ).unsqueeze (1 ).unsqueeze (1 ).unsqueeze (1 )
180- timepoints = timepoints .repeat (1 , noise .shape [1 ], noise .shape [2 ], noise .shape [3 ], noise .shape [4 ])
178+ if len (noise .shape ) == 5 :
179+ timepoints = timepoints .unsqueeze (1 ).unsqueeze (1 ).unsqueeze (1 ).unsqueeze (1 )
180+ timepoints = timepoints .repeat (1 , noise .shape [1 ], noise .shape [2 ], noise .shape [3 ], noise .shape [4 ])
181+ elif len (noise .shape ) == 4 :
182+ timepoints = timepoints .unsqueeze (1 ).unsqueeze (1 ).unsqueeze (1 )
183+ timepoints = timepoints .repeat (1 , noise .shape [1 ], noise .shape [2 ], noise .shape [3 ])
184+ else :
185+ raise ValueError (f"noise has to be 4D or 5D tensor. yet got shape of { noise .shape } ." )
181186 noisy_samples : torch .Tensor = timepoints * original_samples + (1 - timepoints ) * noise
182187
183188 return noisy_samples
@@ -246,7 +251,7 @@ def sample_timesteps(self, x_start):
246251 t = t .long ()
247252
248253 if self .use_timestep_transform :
249- input_img_size_numel = torch .prod (torch .tensor (x_start .shape [- 3 :]))
254+ input_img_size_numel = torch .prod (torch .tensor (x_start .shape [2 :]))
250255 t = timestep_transform (
251256 t ,
252257 input_img_size_numel = input_img_size_numel ,
0 commit comments