@@ -82,6 +82,7 @@ class RFlowScheduler(Scheduler):
8282 transform_scale (float): Scaling factor for timestep transformation, used only if use_timestep_transform=True.
8383 steps_offset (int): Offset added to computed timesteps, used only if use_timestep_transform=True.
8484 base_img_size_numel (int): Reference image volume size for scaling, used only if use_timestep_transform=True.
85+ spatial_dim (int): 2 or 3, incidcating 2D or 3D images, used only if use_timestep_transform=True.
8586
8687 Example:
8788
@@ -93,7 +94,8 @@ class RFlowScheduler(Scheduler):
9394 use_discrete_timesteps = True,
9495 sample_method = 'logit-normal',
9596 use_timestep_transform = True,
96- base_img_size_numel = 32 * 32 * 32
97+ base_img_size_numel = 32 * 32 * 32,
98+ spatial_dim = 3
9799 )
98100
99101 # during training
@@ -139,10 +141,12 @@ def __init__(
139141 transform_scale : float = 1.0 ,
140142 steps_offset : int = 0 ,
141143 base_img_size_numel : int = 32 * 32 * 32 ,
144+ spatial_dim : int = 3
142145 ):
143146 self .num_train_timesteps = num_train_timesteps
144147 self .use_discrete_timesteps = use_discrete_timesteps
145148 self .base_img_size_numel = base_img_size_numel
149+ self .spatial_dim = spatial_dim
146150
147151 # sample method
148152 if sample_method not in ["uniform" , "logit-normal" ]:
@@ -166,7 +170,7 @@ def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timeste
166170 Args:
167171 original_samples: original samples
168172 noise: noise to add to samples
169- timesteps: timesteps tensor indicating the timestep to be computed for each sample.
173+ timesteps: timesteps tensor with shape of (N,), indicating the timestep to be computed for each sample.
170174
171175 Returns:
172176 noisy_samples: sample with added noise
@@ -175,14 +179,14 @@ def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timeste
175179 timepoints = 1 - timepoints # [1,1/1000]
176180
177181 # expand timepoint to noise shape
182+ # Just in case timepoints is not 1D or 2D tensor, make it to be same shape as noise
178183 if len (noise .shape ) == 5 :
179184 timepoints = timepoints .unsqueeze (1 ).unsqueeze (1 ).unsqueeze (1 ).unsqueeze (1 )
180185 timepoints = timepoints .repeat (1 , noise .shape [1 ], noise .shape [2 ], noise .shape [3 ], noise .shape [4 ])
181186 elif len (noise .shape ) == 4 :
182187 timepoints = timepoints .unsqueeze (1 ).unsqueeze (1 ).unsqueeze (1 )
183188 timepoints = timepoints .repeat (1 , noise .shape [1 ], noise .shape [2 ], noise .shape [3 ])
184- else :
185- raise ValueError (f"noise has to be 4D or 5D tensor. yet got shape of { noise .shape } ." )
189+
186190 noisy_samples : torch .Tensor = timepoints * original_samples + (1 - timepoints ) * noise
187191
188192 return noisy_samples
@@ -223,6 +227,7 @@ def set_timesteps(
223227 input_img_size_numel = input_img_size_numel ,
224228 base_img_size_numel = self .base_img_size_numel ,
225229 num_train_timesteps = self .num_train_timesteps ,
230+ spatial_dim = self .spatial_dim
226231 )
227232 for t in timesteps
228233 ]
@@ -257,6 +262,7 @@ def sample_timesteps(self, x_start):
257262 input_img_size_numel = input_img_size_numel ,
258263 base_img_size_numel = self .base_img_size_numel ,
259264 num_train_timesteps = self .num_train_timesteps ,
265+ spatial_dim = len (x_start .shape )- 2
260266 )
261267
262268 return t
0 commit comments