@@ -642,6 +642,10 @@ def prepare_latents(
642
642
def guidance_scale (self ):
643
643
return self ._guidance_scale
644
644
645
+ @property
646
+ def skip_guidance_layers (self ):
647
+ return self ._skip_guidance_layers
648
+
645
649
@property
646
650
def clip_skip (self ):
647
651
return self ._clip_skip
@@ -694,6 +698,10 @@ def __call__(
694
698
callback_on_step_end : Optional [Callable [[int , int , Dict ], None ]] = None ,
695
699
callback_on_step_end_tensor_inputs : List [str ] = ["latents" ],
696
700
max_sequence_length : int = 256 ,
701
+ skip_guidance_layers : List [int ] = None ,
702
+ skip_layer_guidance_scale : int = 2.8 ,
703
+ skip_layer_guidance_stop : int = 0.2 ,
704
+ skip_layer_guidance_start : int = 0.01 ,
697
705
):
698
706
r"""
699
707
Function invoked when calling the pipeline for generation.
@@ -778,6 +786,22 @@ def __call__(
778
786
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
779
787
`._callback_tensor_inputs` attribute of your pipeline class.
780
788
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
789
+ skip_guidance_layers (`List[int]`, *optional*):
790
+ A list of integers that specify layers to skip during guidance. If not provided, all layers will be
791
+ used for guidance. If provided, the guidance will only be applied to the layers specified in the list.
792
+ Recommended value by StabiltyAI for Stable Diffusion 3.5 Medium is [7, 8, 9].
793
+ skip_layer_guidance_scale (`int`, *optional*): The scale of the guidance for the layers specified in
794
+ `skip_guidance_layers`. The guidance will be applied to the layers specified in `skip_guidance_layers`
795
+ with a scale of `skip_layer_guidance_scale`. The guidance will be applied to the rest of the layers
796
+ with a scale of `1`.
797
+ skip_layer_guidance_stop (`int`, *optional*): The step at which the guidance for the layers specified in
798
+ `skip_guidance_layers` will stop. The guidance will be applied to the layers specified in
799
+ `skip_guidance_layers` until the fraction specified in `skip_layer_guidance_stop`. Recommended value by
800
+ StabiltyAI for Stable Diffusion 3.5 Medium is 0.2.
801
+ skip_layer_guidance_start (`int`, *optional*): The step at which the guidance for the layers specified in
802
+ `skip_guidance_layers` will start. The guidance will be applied to the layers specified in
803
+ `skip_guidance_layers` from the fraction specified in `skip_layer_guidance_start`. Recommended value by
804
+ StabiltyAI for Stable Diffusion 3.5 Medium is 0.01.
781
805
782
806
Examples:
783
807
@@ -809,6 +833,7 @@ def __call__(
809
833
)
810
834
811
835
self ._guidance_scale = guidance_scale
836
+ self ._skip_layer_guidance_scale = skip_layer_guidance_scale
812
837
self ._clip_skip = clip_skip
813
838
self ._joint_attention_kwargs = joint_attention_kwargs
814
839
self ._interrupt = False
@@ -851,6 +876,9 @@ def __call__(
851
876
)
852
877
853
878
if self .do_classifier_free_guidance :
879
+ if skip_guidance_layers is not None :
880
+ original_prompt_embeds = prompt_embeds
881
+ original_pooled_prompt_embeds = pooled_prompt_embeds
854
882
prompt_embeds = torch .cat ([negative_prompt_embeds , prompt_embeds ], dim = 0 )
855
883
pooled_prompt_embeds = torch .cat ([negative_pooled_prompt_embeds , pooled_prompt_embeds ], dim = 0 )
856
884
@@ -879,7 +907,11 @@ def __call__(
879
907
continue
880
908
881
909
# expand the latents if we are doing classifier free guidance
882
- latent_model_input = torch .cat ([latents ] * 2 ) if self .do_classifier_free_guidance else latents
910
+ latent_model_input = (
911
+ torch .cat ([latents ] * 2 )
912
+ if self .do_classifier_free_guidance and skip_guidance_layers is None
913
+ else latents
914
+ )
883
915
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
884
916
timestep = t .expand (latent_model_input .shape [0 ])
885
917
@@ -896,6 +928,25 @@ def __call__(
896
928
if self .do_classifier_free_guidance :
897
929
noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
898
930
noise_pred = noise_pred_uncond + self .guidance_scale * (noise_pred_text - noise_pred_uncond )
931
+ should_skip_layers = (
932
+ True
933
+ if i > num_inference_steps * skip_layer_guidance_start
934
+ and i < num_inference_steps * skip_layer_guidance_stop
935
+ else False
936
+ )
937
+ if skip_guidance_layers is not None and should_skip_layers :
938
+ noise_pred_skip_layers = self .transformer (
939
+ hidden_states = latent_model_input ,
940
+ timestep = timestep ,
941
+ encoder_hidden_states = original_prompt_embeds ,
942
+ pooled_projections = original_pooled_prompt_embeds ,
943
+ joint_attention_kwargs = self .joint_attention_kwargs ,
944
+ return_dict = False ,
945
+ skip_layers = skip_guidance_layers ,
946
+ )[0 ]
947
+ noise_pred = (
948
+ noise_pred + (noise_pred_text - noise_pred_skip_layers ) * self ._skip_layer_guidance_scale
949
+ )
899
950
900
951
# compute the previous noisy sample x_t -> x_t-1
901
952
latents_dtype = latents .dtype
0 commit comments