Skip to content

Commit 4ea9e82

Browse files
zRzRzRzRzRzRzRa-r-r-o-wyiyixuxu
authored andcommitted
CogView3Plus DiT (#9570)
* merge 9588 * max_shard_size="5GB" for colab running * conversion script updates; modeling test; refactor transformer * make fix-copies * Update convert_cogview3_to_diffusers.py * initial pipeline draft * make style * fight bugs 🐛🪳 * add example * add tests; refactor * make style * make fix-copies * add co-author YiYi Xu <[email protected]> * remove files * add docs * add co-author Co-Authored-By: YiYi Xu <[email protected]> * fight docs * address reviews * make style * make model work * remove qkv fusion * remove qkv fusion tets * address review comments * fix make fix-copies error * remove None and TODO * for FP16(draft) * make style * remove dynamic cfg * remove pooled_projection_dim as a parameter * fix tests --------- Co-authored-by: Aryan <[email protected]> Co-authored-by: YiYi Xu <[email protected]>
1 parent 0e33264 commit 4ea9e82

21 files changed

+1974
-2
lines changed

docs/source/en/_toctree.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,8 @@
242242
title: AuraFlowTransformer2DModel
243243
- local: api/models/cogvideox_transformer3d
244244
title: CogVideoXTransformer3DModel
245+
- local: api/models/cogview3plus_transformer2d
246+
title: CogView3PlusTransformer2DModel
245247
- local: api/models/dit_transformer2d
246248
title: DiTTransformer2DModel
247249
- local: api/models/flux_transformer
@@ -320,6 +322,8 @@
320322
title: BLIP-Diffusion
321323
- local: api/pipelines/cogvideox
322324
title: CogVideoX
325+
- local: api/pipelines/cogview3
326+
title: CogView3
323327
- local: api/pipelines/consistency_models
324328
title: Consistency Models
325329
- local: api/pipelines/controlnet
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License. -->
11+
12+
# CogView3PlusTransformer2DModel
13+
14+
A Diffusion Transformer model for 2D data from [CogView3Plus](https://github.com/THUDM/CogView3) was introduced in [CogView3: Finer and Faster Text-to-Image Generation via Relay Diffusion](https://huggingface.co/papers/2403.05121) by Tsinghua University & ZhipuAI.
15+
16+
The model can be loaded with the following code snippet.
17+
18+
```python
19+
from diffusers import CogView3PlusTransformer2DModel
20+
21+
vae = CogView3PlusTransformer2DModel.from_pretrained("THUDM/CogView3Plus-3b", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
22+
```
23+
24+
## CogView3PlusTransformer2DModel
25+
26+
[[autodoc]] CogView3PlusTransformer2DModel
27+
28+
## Transformer2DModelOutput
29+
30+
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
-->
15+
16+
# CogView3Plus
17+
18+
[CogView3: Finer and Faster Text-to-Image Generation via Relay Diffusion](https://huggingface.co/papers/2403.05121) from Tsinghua University & ZhipuAI, by Wendi Zheng, Jiayan Teng, Zhuoyi Yang, Weihan Wang, Jidong Chen, Xiaotao Gu, Yuxiao Dong, Ming Ding, Jie Tang.
19+
20+
The abstract from the paper is:
21+
22+
*Recent advancements in text-to-image generative systems have been largely driven by diffusion models. However, single-stage text-to-image diffusion models still face challenges, in terms of computational efficiency and the refinement of image details. To tackle the issue, we propose CogView3, an innovative cascaded framework that enhances the performance of text-to-image diffusion. CogView3 is the first model implementing relay diffusion in the realm of text-to-image generation, executing the task by first creating low-resolution images and subsequently applying relay-based super-resolution. This methodology not only results in competitive text-to-image outputs but also greatly reduces both training and inference costs. Our experimental results demonstrate that CogView3 outperforms SDXL, the current state-of-the-art open-source text-to-image diffusion model, by 77.0% in human evaluations, all while requiring only about 1/2 of the inference time. The distilled variant of CogView3 achieves comparable performance while only utilizing 1/10 of the inference time by SDXL.*
23+
24+
<Tip>
25+
26+
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
27+
28+
</Tip>
29+
30+
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).
31+
32+
## CogView3PlusPipeline
33+
34+
[[autodoc]] CogView3PlusPipeline
35+
- all
36+
- __call__
37+
38+
## CogView3PipelineOutput
39+
40+
[[autodoc]] pipelines.cogview3.pipeline_output.CogView3PipelineOutput
Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
"""
2+
Convert a CogView3 checkpoint to the Diffusers format.
3+
4+
This script converts a CogView3 checkpoint to the Diffusers format, which can then be used
5+
with the Diffusers library.
6+
7+
Example usage:
8+
python scripts/convert_cogview3_to_diffusers.py \
9+
--transformer_checkpoint_path 'your path/cogview3plus_3b/1/mp_rank_00_model_states.pt' \
10+
--vae_checkpoint_path 'your path/3plus_ae/imagekl_ch16.pt' \
11+
--output_path "/raid/yiyi/cogview3_diffusers" \
12+
--dtype "bf16"
13+
14+
Arguments:
15+
--transformer_checkpoint_path: Path to Transformer state dict.
16+
--vae_checkpoint_path: Path to VAE state dict.
17+
--output_path: The path to save the converted model.
18+
--push_to_hub: Whether to push the converted checkpoint to the HF Hub or not. Defaults to `False`.
19+
--text_encoder_cache_dir: Cache directory where text encoder is located. Defaults to None, which means HF_HOME will be used
20+
--dtype: The dtype to save the model in (default: "bf16", options: "fp16", "bf16", "fp32"). If None, the dtype of the state dict is considered.
21+
22+
Default is "bf16" because CogView3 uses bfloat16 for Training.
23+
24+
Note: You must provide either --original_state_dict_repo_id or --checkpoint_path.
25+
"""
26+
27+
import argparse
28+
from contextlib import nullcontext
29+
30+
import torch
31+
from accelerate import init_empty_weights
32+
from transformers import T5EncoderModel, T5Tokenizer
33+
34+
from diffusers import AutoencoderKL, CogVideoXDDIMScheduler, CogView3PlusPipeline, CogView3PlusTransformer2DModel
35+
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
36+
from diffusers.utils.import_utils import is_accelerate_available
37+
38+
39+
CTX = init_empty_weights if is_accelerate_available else nullcontext
40+
41+
TOKENIZER_MAX_LENGTH = 224
42+
43+
parser = argparse.ArgumentParser()
44+
parser.add_argument("--transformer_checkpoint_path", default=None, type=str)
45+
parser.add_argument("--vae_checkpoint_path", default=None, type=str)
46+
parser.add_argument("--output_path", required=True, type=str)
47+
parser.add_argument("--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving")
48+
parser.add_argument("--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory")
49+
parser.add_argument("--dtype", type=str, default="bf16")
50+
51+
args = parser.parse_args()
52+
53+
54+
# this is specific to `AdaLayerNormContinuous`:
55+
# diffusers implementation split the linear projection into the scale, shift while CogView3 split it tino shift, scale
56+
def swap_scale_shift(weight, dim):
57+
shift, scale = weight.chunk(2, dim=0)
58+
new_weight = torch.cat([scale, shift], dim=0)
59+
return new_weight
60+
61+
62+
def convert_cogview3_transformer_checkpoint_to_diffusers(ckpt_path):
63+
original_state_dict = torch.load(ckpt_path, map_location="cpu")
64+
original_state_dict = original_state_dict["module"]
65+
original_state_dict = {k.replace("model.diffusion_model.", ""): v for k, v in original_state_dict.items()}
66+
67+
new_state_dict = {}
68+
69+
# Convert patch_embed
70+
new_state_dict["patch_embed.proj.weight"] = original_state_dict.pop("mixins.patch_embed.proj.weight")
71+
new_state_dict["patch_embed.proj.bias"] = original_state_dict.pop("mixins.patch_embed.proj.bias")
72+
new_state_dict["patch_embed.text_proj.weight"] = original_state_dict.pop("mixins.patch_embed.text_proj.weight")
73+
new_state_dict["patch_embed.text_proj.bias"] = original_state_dict.pop("mixins.patch_embed.text_proj.bias")
74+
75+
# Convert time_condition_embed
76+
new_state_dict["time_condition_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
77+
"time_embed.0.weight"
78+
)
79+
new_state_dict["time_condition_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
80+
"time_embed.0.bias"
81+
)
82+
new_state_dict["time_condition_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
83+
"time_embed.2.weight"
84+
)
85+
new_state_dict["time_condition_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
86+
"time_embed.2.bias"
87+
)
88+
new_state_dict["time_condition_embed.condition_embedder.linear_1.weight"] = original_state_dict.pop(
89+
"label_emb.0.0.weight"
90+
)
91+
new_state_dict["time_condition_embed.condition_embedder.linear_1.bias"] = original_state_dict.pop(
92+
"label_emb.0.0.bias"
93+
)
94+
new_state_dict["time_condition_embed.condition_embedder.linear_2.weight"] = original_state_dict.pop(
95+
"label_emb.0.2.weight"
96+
)
97+
new_state_dict["time_condition_embed.condition_embedder.linear_2.bias"] = original_state_dict.pop(
98+
"label_emb.0.2.bias"
99+
)
100+
101+
# Convert transformer blocks
102+
for i in range(30):
103+
block_prefix = f"transformer_blocks.{i}."
104+
old_prefix = f"transformer.layers.{i}."
105+
adaln_prefix = f"mixins.adaln.adaln_modules.{i}."
106+
107+
new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(adaln_prefix + "1.weight")
108+
new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(adaln_prefix + "1.bias")
109+
110+
qkv_weight = original_state_dict.pop(old_prefix + "attention.query_key_value.weight")
111+
qkv_bias = original_state_dict.pop(old_prefix + "attention.query_key_value.bias")
112+
q, k, v = qkv_weight.chunk(3, dim=0)
113+
q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0)
114+
115+
new_state_dict[block_prefix + "attn1.to_q.weight"] = q
116+
new_state_dict[block_prefix + "attn1.to_q.bias"] = q_bias
117+
new_state_dict[block_prefix + "attn1.to_k.weight"] = k
118+
new_state_dict[block_prefix + "attn1.to_k.bias"] = k_bias
119+
new_state_dict[block_prefix + "attn1.to_v.weight"] = v
120+
new_state_dict[block_prefix + "attn1.to_v.bias"] = v_bias
121+
122+
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = original_state_dict.pop(
123+
old_prefix + "attention.dense.weight"
124+
)
125+
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = original_state_dict.pop(
126+
old_prefix + "attention.dense.bias"
127+
)
128+
129+
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = original_state_dict.pop(
130+
old_prefix + "mlp.dense_h_to_4h.weight"
131+
)
132+
new_state_dict[block_prefix + "ff.net.0.proj.bias"] = original_state_dict.pop(
133+
old_prefix + "mlp.dense_h_to_4h.bias"
134+
)
135+
new_state_dict[block_prefix + "ff.net.2.weight"] = original_state_dict.pop(
136+
old_prefix + "mlp.dense_4h_to_h.weight"
137+
)
138+
new_state_dict[block_prefix + "ff.net.2.bias"] = original_state_dict.pop(old_prefix + "mlp.dense_4h_to_h.bias")
139+
140+
# Convert final norm and projection
141+
new_state_dict["norm_out.linear.weight"] = swap_scale_shift(
142+
original_state_dict.pop("mixins.final_layer.adaln.1.weight"), dim=0
143+
)
144+
new_state_dict["norm_out.linear.bias"] = swap_scale_shift(
145+
original_state_dict.pop("mixins.final_layer.adaln.1.bias"), dim=0
146+
)
147+
new_state_dict["proj_out.weight"] = original_state_dict.pop("mixins.final_layer.linear.weight")
148+
new_state_dict["proj_out.bias"] = original_state_dict.pop("mixins.final_layer.linear.bias")
149+
150+
return new_state_dict
151+
152+
153+
def convert_cogview3_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
154+
original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
155+
return convert_ldm_vae_checkpoint(original_state_dict, vae_config)
156+
157+
158+
def main(args):
159+
if args.dtype == "fp16":
160+
dtype = torch.float16
161+
elif args.dtype == "bf16":
162+
dtype = torch.bfloat16
163+
elif args.dtype == "fp32":
164+
dtype = torch.float32
165+
else:
166+
raise ValueError(f"Unsupported dtype: {args.dtype}")
167+
168+
transformer = None
169+
vae = None
170+
171+
if args.transformer_checkpoint_path is not None:
172+
converted_transformer_state_dict = convert_cogview3_transformer_checkpoint_to_diffusers(
173+
args.transformer_checkpoint_path
174+
)
175+
transformer = CogView3PlusTransformer2DModel()
176+
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
177+
if dtype is not None:
178+
# Original checkpoint data type will be preserved
179+
transformer = transformer.to(dtype=dtype)
180+
181+
if args.vae_checkpoint_path is not None:
182+
vae_config = {
183+
"in_channels": 3,
184+
"out_channels": 3,
185+
"down_block_types": ("DownEncoderBlock2D",) * 4,
186+
"up_block_types": ("UpDecoderBlock2D",) * 4,
187+
"block_out_channels": (128, 512, 1024, 1024),
188+
"layers_per_block": 3,
189+
"act_fn": "silu",
190+
"latent_channels": 16,
191+
"norm_num_groups": 32,
192+
"sample_size": 1024,
193+
"scaling_factor": 1.0,
194+
"force_upcast": True,
195+
"use_quant_conv": False,
196+
"use_post_quant_conv": False,
197+
"mid_block_add_attention": False,
198+
}
199+
converted_vae_state_dict = convert_cogview3_vae_checkpoint_to_diffusers(args.vae_checkpoint_path, vae_config)
200+
vae = AutoencoderKL(**vae_config)
201+
vae.load_state_dict(converted_vae_state_dict, strict=True)
202+
if dtype is not None:
203+
vae = vae.to(dtype=dtype)
204+
205+
text_encoder_id = "google/t5-v1_1-xxl"
206+
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
207+
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
208+
209+
# Apparently, the conversion does not work anymore without this :shrug:
210+
for param in text_encoder.parameters():
211+
param.data = param.data.contiguous()
212+
213+
scheduler = CogVideoXDDIMScheduler.from_config(
214+
{
215+
"snr_shift_scale": 4.0,
216+
"beta_end": 0.012,
217+
"beta_schedule": "scaled_linear",
218+
"beta_start": 0.00085,
219+
"clip_sample": False,
220+
"num_train_timesteps": 1000,
221+
"prediction_type": "v_prediction",
222+
"rescale_betas_zero_snr": True,
223+
"set_alpha_to_one": True,
224+
"timestep_spacing": "trailing",
225+
}
226+
)
227+
228+
pipe = CogView3PlusPipeline(
229+
tokenizer=tokenizer,
230+
text_encoder=text_encoder,
231+
vae=vae,
232+
transformer=transformer,
233+
scheduler=scheduler,
234+
)
235+
236+
# This is necessary for users with insufficient memory, such as those using Colab and notebooks, as it can
237+
# save some memory used for model loading.
238+
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)
239+
240+
241+
if __name__ == "__main__":
242+
main(args)

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
"AutoencoderOobleck",
8585
"AutoencoderTiny",
8686
"CogVideoXTransformer3DModel",
87+
"CogView3PlusTransformer2DModel",
8788
"ConsistencyDecoderVAE",
8889
"ControlNetModel",
8990
"ControlNetXSAdapter",
@@ -258,6 +259,7 @@
258259
"CogVideoXImageToVideoPipeline",
259260
"CogVideoXPipeline",
260261
"CogVideoXVideoToVideoPipeline",
262+
"CogView3PlusPipeline",
261263
"CycleDiffusionPipeline",
262264
"FluxControlNetImg2ImgPipeline",
263265
"FluxControlNetInpaintPipeline",
@@ -559,6 +561,7 @@
559561
AutoencoderOobleck,
560562
AutoencoderTiny,
561563
CogVideoXTransformer3DModel,
564+
CogView3PlusTransformer2DModel,
562565
ConsistencyDecoderVAE,
563566
ControlNetModel,
564567
ControlNetXSAdapter,
@@ -711,6 +714,7 @@
711714
CogVideoXImageToVideoPipeline,
712715
CogVideoXPipeline,
713716
CogVideoXVideoToVideoPipeline,
717+
CogView3PlusPipeline,
714718
CycleDiffusionPipeline,
715719
FluxControlNetImg2ImgPipeline,
716720
FluxControlNetInpaintPipeline,

src/diffusers/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
_import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"]
5555
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
5656
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
57+
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
5758
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
5859
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
5960
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
@@ -98,6 +99,7 @@
9899
from .transformers import (
99100
AuraFlowTransformer2DModel,
100101
CogVideoXTransformer3DModel,
102+
CogView3PlusTransformer2DModel,
101103
DiTTransformer2DModel,
102104
DualTransformer2DModel,
103105
FluxTransformer2DModel,

src/diffusers/models/attention_processor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def __init__(
122122
out_dim: int = None,
123123
context_pre_only=None,
124124
pre_only=False,
125+
elementwise_affine: bool = True,
125126
):
126127
super().__init__()
127128

@@ -179,8 +180,8 @@ def __init__(
179180
self.norm_q = None
180181
self.norm_k = None
181182
elif qk_norm == "layer_norm":
182-
self.norm_q = nn.LayerNorm(dim_head, eps=eps)
183-
self.norm_k = nn.LayerNorm(dim_head, eps=eps)
183+
self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
184+
self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
184185
elif qk_norm == "fp32_layer_norm":
185186
self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
186187
self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)

0 commit comments

Comments
 (0)