Skip to content

Commit 3144c8a

Browse files
committed
make it 2D/3D compartible
Signed-off-by: Can-Zhao <[email protected]>
1 parent 20aa7fd commit 3144c8a

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

monai/networks/schedulers/rectified_flow.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def __init__(
141141
transform_scale: float = 1.0,
142142
steps_offset: int = 0,
143143
base_img_size_numel: int = 32 * 32 * 32,
144-
spatial_dim: int = 3
144+
spatial_dim: int = 3,
145145
):
146146
self.num_train_timesteps = num_train_timesteps
147147
self.use_discrete_timesteps = use_discrete_timesteps
@@ -179,12 +179,12 @@ def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timeste
179179
timepoints = 1 - timepoints # [1,1/1000]
180180

181181
# expand timepoint to noise shape
182-
if len(noise.shape) == 5:
183-
timepoints = timepoints.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1)
184-
timepoints = timepoints.repeat(1, noise.shape[1], noise.shape[2], noise.shape[3], noise.shape[4])
185-
elif len(noise.shape) == 4:
186-
timepoints = timepoints.unsqueeze(1).unsqueeze(1).unsqueeze(1)
187-
timepoints = timepoints.repeat(1, noise.shape[1], noise.shape[2], noise.shape[3])
182+
if noise.ndim == 5:
183+
timepoints = timepoints[..., None, None, None, None].expand(-1, *noise.shape[1:])
184+
elif noise.ndim == 4:
185+
timepoints = timepoints[..., None, None, None].expand(-1, *noise.shape[1:])
186+
else:
187+
raise ValueError(f"noise tensor has to be 4D or 5D tensor, yet got shape of {noise.shape}")
188188

189189
noisy_samples: torch.Tensor = timepoints * original_samples + (1 - timepoints) * noise
190190

@@ -226,7 +226,7 @@ def set_timesteps(
226226
input_img_size_numel=input_img_size_numel,
227227
base_img_size_numel=self.base_img_size_numel,
228228
num_train_timesteps=self.num_train_timesteps,
229-
spatial_dim = self.spatial_dim
229+
spatial_dim=self.spatial_dim,
230230
)
231231
for t in timesteps
232232
]
@@ -261,7 +261,7 @@ def sample_timesteps(self, x_start):
261261
input_img_size_numel=input_img_size_numel,
262262
base_img_size_numel=self.base_img_size_numel,
263263
num_train_timesteps=self.num_train_timesteps,
264-
spatial_dim = len(x_start.shape)-2
264+
spatial_dim=len(x_start.shape) - 2,
265265
)
266266

267267
return t

0 commit comments

Comments
 (0)