File tree Expand file tree Collapse file tree 4 files changed +8
-24
lines changed Expand file tree Collapse file tree 4 files changed +8
-24
lines changed Original file line number Diff line number Diff line change @@ -548,16 +548,12 @@ def __len__(self):
548
548
return self .config .num_train_timesteps
549
549
550
550
def previous_timestep (self , timestep ):
551
- if self .custom_timesteps :
551
+ if self .custom_timesteps or self . num_inference_steps :
552
552
index = (self .timesteps == timestep ).nonzero (as_tuple = True )[0 ][0 ]
553
553
if index == self .timesteps .shape [0 ] - 1 :
554
554
prev_t = torch .tensor (- 1 )
555
555
else :
556
556
prev_t = self .timesteps [index + 1 ]
557
557
else :
558
- num_inference_steps = (
559
- self .num_inference_steps if self .num_inference_steps else self .config .num_train_timesteps
560
- )
561
- prev_t = timestep - self .config .num_train_timesteps // num_inference_steps
562
-
558
+ prev_t = timestep - 1
563
559
return prev_t
Original file line number Diff line number Diff line change @@ -639,16 +639,12 @@ def __len__(self):
639
639
640
640
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
641
641
def previous_timestep (self , timestep ):
642
- if self .custom_timesteps :
642
+ if self .custom_timesteps or self . num_inference_steps :
643
643
index = (self .timesteps == timestep ).nonzero (as_tuple = True )[0 ][0 ]
644
644
if index == self .timesteps .shape [0 ] - 1 :
645
645
prev_t = torch .tensor (- 1 )
646
646
else :
647
647
prev_t = self .timesteps [index + 1 ]
648
648
else :
649
- num_inference_steps = (
650
- self .num_inference_steps if self .num_inference_steps else self .config .num_train_timesteps
651
- )
652
- prev_t = timestep - self .config .num_train_timesteps // num_inference_steps
653
-
649
+ prev_t = timestep - 1
654
650
return prev_t
Original file line number Diff line number Diff line change @@ -643,16 +643,12 @@ def __len__(self):
643
643
644
644
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
645
645
def previous_timestep (self , timestep ):
646
- if self .custom_timesteps :
646
+ if self .custom_timesteps or self . num_inference_steps :
647
647
index = (self .timesteps == timestep ).nonzero (as_tuple = True )[0 ][0 ]
648
648
if index == self .timesteps .shape [0 ] - 1 :
649
649
prev_t = torch .tensor (- 1 )
650
650
else :
651
651
prev_t = self .timesteps [index + 1 ]
652
652
else :
653
- num_inference_steps = (
654
- self .num_inference_steps if self .num_inference_steps else self .config .num_train_timesteps
655
- )
656
- prev_t = timestep - self .config .num_train_timesteps // num_inference_steps
657
-
653
+ prev_t = timestep - 1
658
654
return prev_t
Original file line number Diff line number Diff line change @@ -680,16 +680,12 @@ def __len__(self):
680
680
681
681
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
682
682
def previous_timestep (self , timestep ):
683
- if self .custom_timesteps :
683
+ if self .custom_timesteps or self . num_inference_steps :
684
684
index = (self .timesteps == timestep ).nonzero (as_tuple = True )[0 ][0 ]
685
685
if index == self .timesteps .shape [0 ] - 1 :
686
686
prev_t = torch .tensor (- 1 )
687
687
else :
688
688
prev_t = self .timesteps [index + 1 ]
689
689
else :
690
- num_inference_steps = (
691
- self .num_inference_steps if self .num_inference_steps else self .config .num_train_timesteps
692
- )
693
- prev_t = timestep - self .config .num_train_timesteps // num_inference_steps
694
-
690
+ prev_t = timestep - 1
695
691
return prev_t
You can’t perform that action at this time.
0 commit comments