Skip to content

Commit 8555b67

Browse files
committed
make it 2D/3D compartible, rm a outdated comment
Signed-off-by: Can-Zhao <[email protected]>
1 parent eaa803f commit 8555b67

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

monai/networks/schedulers/rectified_flow.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,10 +174,15 @@ def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timeste
174174
timepoints: torch.Tensor = timesteps.float() / self.num_train_timesteps
175175
timepoints = 1 - timepoints # [1,1/1000]
176176

177-
# timepoint (bsz) noise: (bsz, 4, frame, w ,h)
178177
# expand timepoint to noise shape
179-
timepoints = timepoints.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1)
180-
timepoints = timepoints.repeat(1, noise.shape[1], noise.shape[2], noise.shape[3], noise.shape[4])
178+
if len(noise.shape) == 5:
179+
timepoints = timepoints.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1)
180+
timepoints = timepoints.repeat(1, noise.shape[1], noise.shape[2], noise.shape[3], noise.shape[4])
181+
elif len(noise.shape) == 4:
182+
timepoints = timepoints.unsqueeze(1).unsqueeze(1).unsqueeze(1)
183+
timepoints = timepoints.repeat(1, noise.shape[1], noise.shape[2], noise.shape[3])
184+
else:
185+
raise ValueError(f"noise has to be 4D or 5D tensor. yet got shape of {noise.shape}.")
181186
noisy_samples: torch.Tensor = timepoints * original_samples + (1 - timepoints) * noise
182187

183188
return noisy_samples
@@ -246,7 +251,7 @@ def sample_timesteps(self, x_start):
246251
t = t.long()
247252

248253
if self.use_timestep_transform:
249-
input_img_size_numel = torch.prod(torch.tensor(x_start.shape[-3:]))
254+
input_img_size_numel = torch.prod(torch.tensor(x_start.shape[2:]))
250255
t = timestep_transform(
251256
t,
252257
input_img_size_numel=input_img_size_numel,

0 commit comments

Comments
 (0)