Skip to content

Commit 45b6cb6

Browse files
merge 9588
1 parent 99f6082 commit 45b6cb6

16 files changed

+1590
-5
lines changed
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
"""
2+
Convert a CogView3 checkpoint to the Diffusers format.
3+
4+
This script converts a CogView3 checkpoint to the Diffusers format, which can then be used
5+
with the Diffusers library.
6+
7+
Example usage:
8+
python scripts/convert_cogview3_to_diffusers.py \
9+
--original_state_dict_repo_id "THUDM/cogview3" \
10+
--filename "cogview3.pt" \
11+
--transformer \
12+
--output_path "./cogview3_diffusers" \
13+
--dtype "bf16"
14+
15+
Alternatively, if you have a local checkpoint:
16+
python scripts/convert_cogview3_to_diffusers.py \
17+
--checkpoint_path '/raid/.cache/huggingface/models--ZP2HF--CogView3-SAT/snapshots/ca86ce9ba94f9a7f2dd109e7a59e4c8ad04121be/cogview3plus_3b/1/mp_rank_00_model_states.pt' \
18+
--transformer \
19+
--output_path "/raid/yiyi/cogview3_diffusers" \
20+
--dtype "bf16"
21+
22+
Arguments:
23+
--original_state_dict_repo_id: The Hugging Face repo ID containing the original checkpoint.
24+
--filename: The filename of the checkpoint in the repo (default: "flux.safetensors").
25+
--checkpoint_path: Path to a local checkpoint file (alternative to repo_id and filename).
26+
--transformer: Flag to convert the transformer model.
27+
--output_path: The path to save the converted model.
28+
--dtype: The dtype to save the model in (default: "bf16", options: "fp16", "bf16", "fp32").
29+
30+
Note: You must provide either --original_state_dict_repo_id or --checkpoint_path.
31+
"""
32+
33+
import argparse
34+
from contextlib import nullcontext
35+
36+
import torch
37+
from accelerate import init_empty_weights
38+
from huggingface_hub import hf_hub_download
39+
40+
from diffusers import CogView3PlusTransformer2DModel
41+
from diffusers.utils.import_utils import is_accelerate_available
42+
43+
44+
CTX = init_empty_weights if is_accelerate_available else nullcontext
45+
46+
parser = argparse.ArgumentParser()
47+
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
48+
parser.add_argument("--filename", default="flux.safetensors", type=str)
49+
parser.add_argument("--checkpoint_path", default=None, type=str)
50+
parser.add_argument("--transformer", action="store_true")
51+
parser.add_argument("--output_path", type=str)
52+
parser.add_argument("--dtype", type=str, default="bf16")
53+
54+
args = parser.parse_args()
55+
56+
57+
def load_original_checkpoint(args):
58+
if args.original_state_dict_repo_id is not None:
59+
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
60+
elif args.checkpoint_path is not None:
61+
ckpt_path = args.checkpoint_path
62+
else:
63+
raise ValueError("Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
64+
65+
original_state_dict = torch.load(ckpt_path, map_location="cpu")
66+
return original_state_dict
67+
68+
69+
# this is specific to `AdaLayerNormContinuous`:
70+
# diffusers imnplementation split the linear projection into the scale, shift while CogView3 split it tino shift, scale
71+
def swap_scale_shift(weight, dim):
72+
shift, scale = weight.chunk(2, dim=0)
73+
new_weight = torch.cat([scale, shift], dim=0)
74+
return new_weight
75+
76+
77+
def convert_cogview3_transformer_checkpoint_to_diffusers(original_state_dict):
78+
new_state_dict = {}
79+
80+
# Convert pos_embed
81+
new_state_dict["pos_embed.proj.weight"] = original_state_dict.pop("mixins.patch_embed.proj.weight")
82+
new_state_dict["pos_embed.proj.bias"] = original_state_dict.pop("mixins.patch_embed.proj.bias")
83+
new_state_dict["pos_embed.text_proj.weight"] = original_state_dict.pop("mixins.patch_embed.text_proj.weight")
84+
new_state_dict["pos_embed.text_proj.bias"] = original_state_dict.pop("mixins.patch_embed.text_proj.bias")
85+
86+
# Convert time_text_embed
87+
new_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
88+
"time_embed.0.weight"
89+
)
90+
new_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop("time_embed.0.bias")
91+
new_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
92+
"time_embed.2.weight"
93+
)
94+
new_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop("time_embed.2.bias")
95+
new_state_dict["time_text_embed.text_embedder.linear_1.weight"] = original_state_dict.pop("label_emb.0.0.weight")
96+
new_state_dict["time_text_embed.text_embedder.linear_1.bias"] = original_state_dict.pop("label_emb.0.0.bias")
97+
new_state_dict["time_text_embed.text_embedder.linear_2.weight"] = original_state_dict.pop("label_emb.0.2.weight")
98+
new_state_dict["time_text_embed.text_embedder.linear_2.bias"] = original_state_dict.pop("label_emb.0.2.bias")
99+
100+
# Convert transformer blocks
101+
for i in range(30):
102+
block_prefix = f"transformer_blocks.{i}."
103+
old_prefix = f"transformer.layers.{i}."
104+
adaln_prefix = f"mixins.adaln.adaln_modules.{i}."
105+
106+
new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(adaln_prefix + "1.weight")
107+
new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(adaln_prefix + "1.bias")
108+
109+
qkv_weight = original_state_dict.pop(old_prefix + "attention.query_key_value.weight")
110+
qkv_bias = original_state_dict.pop(old_prefix + "attention.query_key_value.bias")
111+
q, k, v = qkv_weight.chunk(3, dim=0)
112+
q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0)
113+
114+
new_state_dict[block_prefix + "attn.to_q.weight"] = q
115+
new_state_dict[block_prefix + "attn.to_q.bias"] = q_bias
116+
new_state_dict[block_prefix + "attn.to_k.weight"] = k
117+
new_state_dict[block_prefix + "attn.to_k.bias"] = k_bias
118+
new_state_dict[block_prefix + "attn.to_v.weight"] = v
119+
new_state_dict[block_prefix + "attn.to_v.bias"] = v_bias
120+
121+
new_state_dict[block_prefix + "attn.to_out.0.weight"] = original_state_dict.pop(
122+
old_prefix + "attention.dense.weight"
123+
)
124+
new_state_dict[block_prefix + "attn.to_out.0.bias"] = original_state_dict.pop(
125+
old_prefix + "attention.dense.bias"
126+
)
127+
128+
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = original_state_dict.pop(
129+
old_prefix + "mlp.dense_h_to_4h.weight"
130+
)
131+
new_state_dict[block_prefix + "ff.net.0.proj.bias"] = original_state_dict.pop(
132+
old_prefix + "mlp.dense_h_to_4h.bias"
133+
)
134+
new_state_dict[block_prefix + "ff.net.2.weight"] = original_state_dict.pop(
135+
old_prefix + "mlp.dense_4h_to_h.weight"
136+
)
137+
new_state_dict[block_prefix + "ff.net.2.bias"] = original_state_dict.pop(old_prefix + "mlp.dense_4h_to_h.bias")
138+
139+
# Convert final norm and projection
140+
new_state_dict["norm_out.linear.weight"] = swap_scale_shift(
141+
original_state_dict.pop("mixins.final_layer.adaln.1.weight"), dim=0
142+
)
143+
new_state_dict["norm_out.linear.bias"] = swap_scale_shift(
144+
original_state_dict.pop("mixins.final_layer.adaln.1.bias"), dim=0
145+
)
146+
new_state_dict["proj_out.weight"] = original_state_dict.pop("mixins.final_layer.linear.weight")
147+
new_state_dict["proj_out.bias"] = original_state_dict.pop("mixins.final_layer.linear.bias")
148+
149+
return new_state_dict
150+
151+
152+
def main(args):
153+
original_ckpt = load_original_checkpoint(args)
154+
original_ckpt = original_ckpt["module"]
155+
original_ckpt = {k.replace("model.diffusion_model.", ""): v for k, v in original_ckpt.items()}
156+
157+
original_dtype = next(iter(original_ckpt.values())).dtype
158+
dtype = None
159+
if args.dtype is None:
160+
dtype = original_dtype
161+
elif args.dtype == "fp16":
162+
dtype = torch.float16
163+
elif args.dtype == "bf16":
164+
dtype = torch.bfloat16
165+
elif args.dtype == "fp32":
166+
dtype = torch.float32
167+
else:
168+
raise ValueError(f"Unsupported dtype: {args.dtype}")
169+
170+
if args.transformer:
171+
converted_transformer_state_dict = convert_cogview3_transformer_checkpoint_to_diffusers(original_ckpt)
172+
transformer = CogView3PlusTransformer2DModel()
173+
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
174+
175+
print(f"Saving CogView3 Transformer in Diffusers format in {args.output_path}/transformer")
176+
transformer.to(dtype).save_pretrained(f"{args.output_path}/transformer")
177+
178+
if len(original_ckpt) > 0:
179+
print(f"Warning: {len(original_ckpt)} keys were not converted and will be saved as is: {original_ckpt.keys()}")
180+
181+
182+
if __name__ == "__main__":
183+
main(args)

show_model.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import torch
2+
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
3+
from diffusers import AutoencoderKL
4+
from huggingface_hub import hf_hub_download
5+
from sgm.models.autoencoder import AutoencodingEngine
6+
7+
# (1) create vae_sat
8+
# AutoencodingEngine initialization arguments:
9+
encoder_config={'target': 'sgm.modules.diffusionmodules.model.Encoder', 'params': {'attn_type': 'vanilla', 'double_z': True, 'z_channels': 16, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 4, 8, 8], 'num_res_blocks': 3, 'attn_resolutions': [], 'mid_attn': False, 'dropout': 0.0}}
10+
decoder_config={'target': 'sgm.modules.diffusionmodules.model.Decoder', 'params': {'attn_type': 'vanilla', 'double_z': True, 'z_channels': 16, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 4, 8, 8], 'num_res_blocks': 3, 'attn_resolutions': [], 'mid_attn': False, 'dropout': 0.0}}
11+
loss_config={'target': 'torch.nn.Identity'}
12+
regularizer_config={'target': 'sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer'}
13+
optimizer_config=None
14+
lr_g_factor=1.0
15+
ckpt_path="/raid/.cache/huggingface/models--ZP2HF--CogView3-SAT/snapshots/ca86ce9ba94f9a7f2dd109e7a59e4c8ad04121be/3plus_ae/imagekl_ch16.pt"
16+
ignore_keys= []
17+
kwargs = {"monitor": "val/rec_loss"}
18+
vae_sat = AutoencodingEngine(
19+
encoder_config=encoder_config,
20+
decoder_config=decoder_config,
21+
loss_config=loss_config,
22+
regularizer_config=regularizer_config,
23+
optimizer_config=optimizer_config,
24+
lr_g_factor=lr_g_factor,
25+
ckpt_path=ckpt_path,
26+
ignore_keys=ignore_keys,
27+
**kwargs)
28+
29+
30+
31+
# (2) create vae (diffusers)
32+
ckpt_path_vae_cogview3 = hf_hub_download(repo_id="ZP2HF/CogView3-SAT", subfolder="3plus_ae", filename="imagekl_ch16.pt")
33+
cogview3_ckpt = torch.load(ckpt_path_vae_cogview3, map_location='cpu')["state_dict"]
34+
35+
in_channels = 3 # Inferred from encoder.conv_in.weight shape
36+
out_channels = 3 # Inferred from decoder.conv_out.weight shape
37+
down_block_types = ("DownEncoderBlock2D",) * 4 # Inferred from the presence of 4 encoder.down blocks
38+
up_block_types = ("UpDecoderBlock2D",) * 4 # Inferred from the presence of 4 decoder.up blocks
39+
block_out_channels = (128, 512, 1024, 1024) # Inferred from the channel sizes in encoder.down blocks
40+
layers_per_block = 3 # Inferred from the number of blocks in each encoder.down and decoder.up
41+
act_fn = "silu" # This is the default, cannot be inferred from state_dict
42+
latent_channels = 16 # Inferred from decoder.conv_in.weight shape
43+
norm_num_groups = 32 # This is the default, cannot be inferred from state_dict
44+
sample_size = 1024 # This is the default, cannot be inferred from state_dict
45+
scaling_factor = 0.18215 # This is the default, cannot be inferred from state_dict
46+
force_upcast = True # This is the default, cannot be inferred from state_dict
47+
use_quant_conv = False # Inferred from the presence of encoder.conv_out
48+
use_post_quant_conv = False # Inferred from the presence of decoder.conv_in
49+
mid_block_add_attention = False # Inferred from the absence of attention layers in mid blocks
50+
51+
vae = AutoencoderKL(
52+
in_channels=in_channels,
53+
out_channels=out_channels,
54+
down_block_types=down_block_types,
55+
up_block_types=up_block_types,
56+
block_out_channels=block_out_channels,
57+
layers_per_block=layers_per_block,
58+
act_fn=act_fn,
59+
latent_channels=latent_channels,
60+
norm_num_groups=norm_num_groups,
61+
sample_size=sample_size,
62+
scaling_factor=scaling_factor,
63+
force_upcast=force_upcast,
64+
use_quant_conv=use_quant_conv,
65+
use_post_quant_conv=use_post_quant_conv,
66+
mid_block_add_attention=mid_block_add_attention,
67+
)
68+
69+
vae.eval()
70+
vae_sat.eval()
71+
72+
converted_vae_state_dict = convert_ldm_vae_checkpoint(cogview3_ckpt, vae.config)
73+
vae.load_state_dict(converted_vae_state_dict, strict=False)
74+
75+
# (3) run forward pass for both models
76+
77+
# [2, 16, 128, 128] -> [2, 3, 1024, 1024
78+
z = torch.load("z.pt").float().to("cpu")
79+
80+
with torch.no_grad():
81+
print(" ")
82+
print(f" running forward pass for diffusers vae")
83+
out = vae.decode(z).sample
84+
print(f" ")
85+
print(f" running forward pass for sgm vae")
86+
out_sat = vae_sat.decode(z)
87+
88+
print(f" output shape: {out.shape}")
89+
print(f" expected output shape: {out_sat.shape}")
90+
assert out.shape == out_sat.shape
91+
assert (out - out_sat).abs().max() < 1e-4, f"max diff: {(out - out_sat).abs().max()}"

show_model_cogview.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import torch
2+
from diffusers import CogView3PlusTransformer2DModel
3+
4+
model = CogView3PlusTransformer2DModel.from_pretrained("/share/home/zyx/Models/CogView3Plus_hf/transformer",torch_dtype=torch.bfloat16)
5+
6+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7+
model.to(device)
8+
9+
batch_size = 1
10+
hidden_states = torch.ones((batch_size, 16, 256, 256), device=device, dtype=torch.bfloat16)
11+
timestep = torch.full((batch_size,), 999.0, device=device, dtype=torch.bfloat16)
12+
y = torch.ones((batch_size, 1536), device=device, dtype=torch.bfloat16)
13+
14+
# 模拟调用 forward 方法
15+
outputs = model(
16+
hidden_states=hidden_states, # hidden_states 输入
17+
timestep=timestep, # timestep 输入
18+
y=y, # 标签输入
19+
block_controlnet_hidden_states=None, # 如果不需要,可以忽略
20+
return_dict=True, # 保持默认值
21+
target_size=[(2048, 2048)],
22+
)
23+
24+
# 输出模型结果
25+
print("Output shape:", outputs.sample.shape)

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
"AutoencoderOobleck",
8585
"AutoencoderTiny",
8686
"CogVideoXTransformer3DModel",
87+
"CogView3PlusTransformer2DModel",
8788
"ConsistencyDecoderVAE",
8889
"ControlNetModel",
8990
"ControlNetXSAdapter",
@@ -258,6 +259,7 @@
258259
"CogVideoXImageToVideoPipeline",
259260
"CogVideoXPipeline",
260261
"CogVideoXVideoToVideoPipeline",
262+
"CogView3PlusPipeline",
261263
"CycleDiffusionPipeline",
262264
"FluxControlNetImg2ImgPipeline",
263265
"FluxControlNetInpaintPipeline",
@@ -558,6 +560,7 @@
558560
AutoencoderOobleck,
559561
AutoencoderTiny,
560562
CogVideoXTransformer3DModel,
563+
CogView3PlusTransformer2DModel,
561564
ConsistencyDecoderVAE,
562565
ControlNetModel,
563566
ControlNetXSAdapter,
@@ -710,6 +713,7 @@
710713
CogVideoXImageToVideoPipeline,
711714
CogVideoXPipeline,
712715
CogVideoXVideoToVideoPipeline,
716+
CogView3PlusPipeline,
713717
CycleDiffusionPipeline,
714718
FluxControlNetImg2ImgPipeline,
715719
FluxControlNetInpaintPipeline,

src/diffusers/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
_import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"]
5555
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
5656
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
57+
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
5758
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
5859
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
5960
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
@@ -98,6 +99,7 @@
9899
from .transformers import (
99100
AuraFlowTransformer2DModel,
100101
CogVideoXTransformer3DModel,
102+
CogView3PlusTransformer2DModel,
101103
DiTTransformer2DModel,
102104
DualTransformer2DModel,
103105
FluxTransformer2DModel,

src/diffusers/models/attention_processor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def __init__(
122122
out_dim: int = None,
123123
context_pre_only=None,
124124
pre_only=False,
125+
layrnorm_elementwise_affine: bool = True,
125126
):
126127
super().__init__()
127128

@@ -179,8 +180,8 @@ def __init__(
179180
self.norm_q = None
180181
self.norm_k = None
181182
elif qk_norm == "layer_norm":
182-
self.norm_q = nn.LayerNorm(dim_head, eps=eps)
183-
self.norm_k = nn.LayerNorm(dim_head, eps=eps)
183+
self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=layrnorm_elementwise_affine)
184+
self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=layrnorm_elementwise_affine)
184185
elif qk_norm == "fp32_layer_norm":
185186
self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
186187
self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)

0 commit comments

Comments
 (0)