Skip to content

Commit 99c0483

Browse files
bghirabghirayiyixuxu
authored
add skip_layers argument to SD3 transformer model class (#9880)
* add skip_layers argument to SD3 transformer model class * add unit test for skip_layers in stable diffusion 3 * sd3: pipeline should support skip layer guidance * up --------- Co-authored-by: bghira <[email protected]> Co-authored-by: yiyixuxu <[email protected]>
1 parent cc7d88f commit 99c0483

File tree

3 files changed

+82
-6
lines changed

3 files changed

+82
-6
lines changed

src/diffusers/models/transformers/transformer_sd3.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ def forward(
268268
block_controlnet_hidden_states: List = None,
269269
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
270270
return_dict: bool = True,
271+
skip_layers: Optional[List[int]] = None,
271272
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
272273
"""
273274
The [`SD3Transformer2DModel`] forward method.
@@ -279,9 +280,9 @@ def forward(
279280
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
280281
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
281282
from the embeddings of input conditions.
282-
timestep ( `torch.LongTensor`):
283+
timestep (`torch.LongTensor`):
283284
Used to indicate denoising step.
284-
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
285+
block_controlnet_hidden_states (`list` of `torch.Tensor`):
285286
A list of tensors that if specified are added to the residuals of transformer blocks.
286287
joint_attention_kwargs (`dict`, *optional*):
287288
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
@@ -290,6 +291,8 @@ def forward(
290291
return_dict (`bool`, *optional*, defaults to `True`):
291292
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
292293
tuple.
294+
skip_layers (`list` of `int`, *optional*):
295+
A list of layer indices to skip during the forward pass.
293296
294297
Returns:
295298
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
@@ -317,7 +320,10 @@ def forward(
317320
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
318321

319322
for index_block, block in enumerate(self.transformer_blocks):
320-
if torch.is_grad_enabled() and self.gradient_checkpointing:
323+
# Skip specified layers
324+
is_skip = True if skip_layers is not None and index_block in skip_layers else False
325+
326+
if torch.is_grad_enabled() and self.gradient_checkpointing and not is_skip:
321327

322328
def create_custom_forward(module, return_dict=None):
323329
def custom_forward(*inputs):
@@ -336,8 +342,7 @@ def custom_forward(*inputs):
336342
temb,
337343
**ckpt_kwargs,
338344
)
339-
340-
else:
345+
elif not is_skip:
341346
encoder_hidden_states, hidden_states = block(
342347
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
343348
)

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,10 @@ def prepare_latents(
642642
def guidance_scale(self):
643643
return self._guidance_scale
644644

645+
@property
646+
def skip_guidance_layers(self):
647+
return self._skip_guidance_layers
648+
645649
@property
646650
def clip_skip(self):
647651
return self._clip_skip
@@ -694,6 +698,10 @@ def __call__(
694698
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
695699
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
696700
max_sequence_length: int = 256,
701+
skip_guidance_layers: List[int] = None,
702+
skip_layer_guidance_scale: int = 2.8,
703+
skip_layer_guidance_stop: int = 0.2,
704+
skip_layer_guidance_start: int = 0.01,
697705
):
698706
r"""
699707
Function invoked when calling the pipeline for generation.
@@ -778,6 +786,22 @@ def __call__(
778786
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
779787
`._callback_tensor_inputs` attribute of your pipeline class.
780788
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
789+
skip_guidance_layers (`List[int]`, *optional*):
790+
A list of integers that specify layers to skip during guidance. If not provided, all layers will be
791+
used for guidance. If provided, the guidance will only be applied to the layers specified in the list.
792+
Recommended value by StabiltyAI for Stable Diffusion 3.5 Medium is [7, 8, 9].
793+
skip_layer_guidance_scale (`int`, *optional*): The scale of the guidance for the layers specified in
794+
`skip_guidance_layers`. The guidance will be applied to the layers specified in `skip_guidance_layers`
795+
with a scale of `skip_layer_guidance_scale`. The guidance will be applied to the rest of the layers
796+
with a scale of `1`.
797+
skip_layer_guidance_stop (`int`, *optional*): The step at which the guidance for the layers specified in
798+
`skip_guidance_layers` will stop. The guidance will be applied to the layers specified in
799+
`skip_guidance_layers` until the fraction specified in `skip_layer_guidance_stop`. Recommended value by
800+
StabiltyAI for Stable Diffusion 3.5 Medium is 0.2.
801+
skip_layer_guidance_start (`int`, *optional*): The step at which the guidance for the layers specified in
802+
`skip_guidance_layers` will start. The guidance will be applied to the layers specified in
803+
`skip_guidance_layers` from the fraction specified in `skip_layer_guidance_start`. Recommended value by
804+
StabiltyAI for Stable Diffusion 3.5 Medium is 0.01.
781805
782806
Examples:
783807
@@ -809,6 +833,7 @@ def __call__(
809833
)
810834

811835
self._guidance_scale = guidance_scale
836+
self._skip_layer_guidance_scale = skip_layer_guidance_scale
812837
self._clip_skip = clip_skip
813838
self._joint_attention_kwargs = joint_attention_kwargs
814839
self._interrupt = False
@@ -851,6 +876,9 @@ def __call__(
851876
)
852877

853878
if self.do_classifier_free_guidance:
879+
if skip_guidance_layers is not None:
880+
original_prompt_embeds = prompt_embeds
881+
original_pooled_prompt_embeds = pooled_prompt_embeds
854882
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
855883
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
856884

@@ -879,7 +907,11 @@ def __call__(
879907
continue
880908

881909
# expand the latents if we are doing classifier free guidance
882-
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
910+
latent_model_input = (
911+
torch.cat([latents] * 2)
912+
if self.do_classifier_free_guidance and skip_guidance_layers is None
913+
else latents
914+
)
883915
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
884916
timestep = t.expand(latent_model_input.shape[0])
885917

@@ -896,6 +928,25 @@ def __call__(
896928
if self.do_classifier_free_guidance:
897929
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
898930
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
931+
should_skip_layers = (
932+
True
933+
if i > num_inference_steps * skip_layer_guidance_start
934+
and i < num_inference_steps * skip_layer_guidance_stop
935+
else False
936+
)
937+
if skip_guidance_layers is not None and should_skip_layers:
938+
noise_pred_skip_layers = self.transformer(
939+
hidden_states=latent_model_input,
940+
timestep=timestep,
941+
encoder_hidden_states=original_prompt_embeds,
942+
pooled_projections=original_pooled_prompt_embeds,
943+
joint_attention_kwargs=self.joint_attention_kwargs,
944+
return_dict=False,
945+
skip_layers=skip_guidance_layers,
946+
)[0]
947+
noise_pred = (
948+
noise_pred + (noise_pred_text - noise_pred_skip_layers) * self._skip_layer_guidance_scale
949+
)
899950

900951
# compute the previous noisy sample x_t -> x_t-1
901952
latents_dtype = latents.dtype

tests/models/transformers/test_models_transformer_sd3.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,23 @@ def test_set_attn_processor_for_determinism(self):
147147
def test_gradient_checkpointing_is_applied(self):
148148
expected_set = {"SD3Transformer2DModel"}
149149
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
150+
151+
def test_skip_layers(self):
152+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
153+
model = self.model_class(**init_dict).to(torch_device)
154+
155+
# Forward pass without skipping layers
156+
output_full = model(**inputs_dict).sample
157+
158+
# Forward pass with skipping layers 0 (since there's only one layer in this test setup)
159+
inputs_dict_with_skip = inputs_dict.copy()
160+
inputs_dict_with_skip["skip_layers"] = [0]
161+
output_skip = model(**inputs_dict_with_skip).sample
162+
163+
# Check that the outputs are different
164+
self.assertFalse(
165+
torch.allclose(output_full, output_skip, atol=1e-5), "Outputs should differ when layers are skipped"
166+
)
167+
168+
# Check that the outputs have the same shape
169+
self.assertEqual(output_full.shape, output_skip.shape, "Outputs should have the same shape")

0 commit comments

Comments
 (0)