diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 72bcb8fd5a..769b6cc0e7 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -1607,7 +1607,7 @@ def __init__( self.autoencoder_latent_shape = autoencoder_latent_shape if self.ldm_latent_shape is not None and self.autoencoder_latent_shape is not None: self.ldm_resizer = SpatialPad(spatial_size=self.ldm_latent_shape) - self.autoencoder_resizer = CenterSpatialCrop(roi_size=[-1] + self.autoencoder_latent_shape) + self.autoencoder_resizer = CenterSpatialCrop(roi_size=self.autoencoder_latent_shape) def __call__( # type: ignore[override] self, diff --git a/monai/networks/schedulers/ddim.py b/monai/networks/schedulers/ddim.py index 78e3cc2a0c..19e24d94b8 100644 --- a/monai/networks/schedulers/ddim.py +++ b/monai/networks/schedulers/ddim.py @@ -57,6 +57,8 @@ class DDIMScheduler(Scheduler): `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. prediction_type: member of DDPMPredictionType + clip_sample_min: minimum clipping value when clip_sample equals True + clip_sample_max: maximum clipping value when clip_sample equals True schedule_args: arguments to pass to the schedule function """ @@ -69,6 +71,8 @@ def __init__( set_alpha_to_one: bool = True, steps_offset: int = 0, prediction_type: str = DDIMPredictionType.EPSILON, + clip_sample_min: float = -1.0, + clip_sample_max: float = 1.0, **schedule_args, ) -> None: super().__init__(num_train_timesteps, schedule, **schedule_args) @@ -90,6 +94,7 @@ def __init__( self.timesteps = torch.from_numpy(np.arange(0, self.num_train_timesteps)[::-1].astype(np.int64)) self.clip_sample = clip_sample + self.clip_sample_values = [clip_sample_min, clip_sample_max] self.steps_offset = steps_offset # default the number of inference timesteps to the number of train steps @@ -193,7 +198,9 @@ def step( # 4. Clip "predicted x_0" if self.clip_sample: - pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + pred_original_sample = torch.clamp( + pred_original_sample, self.clip_sample_values[0], self.clip_sample_values[1] + ) # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) @@ -266,7 +273,9 @@ def reversed_step( # 4. Clip "predicted x_0" if self.clip_sample: - pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + pred_original_sample = torch.clamp( + pred_original_sample, self.clip_sample_values[0], self.clip_sample_values[1] + ) # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon diff --git a/monai/networks/schedulers/ddpm.py b/monai/networks/schedulers/ddpm.py index a5173a1b65..93ad833031 100644 --- a/monai/networks/schedulers/ddpm.py +++ b/monai/networks/schedulers/ddpm.py @@ -77,6 +77,8 @@ class DDPMScheduler(Scheduler): variance_type: member of DDPMVarianceType clip_sample: option to clip predicted sample between -1 and 1 for numerical stability. prediction_type: member of DDPMPredictionType + clip_sample_min: minimum clipping value when clip_sample equals True + clip_sample_max: maximum clipping value when clip_sample equals True schedule_args: arguments to pass to the schedule function """ @@ -87,6 +89,8 @@ def __init__( variance_type: str = DDPMVarianceType.FIXED_SMALL, clip_sample: bool = True, prediction_type: str = DDPMPredictionType.EPSILON, + clip_sample_min: float = -1.0, + clip_sample_max: float = 1.0, **schedule_args, ) -> None: super().__init__(num_train_timesteps, schedule, **schedule_args) @@ -98,6 +102,7 @@ def __init__( raise ValueError("Argument `prediction_type` must be a member of `DDPMPredictionType`") self.clip_sample = clip_sample + self.clip_sample_values = [clip_sample_min, clip_sample_max] self.variance_type = variance_type self.prediction_type = prediction_type @@ -219,7 +224,9 @@ def step( # 3. Clip "predicted x_0" if self.clip_sample: - pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + pred_original_sample = torch.clamp( + pred_original_sample, self.clip_sample_values[0], self.clip_sample_values[1] + ) # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf diff --git a/requirements.txt b/requirements.txt index 1569646794..1d6ae13eec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ torch>=1.9 -numpy>=1.20 +numpy>=1.20,<2.0