Skip to content

Commit ef7e83c

Browse files
committed
With the linspace approach, max_timestep = (num_train_timesteps - 1 - steps_offset) + steps_offset = num_train_timesteps - 1 regardless of the relationship between steps_offset and step_ratio. The actual constraint is 0 <= steps_offset < num_train_timesteps.
Signed-off-by: ytl0623 <[email protected]>
1 parent f7ca165 commit ef7e83c

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

monai/networks/schedulers/ddim.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -117,15 +117,14 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N
117117
)
118118

119119
self.num_inference_steps = num_inference_steps
120-
step_ratio = self.num_train_timesteps // self.num_inference_steps
121-
if self.steps_offset >= step_ratio:
122-
raise ValueError(
123-
f"`steps_offset`: {self.steps_offset} cannot be greater than or equal to "
124-
f"`num_train_timesteps // num_inference_steps : {step_ratio}` as this will cause timesteps to exceed"
125-
f" the max train timestep."
126-
)
127-
128-
timesteps = np.linspace((self.num_train_timesteps - 1) - self.steps_offset, 0, num_inference_steps).round().astype(np.int64)
120+
if self.steps_offset < 0 or self.steps_offset >= self.num_train_timesteps:
121+
raise ValueError(f"`steps_offset`: {self.steps_offset} must be in range [0, {self.num_train_timesteps}).")
122+
123+
timesteps = (
124+
np.linspace((self.num_train_timesteps - 1) - self.steps_offset, 0, num_inference_steps)
125+
.round()
126+
.astype(np.int64)
127+
)
129128
self.timesteps = torch.from_numpy(timesteps).to(device)
130129
self.timesteps += self.steps_offset
131130

0 commit comments

Comments
 (0)