Skip to content

minor doc/test update #9734

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ image = pipe(
image.save("sd3_hello_world.png")
```

**Note:** Stable Diffusion 3.5 can also be run using the SD3 pipeline, and all mentioned optimizations and techniques apply to it as well. In total there are three official models in the SD3 family:
- [`stabilityai/stable-diffusion-3-medium-diffusers`](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers)
- [`stabilityai/stable-diffusion-3.5-medium-diffusers`](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium-diffusers)
- [`stabilityai/stable-diffusion-3.5-large-diffusers`](https://huggingface.co/stabilityai/stable-diffusion-3.5-large-diffusers)

## Memory Optimisations for SD3

SD3 uses three text encoders, one if which is the very large T5-XXL model. This makes it challenging to run the model on GPUs with less than 24GB of VRAM, even when using `fp16` precision. The following section outlines a few memory optimizations in Diffusers that make it easier to run SD3 on low resource hardware.
Expand Down
119 changes: 111 additions & 8 deletions scripts/convert_sd3_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", type=str)
parser.add_argument("--output_path", type=str)
parser.add_argument("--dtype", type=str, default="fp16")
parser.add_argument("--dtype", type=str)

args = parser.parse_args()
dtype = torch.float16 if args.dtype == "fp16" else torch.float32


def load_original_checkpoint(ckpt_path):
Expand All @@ -40,7 +39,9 @@ def swap_scale_shift(weight, dim):
return new_weight


def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_layers, caption_projection_dim):
def convert_sd3_transformer_checkpoint_to_diffusers(
original_state_dict, num_layers, caption_projection_dim, dual_attention_layers, has_qk_norm
):
converted_state_dict = {}

# Positional and patch embeddings.
Expand Down Expand Up @@ -110,6 +111,21 @@ def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_lay
converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v])
converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias])

# qk norm
if has_qk_norm:
converted_state_dict[f"transformer_blocks.{i}.attn.norm_q.weight"] = original_state_dict.pop(
f"joint_blocks.{i}.x_block.attn.ln_q.weight"
)
converted_state_dict[f"transformer_blocks.{i}.attn.norm_k.weight"] = original_state_dict.pop(
f"joint_blocks.{i}.x_block.attn.ln_k.weight"
)
converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_q.weight"] = original_state_dict.pop(
f"joint_blocks.{i}.context_block.attn.ln_q.weight"
)
converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_k.weight"] = original_state_dict.pop(
f"joint_blocks.{i}.context_block.attn.ln_k.weight"
)

# output projections.
converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = original_state_dict.pop(
f"joint_blocks.{i}.x_block.attn.proj.weight"
Expand All @@ -125,6 +141,39 @@ def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_lay
f"joint_blocks.{i}.context_block.attn.proj.bias"
)

# attn2
if i in dual_attention_layers:
# Q, K, V
sample_q2, sample_k2, sample_v2 = torch.chunk(
original_state_dict.pop(f"joint_blocks.{i}.x_block.attn2.qkv.weight"), 3, dim=0
)
sample_q2_bias, sample_k2_bias, sample_v2_bias = torch.chunk(
original_state_dict.pop(f"joint_blocks.{i}.x_block.attn2.qkv.bias"), 3, dim=0
)
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = torch.cat([sample_q2])
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = torch.cat([sample_q2_bias])
converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = torch.cat([sample_k2])
converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = torch.cat([sample_k2_bias])
converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = torch.cat([sample_v2])
converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = torch.cat([sample_v2_bias])

# qk norm
if has_qk_norm:
converted_state_dict[f"transformer_blocks.{i}.attn2.norm_q.weight"] = original_state_dict.pop(
f"joint_blocks.{i}.x_block.attn2.ln_q.weight"
)
converted_state_dict[f"transformer_blocks.{i}.attn2.norm_k.weight"] = original_state_dict.pop(
f"joint_blocks.{i}.x_block.attn2.ln_k.weight"
)

# output projections.
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = original_state_dict.pop(
f"joint_blocks.{i}.x_block.attn2.proj.weight"
)
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = original_state_dict.pop(
f"joint_blocks.{i}.x_block.attn2.proj.bias"
)

# norms.
converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = original_state_dict.pop(
f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight"
Expand Down Expand Up @@ -195,25 +244,79 @@ def is_vae_in_checkpoint(original_state_dict):
)


def get_attn2_layers(state_dict):
attn2_layers = []
for key in state_dict.keys():
if "attn2." in key:
# Extract the layer number from the key
layer_num = int(key.split(".")[1])
attn2_layers.append(layer_num)
return tuple(sorted(set(attn2_layers)))


def get_pos_embed_max_size(state_dict):
num_patches = state_dict["pos_embed"].shape[1]
pos_embed_max_size = int(num_patches**0.5)
return pos_embed_max_size


def get_caption_projection_dim(state_dict):
caption_projection_dim = state_dict["context_embedder.weight"].shape[0]
return caption_projection_dim


def main(args):
original_ckpt = load_original_checkpoint(args.checkpoint_path)
original_dtype = next(iter(original_ckpt.values())).dtype

# Initialize dtype with a default value
dtype = None

if args.dtype is None:
dtype = original_dtype
elif args.dtype == "fp16":
dtype = torch.float16
elif args.dtype == "bf16":
dtype = torch.bfloat16
elif args.dtype == "fp32":
dtype = torch.float32
else:
raise ValueError(f"Unsupported dtype: {args.dtype}")

if dtype != original_dtype:
print(
f"Checkpoint dtype {original_dtype} does not match requested dtype {dtype}. This can lead to unexpected results, proceed with caution."
)

num_layers = list(set(int(k.split(".", 2)[1]) for k in original_ckpt if "joint_blocks" in k))[-1] + 1 # noqa: C401
caption_projection_dim = 1536

caption_projection_dim = get_caption_projection_dim(original_ckpt)

# () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5
attn2_layers = get_attn2_layers(original_ckpt)

# sd3.5 use qk norm("rms_norm")
has_qk_norm = any("ln_q" in key for key in original_ckpt.keys())

# sd3.5 2b use pox_embed_max_size=384 and sd3.0 and sd3.5 8b use 192
pos_embed_max_size = get_pos_embed_max_size(original_ckpt)

converted_transformer_state_dict = convert_sd3_transformer_checkpoint_to_diffusers(
original_ckpt, num_layers, caption_projection_dim
original_ckpt, num_layers, caption_projection_dim, attn2_layers, has_qk_norm
)

with CTX():
transformer = SD3Transformer2DModel(
sample_size=64,
sample_size=128,
patch_size=2,
in_channels=16,
joint_attention_dim=4096,
num_layers=num_layers,
caption_projection_dim=caption_projection_dim,
num_attention_heads=24,
pos_embed_max_size=192,
num_attention_heads=num_layers,
pos_embed_max_size=pos_embed_max_size,
qk_norm="rms_norm" if has_qk_norm else None,
dual_attention_layers=attn2_layers,
)
if is_accelerate_available():
load_model_dict_into_meta(transformer, converted_transformer_state_dict)
Expand Down
49 changes: 45 additions & 4 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
from .attention_processor import Attention, JointAttnProcessor2_0
from .embeddings import SinusoidalPositionalEmbedding
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -100,13 +100,25 @@ class JointTransformerBlock(nn.Module):
processing of `context` conditions.
"""

def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
context_pre_only: bool = False,
qk_norm: Optional[str] = None,
use_dual_attention: bool = False,
):
super().__init__()

self.use_dual_attention = use_dual_attention
self.context_pre_only = context_pre_only
context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"

self.norm1 = AdaLayerNormZero(dim)
if use_dual_attention:
self.norm1 = SD35AdaLayerNormZeroX(dim)
else:
self.norm1 = AdaLayerNormZero(dim)

if context_norm_type == "ada_norm_continous":
self.norm1_context = AdaLayerNormContinuous(
Expand All @@ -118,12 +130,14 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_onl
raise ValueError(
f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
)

if hasattr(F, "scaled_dot_product_attention"):
processor = JointAttnProcessor2_0()
else:
raise ValueError(
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
)

self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
Expand All @@ -134,8 +148,25 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_onl
context_pre_only=context_pre_only,
bias=True,
processor=processor,
qk_norm=qk_norm,
eps=1e-6,
)

if use_dual_attention:
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
processor=processor,
qk_norm=qk_norm,
eps=1e-6,
)
else:
self.attn2 = None

self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")

Expand All @@ -159,7 +190,12 @@ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
def forward(
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
):
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
if self.use_dual_attention:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
hidden_states, emb=temb
)
else:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)

if self.context_pre_only:
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
Expand All @@ -177,6 +213,11 @@ def forward(
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output

if self.use_dual_attention:
attn_output2 = self.attn2(hidden_states=norm_hidden_states2)
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
hidden_states = hidden_states + attn_output2

norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
if self._chunk_size is not None:
Expand Down
Loading
Loading