Skip to content

Commit 69acaee

Browse files
AnandK27yiyixuxuhlky
authored andcommitted
[Bug fix] "previous_timestep()" in DDPM scheduling compatible with "trailing" and "linspace" options (#9384)
* Update scheduling_ddpm.py * fix copies --------- Co-authored-by: YiYi Xu <[email protected]> Co-authored-by: hlky <[email protected]>
1 parent 94ee4c1 commit 69acaee

File tree

4 files changed

+8
-24
lines changed

4 files changed

+8
-24
lines changed

src/diffusers/schedulers/scheduling_ddpm.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -548,16 +548,12 @@ def __len__(self):
548548
return self.config.num_train_timesteps
549549

550550
def previous_timestep(self, timestep):
551-
if self.custom_timesteps:
551+
if self.custom_timesteps or self.num_inference_steps:
552552
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
553553
if index == self.timesteps.shape[0] - 1:
554554
prev_t = torch.tensor(-1)
555555
else:
556556
prev_t = self.timesteps[index + 1]
557557
else:
558-
num_inference_steps = (
559-
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
560-
)
561-
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
562-
558+
prev_t = timestep - 1
563559
return prev_t

src/diffusers/schedulers/scheduling_ddpm_parallel.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -639,16 +639,12 @@ def __len__(self):
639639

640640
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
641641
def previous_timestep(self, timestep):
642-
if self.custom_timesteps:
642+
if self.custom_timesteps or self.num_inference_steps:
643643
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
644644
if index == self.timesteps.shape[0] - 1:
645645
prev_t = torch.tensor(-1)
646646
else:
647647
prev_t = self.timesteps[index + 1]
648648
else:
649-
num_inference_steps = (
650-
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
651-
)
652-
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
653-
649+
prev_t = timestep - 1
654650
return prev_t

src/diffusers/schedulers/scheduling_lcm.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -643,16 +643,12 @@ def __len__(self):
643643

644644
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
645645
def previous_timestep(self, timestep):
646-
if self.custom_timesteps:
646+
if self.custom_timesteps or self.num_inference_steps:
647647
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
648648
if index == self.timesteps.shape[0] - 1:
649649
prev_t = torch.tensor(-1)
650650
else:
651651
prev_t = self.timesteps[index + 1]
652652
else:
653-
num_inference_steps = (
654-
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
655-
)
656-
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
657-
653+
prev_t = timestep - 1
658654
return prev_t

src/diffusers/schedulers/scheduling_tcd.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -680,16 +680,12 @@ def __len__(self):
680680

681681
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
682682
def previous_timestep(self, timestep):
683-
if self.custom_timesteps:
683+
if self.custom_timesteps or self.num_inference_steps:
684684
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
685685
if index == self.timesteps.shape[0] - 1:
686686
prev_t = torch.tensor(-1)
687687
else:
688688
prev_t = self.timesteps[index + 1]
689689
else:
690-
num_inference_steps = (
691-
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
692-
)
693-
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
694-
690+
prev_t = timestep - 1
695691
return prev_t

0 commit comments

Comments
 (0)