Skip to content

Scheduler Clip Fix #7855

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monai/inferers/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1607,7 +1607,7 @@ def __init__(
self.autoencoder_latent_shape = autoencoder_latent_shape
if self.ldm_latent_shape is not None and self.autoencoder_latent_shape is not None:
self.ldm_resizer = SpatialPad(spatial_size=self.ldm_latent_shape)
self.autoencoder_resizer = CenterSpatialCrop(roi_size=[-1] + self.autoencoder_latent_shape)
self.autoencoder_resizer = CenterSpatialCrop(roi_size=self.autoencoder_latent_shape)

def __call__( # type: ignore[override]
self,
Expand Down
13 changes: 11 additions & 2 deletions monai/networks/schedulers/ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ class DDIMScheduler(Scheduler):
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
prediction_type: member of DDPMPredictionType
clip_sample_min: minimum clipping value when clip_sample equals True
clip_sample_max: maximum clipping value when clip_sample equals True
schedule_args: arguments to pass to the schedule function

"""
Expand All @@ -69,6 +71,8 @@ def __init__(
set_alpha_to_one: bool = True,
steps_offset: int = 0,
prediction_type: str = DDIMPredictionType.EPSILON,
clip_sample_min: float = -1.0,
clip_sample_max: float = 1.0,
**schedule_args,
) -> None:
super().__init__(num_train_timesteps, schedule, **schedule_args)
Expand All @@ -90,6 +94,7 @@ def __init__(
self.timesteps = torch.from_numpy(np.arange(0, self.num_train_timesteps)[::-1].astype(np.int64))

self.clip_sample = clip_sample
self.clip_sample_values = [clip_sample_min, clip_sample_max]
self.steps_offset = steps_offset

# default the number of inference timesteps to the number of train steps
Expand Down Expand Up @@ -193,7 +198,9 @@ def step(

# 4. Clip "predicted x_0"
if self.clip_sample:
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
pred_original_sample = torch.clamp(
pred_original_sample, self.clip_sample_values[0], self.clip_sample_values[1]
)

# 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
Expand Down Expand Up @@ -266,7 +273,9 @@ def reversed_step(

# 4. Clip "predicted x_0"
if self.clip_sample:
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
pred_original_sample = torch.clamp(
pred_original_sample, self.clip_sample_values[0], self.clip_sample_values[1]
)

# 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon
Expand Down
9 changes: 8 additions & 1 deletion monai/networks/schedulers/ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ class DDPMScheduler(Scheduler):
variance_type: member of DDPMVarianceType
clip_sample: option to clip predicted sample between -1 and 1 for numerical stability.
prediction_type: member of DDPMPredictionType
clip_sample_min: minimum clipping value when clip_sample equals True
clip_sample_max: maximum clipping value when clip_sample equals True
schedule_args: arguments to pass to the schedule function
"""

Expand All @@ -87,6 +89,8 @@ def __init__(
variance_type: str = DDPMVarianceType.FIXED_SMALL,
clip_sample: bool = True,
prediction_type: str = DDPMPredictionType.EPSILON,
clip_sample_min: float = -1.0,
clip_sample_max: float = 1.0,
**schedule_args,
) -> None:
super().__init__(num_train_timesteps, schedule, **schedule_args)
Expand All @@ -98,6 +102,7 @@ def __init__(
raise ValueError("Argument `prediction_type` must be a member of `DDPMPredictionType`")

self.clip_sample = clip_sample
self.clip_sample_values = [clip_sample_min, clip_sample_max]
self.variance_type = variance_type
self.prediction_type = prediction_type

Expand Down Expand Up @@ -219,7 +224,9 @@ def step(

# 3. Clip "predicted x_0"
if self.clip_sample:
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
pred_original_sample = torch.clamp(
pred_original_sample, self.clip_sample_values[0], self.clip_sample_values[1]
)

# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
torch>=1.9
numpy>=1.20
numpy>=1.20,<2.0
Loading