Skip to content

Commit ddf6898

Browse files
committed
timestep scheduling with np.linspace
Signed-off-by: ytl0623 <[email protected]>
1 parent 3f4889c commit ddf6898

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

monai/networks/schedulers/ddim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N
127127

128128
# creates integer timesteps by multiplying by ratio
129129
# casting to int to avoid issues when num_inference_step is power of 3
130-
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
130+
timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps).round().astype(np.int64)
131131
self.timesteps = torch.from_numpy(timesteps).to(device)
132132
self.timesteps += self.steps_offset
133133

monai/networks/schedulers/ddpm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N
125125
step_ratio = self.num_train_timesteps // self.num_inference_steps
126126
# creates integer timesteps by multiplying by ratio
127127
# casting to int to avoid issues when num_inference_step is power of 3
128-
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].astype(np.int64)
128+
timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps).round().astype(np.int64)
129129
self.timesteps = torch.from_numpy(timesteps).to(device)
130130

131131
def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)