Skip to content

Commit 0ef36dd

Browse files
committed
Fixed typo and reverted removal of skip_layers in SD3Transformer2DModel
1 parent 50d09d9 commit 0ef36dd

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

src/diffusers/models/transformers/transformer_sd3.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ def forward(
341341
block_controlnet_hidden_states: List = None,
342342
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
343343
return_dict: bool = True,
344+
skip_layers: Optional[List[int]] = None,
344345
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
345346
"""
346347
The [`SD3Transformer2DModel`] forward method.
@@ -363,6 +364,8 @@ def forward(
363364
return_dict (`bool`, *optional*, defaults to `True`):
364365
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
365366
tuple.
367+
skip_layers (`list` of `int`, *optional*):
368+
A list of layer indices to skip during the forward pass.
366369
367370
Returns:
368371
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
@@ -390,7 +393,10 @@ def forward(
390393
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
391394

392395
for index_block, block in enumerate(self.transformer_blocks):
393-
if self.training and self.gradient_checkpointing:
396+
# Skip specified layers
397+
is_skip = True if skip_layers is not None and index_block in skip_layers else False
398+
399+
if torch.is_grad_enabled() and self.gradient_checkpointing and not is_skip:
394400

395401
def create_custom_forward(module, return_dict=None):
396402
def custom_forward(*inputs):
@@ -410,8 +416,7 @@ def custom_forward(*inputs):
410416
joint_attention_kwargs,
411417
**ckpt_kwargs,
412418
)
413-
414-
else:
419+
elif not is_skip:
415420
encoder_hidden_states, hidden_states = block(
416421
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb,
417422
joint_attention_kwargs=joint_attention_kwargs,

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -855,7 +855,7 @@ def set_ip_adapter_scale(self, scale):
855855
only conditioned by the text prompt. Lowering this value encourages the model to produce more diverse images, but they
856856
may not be as aligned with the image prompt.
857857
"""
858-
for attn_processor in self.transformes.attn_processors.values():
858+
for attn_processor in self.transformer.attn_processors.values():
859859
if isinstance(attn_processor, IPAdapterJointAttnProcessor2_0):
860860
attn_processor.scale = scale
861861

0 commit comments

Comments
 (0)