Skip to content

Commit 14664e8

Browse files
committed
make it 2D/3D compartible, rm a outdated comment
Signed-off-by: Can-Zhao <[email protected]>
1 parent 8555b67 commit 14664e8

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

monai/networks/schedulers/rectified_flow.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ class RFlowScheduler(Scheduler):
8282
transform_scale (float): Scaling factor for timestep transformation, used only if use_timestep_transform=True.
8383
steps_offset (int): Offset added to computed timesteps, used only if use_timestep_transform=True.
8484
base_img_size_numel (int): Reference image volume size for scaling, used only if use_timestep_transform=True.
85+
spatial_dim (int): 2 or 3, incidcating 2D or 3D images, used only if use_timestep_transform=True.
8586
8687
Example:
8788
@@ -93,7 +94,8 @@ class RFlowScheduler(Scheduler):
9394
use_discrete_timesteps = True,
9495
sample_method = 'logit-normal',
9596
use_timestep_transform = True,
96-
base_img_size_numel = 32 * 32 * 32
97+
base_img_size_numel = 32 * 32 * 32,
98+
spatial_dim = 3
9799
)
98100
99101
# during training
@@ -139,10 +141,12 @@ def __init__(
139141
transform_scale: float = 1.0,
140142
steps_offset: int = 0,
141143
base_img_size_numel: int = 32 * 32 * 32,
144+
spatial_dim: int = 3
142145
):
143146
self.num_train_timesteps = num_train_timesteps
144147
self.use_discrete_timesteps = use_discrete_timesteps
145148
self.base_img_size_numel = base_img_size_numel
149+
self.spatial_dim = spatial_dim
146150

147151
# sample method
148152
if sample_method not in ["uniform", "logit-normal"]:
@@ -166,7 +170,7 @@ def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timeste
166170
Args:
167171
original_samples: original samples
168172
noise: noise to add to samples
169-
timesteps: timesteps tensor indicating the timestep to be computed for each sample.
173+
timesteps: timesteps tensor with shape of (N,), indicating the timestep to be computed for each sample.
170174
171175
Returns:
172176
noisy_samples: sample with added noise
@@ -175,14 +179,14 @@ def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timeste
175179
timepoints = 1 - timepoints # [1,1/1000]
176180

177181
# expand timepoint to noise shape
182+
# Just in case timepoints is not 1D or 2D tensor, make it to be same shape as noise
178183
if len(noise.shape) == 5:
179184
timepoints = timepoints.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1)
180185
timepoints = timepoints.repeat(1, noise.shape[1], noise.shape[2], noise.shape[3], noise.shape[4])
181186
elif len(noise.shape) == 4:
182187
timepoints = timepoints.unsqueeze(1).unsqueeze(1).unsqueeze(1)
183188
timepoints = timepoints.repeat(1, noise.shape[1], noise.shape[2], noise.shape[3])
184-
else:
185-
raise ValueError(f"noise has to be 4D or 5D tensor. yet got shape of {noise.shape}.")
189+
186190
noisy_samples: torch.Tensor = timepoints * original_samples + (1 - timepoints) * noise
187191

188192
return noisy_samples
@@ -223,6 +227,7 @@ def set_timesteps(
223227
input_img_size_numel=input_img_size_numel,
224228
base_img_size_numel=self.base_img_size_numel,
225229
num_train_timesteps=self.num_train_timesteps,
230+
spatial_dim = self.spatial_dim
226231
)
227232
for t in timesteps
228233
]
@@ -257,6 +262,7 @@ def sample_timesteps(self, x_start):
257262
input_img_size_numel=input_img_size_numel,
258263
base_img_size_numel=self.base_img_size_numel,
259264
num_train_timesteps=self.num_train_timesteps,
265+
spatial_dim = len(x_start.shape)-2
260266
)
261267

262268
return t

0 commit comments

Comments
 (0)