Skip to content

Commit 98550c0

Browse files
virginiafdezvirginiafdez
andauthored
Scheduler Clip Fix (#7855)
Fixes # . ### Description Fixes a bug in the inferer and adds clipping parameters to the DDIM/DDPM schedulers. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: virginiafdez <[email protected]> Co-authored-by: virginiafdez <[email protected]>
1 parent 36511cc commit 98550c0

File tree

4 files changed

+21
-5
lines changed

4 files changed

+21
-5
lines changed

monai/inferers/inferer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1607,7 +1607,7 @@ def __init__(
16071607
self.autoencoder_latent_shape = autoencoder_latent_shape
16081608
if self.ldm_latent_shape is not None and self.autoencoder_latent_shape is not None:
16091609
self.ldm_resizer = SpatialPad(spatial_size=self.ldm_latent_shape)
1610-
self.autoencoder_resizer = CenterSpatialCrop(roi_size=[-1] + self.autoencoder_latent_shape)
1610+
self.autoencoder_resizer = CenterSpatialCrop(roi_size=self.autoencoder_latent_shape)
16111611

16121612
def __call__( # type: ignore[override]
16131613
self,

monai/networks/schedulers/ddim.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ class DDIMScheduler(Scheduler):
5757
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
5858
stable diffusion.
5959
prediction_type: member of DDPMPredictionType
60+
clip_sample_min: minimum clipping value when clip_sample equals True
61+
clip_sample_max: maximum clipping value when clip_sample equals True
6062
schedule_args: arguments to pass to the schedule function
6163
6264
"""
@@ -69,6 +71,8 @@ def __init__(
6971
set_alpha_to_one: bool = True,
7072
steps_offset: int = 0,
7173
prediction_type: str = DDIMPredictionType.EPSILON,
74+
clip_sample_min: float = -1.0,
75+
clip_sample_max: float = 1.0,
7276
**schedule_args,
7377
) -> None:
7478
super().__init__(num_train_timesteps, schedule, **schedule_args)
@@ -90,6 +94,7 @@ def __init__(
9094
self.timesteps = torch.from_numpy(np.arange(0, self.num_train_timesteps)[::-1].astype(np.int64))
9195

9296
self.clip_sample = clip_sample
97+
self.clip_sample_values = [clip_sample_min, clip_sample_max]
9398
self.steps_offset = steps_offset
9499

95100
# default the number of inference timesteps to the number of train steps
@@ -193,7 +198,9 @@ def step(
193198

194199
# 4. Clip "predicted x_0"
195200
if self.clip_sample:
196-
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
201+
pred_original_sample = torch.clamp(
202+
pred_original_sample, self.clip_sample_values[0], self.clip_sample_values[1]
203+
)
197204

198205
# 5. compute variance: "sigma_t(η)" -> see formula (16)
199206
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
@@ -266,7 +273,9 @@ def reversed_step(
266273

267274
# 4. Clip "predicted x_0"
268275
if self.clip_sample:
269-
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
276+
pred_original_sample = torch.clamp(
277+
pred_original_sample, self.clip_sample_values[0], self.clip_sample_values[1]
278+
)
270279

271280
# 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
272281
pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon

monai/networks/schedulers/ddpm.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ class DDPMScheduler(Scheduler):
7777
variance_type: member of DDPMVarianceType
7878
clip_sample: option to clip predicted sample between -1 and 1 for numerical stability.
7979
prediction_type: member of DDPMPredictionType
80+
clip_sample_min: minimum clipping value when clip_sample equals True
81+
clip_sample_max: maximum clipping value when clip_sample equals True
8082
schedule_args: arguments to pass to the schedule function
8183
"""
8284

@@ -87,6 +89,8 @@ def __init__(
8789
variance_type: str = DDPMVarianceType.FIXED_SMALL,
8890
clip_sample: bool = True,
8991
prediction_type: str = DDPMPredictionType.EPSILON,
92+
clip_sample_min: float = -1.0,
93+
clip_sample_max: float = 1.0,
9094
**schedule_args,
9195
) -> None:
9296
super().__init__(num_train_timesteps, schedule, **schedule_args)
@@ -98,6 +102,7 @@ def __init__(
98102
raise ValueError("Argument `prediction_type` must be a member of `DDPMPredictionType`")
99103

100104
self.clip_sample = clip_sample
105+
self.clip_sample_values = [clip_sample_min, clip_sample_max]
101106
self.variance_type = variance_type
102107
self.prediction_type = prediction_type
103108

@@ -219,7 +224,9 @@ def step(
219224

220225
# 3. Clip "predicted x_0"
221226
if self.clip_sample:
222-
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
227+
pred_original_sample = torch.clamp(
228+
pred_original_sample, self.clip_sample_values[0], self.clip_sample_values[1]
229+
)
223230

224231
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
225232
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
torch>=1.9
2-
numpy>=1.20
2+
numpy>=1.20,<2.0

0 commit comments

Comments
 (0)