Skip to content
14 changes: 3 additions & 11 deletions monai/networks/schedulers/ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,18 +117,10 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N
)

self.num_inference_steps = num_inference_steps
step_ratio = self.num_train_timesteps // self.num_inference_steps
if self.steps_offset >= step_ratio:
raise ValueError(
f"`steps_offset`: {self.steps_offset} cannot be greater than or equal to "
f"`num_train_timesteps // num_inference_steps : {step_ratio}` as this will cause timesteps to exceed"
f" the max train timestep."
)
if self.steps_offset < 0 or self.steps_offset >= self.num_train_timesteps:
raise ValueError(f"`steps_offset`: {self.steps_offset} must be in range [0, {self.num_train_timesteps}).")

# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
self.timesteps = torch.from_numpy(timesteps).to(device)
self.timesteps = torch.linspace((self.num_train_timesteps - 1) - self.steps_offset, 0, num_inference_steps, device=device).round().long()
self.timesteps += self.steps_offset

def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor:
Expand Down
8 changes: 2 additions & 6 deletions monai/networks/schedulers/ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@

from __future__ import annotations

import numpy as np
import torch

from monai.utils import StrEnum
Expand Down Expand Up @@ -122,11 +121,8 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N
)

self.num_inference_steps = num_inference_steps
step_ratio = self.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].astype(np.int64)
self.timesteps = torch.from_numpy(timesteps).to(device)
self.timesteps = torch.linspace(self.num_train_timesteps - 1, 0, self.num_inference_steps, device=device).round().long()


def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torch.Tensor:
"""
Expand Down
Loading