Skip to content

Commit 7b71a61

Browse files
committed
Update monai/networks/schedulers/ddim.py
Signed-off-by: ytl0623 <[email protected]>
1 parent dfd1626 commit 7b71a61

File tree

2 files changed

+3
-8
lines changed

2 files changed

+3
-8
lines changed

monai/networks/schedulers/ddim.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,7 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N
120120
if self.steps_offset < 0 or self.steps_offset >= self.num_train_timesteps:
121121
raise ValueError(f"`steps_offset`: {self.steps_offset} must be in range [0, {self.num_train_timesteps}).")
122122

123-
timesteps = (
124-
np.linspace((self.num_train_timesteps - 1) - self.steps_offset, 0, num_inference_steps)
125-
.round()
126-
.astype(np.int64)
127-
)
128-
self.timesteps = torch.from_numpy(timesteps).to(device)
123+
self.timesteps = torch.linspace((self.num_train_timesteps - 1) - self.steps_offset, 0, num_inference_steps, device=device).round().long()
129124
self.timesteps += self.steps_offset
130125

131126
def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor:

monai/networks/schedulers/ddpm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,8 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N
122122
)
123123

124124
self.num_inference_steps = num_inference_steps
125-
timesteps = np.linspace(self.num_train_timesteps - 1, 0, self.num_inference_steps).round().astype(np.int64)
126-
self.timesteps = torch.from_numpy(timesteps).to(device)
125+
self.timesteps = torch.linspace(self.num_train_timesteps - 1, 0, self.num_inference_steps, device=device).round().long()
126+
127127

128128
def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torch.Tensor:
129129
"""

0 commit comments

Comments
 (0)