@@ -141,7 +141,7 @@ def __init__(
141141 transform_scale : float = 1.0 ,
142142 steps_offset : int = 0 ,
143143 base_img_size_numel : int = 32 * 32 * 32 ,
144- spatial_dim : int = 3
144+ spatial_dim : int = 3 ,
145145 ):
146146 self .num_train_timesteps = num_train_timesteps
147147 self .use_discrete_timesteps = use_discrete_timesteps
@@ -179,12 +179,12 @@ def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timeste
179179 timepoints = 1 - timepoints # [1,1/1000]
180180
181181 # expand timepoint to noise shape
182- if len ( noise .shape ) == 5 :
183- timepoints = timepoints . unsqueeze ( 1 ). unsqueeze ( 1 ). unsqueeze ( 1 ). unsqueeze ( 1 )
184- timepoints = timepoints . repeat ( 1 , noise .shape [ 1 ], noise . shape [ 2 ], noise . shape [ 3 ], noise . shape [ 4 ])
185- elif len ( noise .shape ) == 4 :
186- timepoints = timepoints . unsqueeze ( 1 ). unsqueeze ( 1 ). unsqueeze ( 1 )
187- timepoints = timepoints . repeat ( 1 , noise . shape [ 1 ], noise . shape [ 2 ], noise .shape [ 3 ] )
182+ if noise .ndim == 5 :
183+ timepoints = timepoints [..., None , None , None , None ]. expand ( - 1 , * noise . shape [ 1 :] )
184+ elif noise .ndim == 4 :
185+ timepoints = timepoints [..., None , None , None ]. expand ( - 1 , * noise .shape [ 1 :])
186+ else :
187+ raise ValueError ( f" noise tensor has to be 4D or 5D tensor, yet got shape of { noise .shape } " )
188188
189189 noisy_samples : torch .Tensor = timepoints * original_samples + (1 - timepoints ) * noise
190190
@@ -226,7 +226,7 @@ def set_timesteps(
226226 input_img_size_numel = input_img_size_numel ,
227227 base_img_size_numel = self .base_img_size_numel ,
228228 num_train_timesteps = self .num_train_timesteps ,
229- spatial_dim = self .spatial_dim
229+ spatial_dim = self .spatial_dim ,
230230 )
231231 for t in timesteps
232232 ]
@@ -261,7 +261,7 @@ def sample_timesteps(self, x_start):
261261 input_img_size_numel = input_img_size_numel ,
262262 base_img_size_numel = self .base_img_size_numel ,
263263 num_train_timesteps = self .num_train_timesteps ,
264- spatial_dim = len (x_start .shape )- 2
264+ spatial_dim = len (x_start .shape ) - 2 ,
265265 )
266266
267267 return t
0 commit comments