-
Notifications
You must be signed in to change notification settings - Fork 6k
CogView3Plus DiT #9570
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CogView3Plus DiT #9570
Changes from 26 commits
45b6cb6
8abfe00
d668ad9
53935c0
8af6709
56b5599
8d14495
8e2ddd5
0e33401
3873c57
e77b988
958d3e7
ab8c65f
059812a
3138ad1
86909dc
7da234b
0114e3b
4e8de65
23a093e
0c1ebc3
f35850c
b439f4c
9d9b0b2
2158f00
1b06020
80e7cca
5cc2ade
0e4577d
53ff253
5c68cfd
9c3a81d
6d1f1c9
6603901
fe33f0a
db2a958
52f97f6
270d407
4ac4e52
21dd890
221b486
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
<!--Copyright 2024 The HuggingFace Team. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. --> | ||
|
||
# CogView3PlusTransformer2DModel | ||
|
||
A Diffusion Transformer model for 2D data from [CogView3Plus](https://github.com/THUDM/CogView3) was introduced in [CogView3: Finer and Faster Text-to-Image Generation via Relay Diffusion](https://huggingface.co/papers/2403.05121) by Tsinghua University & ZhipuAI. | ||
|
||
The model can be loaded with the following code snippet. | ||
|
||
```python | ||
from diffusers import CogView3PlusTransformer2DModel | ||
|
||
vae = CogView3PlusTransformer2DModel.from_pretrained("THUDM/CogView3Plus-3b", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda") | ||
``` | ||
|
||
## CogView3PlusTransformer2DModel | ||
|
||
[[autodoc]] CogView3PlusTransformer2DModel | ||
|
||
## Transformer2DModelOutput | ||
|
||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
<!--Copyright 2024 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
--> | ||
|
||
# CogView3Plus | ||
|
||
[CogView3: Finer and Faster Text-to-Image Generation via Relay Diffusion](https://huggingface.co/papers/2403.05121) from Tsinghua University & ZhipuAI, by Wendi Zheng, Jiayan Teng, Zhuoyi Yang, Weihan Wang, Jidong Chen, Xiaotao Gu, Yuxiao Dong, Ming Ding, Jie Tang. | ||
|
||
The abstract from the paper is: | ||
|
||
*Recent advancements in text-to-image generative systems have been largely driven by diffusion models. However, single-stage text-to-image diffusion models still face challenges, in terms of computational efficiency and the refinement of image details. To tackle the issue, we propose CogView3, an innovative cascaded framework that enhances the performance of text-to-image diffusion. CogView3 is the first model implementing relay diffusion in the realm of text-to-image generation, executing the task by first creating low-resolution images and subsequently applying relay-based super-resolution. This methodology not only results in competitive text-to-image outputs but also greatly reduces both training and inference costs. Our experimental results demonstrate that CogView3 outperforms SDXL, the current state-of-the-art open-source text-to-image diffusion model, by 77.0% in human evaluations, all while requiring only about 1/2 of the inference time. The distilled variant of CogView3 achieves comparable performance while only utilizing 1/10 of the inference time by SDXL.* | ||
|
||
<Tip> | ||
|
||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. | ||
|
||
</Tip> | ||
|
||
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM). | ||
|
||
## CogView3PlusPipeline | ||
|
||
[[autodoc]] CogView3PlusPipeline | ||
- all | ||
- __call__ | ||
|
||
## CogView3PipelineOutput | ||
|
||
[[autodoc]] pipelines.cogview3.pipeline_output.CogView3PipelineOutput |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,245 @@ | ||
""" | ||
Convert a CogView3 checkpoint to the Diffusers format. | ||
|
||
This script converts a CogView3 checkpoint to the Diffusers format, which can then be used | ||
with the Diffusers library. | ||
|
||
Example usage: | ||
python scripts/convert_cogview3_to_diffusers.py \ | ||
--transformer_checkpoint_path 'your path/cogview3plus_3b/1/mp_rank_00_model_states.pt' \ | ||
--vae_checkpoint_path 'your path/3plus_ae/imagekl_ch16.pt' \ | ||
--output_path "/raid/yiyi/cogview3_diffusers" \ | ||
--dtype "bf16" | ||
|
||
Arguments: | ||
--transformer_checkpoint_path: Path to Transformer state dict. | ||
--vae_checkpoint_path: Path to VAE state dict. | ||
--output_path: The path to save the converted model. | ||
--push_to_hub: Whether to push the converted checkpoint to the HF Hub or not. Defaults to `False`. | ||
--text_encoder_cache_dir: Cache directory where text encoder is located. Defaults to None, which means HF_HOME will be used | ||
--dtype: The dtype to save the model in (default: "bf16", options: "fp16", "bf16", "fp32"). If None, the dtype of the state dict is considered. | ||
|
||
Default is "bf16" because CogView3 uses bfloat16 for Training. | ||
|
||
Note: You must provide either --original_state_dict_repo_id or --checkpoint_path. | ||
""" | ||
|
||
import argparse | ||
from contextlib import nullcontext | ||
|
||
import torch | ||
from accelerate import init_empty_weights | ||
from transformers import T5EncoderModel, T5Tokenizer | ||
|
||
from diffusers import AutoencoderKL, CogVideoXDDIMScheduler, CogView3PlusPipeline, CogView3PlusTransformer2DModel | ||
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint | ||
from diffusers.utils.import_utils import is_accelerate_available | ||
|
||
|
||
CTX = init_empty_weights if is_accelerate_available else nullcontext | ||
|
||
TOKENIZER_MAX_LENGTH = 224 | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--transformer_checkpoint_path", default=None, type=str) | ||
parser.add_argument("--vae_checkpoint_path", default=None, type=str) | ||
parser.add_argument("--output_path", required=True, type=str) | ||
parser.add_argument("--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving") | ||
parser.add_argument("--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory") | ||
parser.add_argument("--dtype", type=str, default="bf16") | ||
|
||
args = parser.parse_args() | ||
|
||
|
||
# this is specific to `AdaLayerNormContinuous`: | ||
# diffusers implementation split the linear projection into the scale, shift while CogView3 split it tino shift, scale | ||
def swap_scale_shift(weight, dim): | ||
shift, scale = weight.chunk(2, dim=0) | ||
new_weight = torch.cat([scale, shift], dim=0) | ||
return new_weight | ||
|
||
|
||
def convert_cogview3_transformer_checkpoint_to_diffusers(ckpt_path): | ||
original_state_dict = torch.load(ckpt_path, map_location="cpu") | ||
original_state_dict = original_state_dict["module"] | ||
original_state_dict = {k.replace("model.diffusion_model.", ""): v for k, v in original_state_dict.items()} | ||
|
||
new_state_dict = {} | ||
|
||
# Convert patch_embed | ||
new_state_dict["patch_embed.proj.weight"] = original_state_dict.pop("mixins.patch_embed.proj.weight") | ||
new_state_dict["patch_embed.proj.bias"] = original_state_dict.pop("mixins.patch_embed.proj.bias") | ||
new_state_dict["patch_embed.text_proj.weight"] = original_state_dict.pop("mixins.patch_embed.text_proj.weight") | ||
new_state_dict["patch_embed.text_proj.bias"] = original_state_dict.pop("mixins.patch_embed.text_proj.bias") | ||
|
||
# Convert time_condition_embed | ||
new_state_dict["time_condition_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop( | ||
"time_embed.0.weight" | ||
) | ||
new_state_dict["time_condition_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop( | ||
"time_embed.0.bias" | ||
) | ||
new_state_dict["time_condition_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop( | ||
"time_embed.2.weight" | ||
) | ||
new_state_dict["time_condition_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop( | ||
"time_embed.2.bias" | ||
) | ||
new_state_dict["time_condition_embed.condition_embedder.linear_1.weight"] = original_state_dict.pop( | ||
"label_emb.0.0.weight" | ||
) | ||
new_state_dict["time_condition_embed.condition_embedder.linear_1.bias"] = original_state_dict.pop( | ||
"label_emb.0.0.bias" | ||
) | ||
new_state_dict["time_condition_embed.condition_embedder.linear_2.weight"] = original_state_dict.pop( | ||
"label_emb.0.2.weight" | ||
) | ||
new_state_dict["time_condition_embed.condition_embedder.linear_2.bias"] = original_state_dict.pop( | ||
"label_emb.0.2.bias" | ||
) | ||
|
||
# Convert transformer blocks | ||
for i in range(30): | ||
block_prefix = f"transformer_blocks.{i}." | ||
old_prefix = f"transformer.layers.{i}." | ||
adaln_prefix = f"mixins.adaln.adaln_modules.{i}." | ||
|
||
new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(adaln_prefix + "1.weight") | ||
new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(adaln_prefix + "1.bias") | ||
|
||
qkv_weight = original_state_dict.pop(old_prefix + "attention.query_key_value.weight") | ||
qkv_bias = original_state_dict.pop(old_prefix + "attention.query_key_value.bias") | ||
q, k, v = qkv_weight.chunk(3, dim=0) | ||
q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0) | ||
|
||
new_state_dict[block_prefix + "attn1.to_q.weight"] = q | ||
new_state_dict[block_prefix + "attn1.to_q.bias"] = q_bias | ||
new_state_dict[block_prefix + "attn1.to_k.weight"] = k | ||
new_state_dict[block_prefix + "attn1.to_k.bias"] = k_bias | ||
new_state_dict[block_prefix + "attn1.to_v.weight"] = v | ||
new_state_dict[block_prefix + "attn1.to_v.bias"] = v_bias | ||
|
||
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = original_state_dict.pop( | ||
old_prefix + "attention.dense.weight" | ||
) | ||
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = original_state_dict.pop( | ||
old_prefix + "attention.dense.bias" | ||
) | ||
|
||
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = original_state_dict.pop( | ||
old_prefix + "mlp.dense_h_to_4h.weight" | ||
) | ||
new_state_dict[block_prefix + "ff.net.0.proj.bias"] = original_state_dict.pop( | ||
old_prefix + "mlp.dense_h_to_4h.bias" | ||
) | ||
new_state_dict[block_prefix + "ff.net.2.weight"] = original_state_dict.pop( | ||
old_prefix + "mlp.dense_4h_to_h.weight" | ||
) | ||
new_state_dict[block_prefix + "ff.net.2.bias"] = original_state_dict.pop(old_prefix + "mlp.dense_4h_to_h.bias") | ||
|
||
# Convert final norm and projection | ||
new_state_dict["norm_out.linear.weight"] = swap_scale_shift( | ||
original_state_dict.pop("mixins.final_layer.adaln.1.weight"), dim=0 | ||
) | ||
new_state_dict["norm_out.linear.bias"] = swap_scale_shift( | ||
original_state_dict.pop("mixins.final_layer.adaln.1.bias"), dim=0 | ||
) | ||
new_state_dict["proj_out.weight"] = original_state_dict.pop("mixins.final_layer.linear.weight") | ||
new_state_dict["proj_out.bias"] = original_state_dict.pop("mixins.final_layer.linear.bias") | ||
|
||
return new_state_dict | ||
|
||
|
||
def convert_cogview3_vae_checkpoint_to_diffusers(ckpt_path, vae_config): | ||
original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] | ||
return convert_ldm_vae_checkpoint(original_state_dict, vae_config) | ||
|
||
|
||
def main(args): | ||
if args.dtype is None: | ||
dtype = None | ||
if args.dtype == "fp16": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it makes sense to have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh okay, will update There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's make sure to incorporate this change! |
||
dtype = torch.float16 | ||
elif args.dtype == "bf16": | ||
dtype = torch.bfloat16 | ||
elif args.dtype == "fp32": | ||
dtype = torch.float32 | ||
else: | ||
raise ValueError(f"Unsupported dtype: {args.dtype}") | ||
|
||
transformer = None | ||
vae = None | ||
|
||
if args.transformer_checkpoint_path is not None: | ||
converted_transformer_state_dict = convert_cogview3_transformer_checkpoint_to_diffusers( | ||
args.transformer_checkpoint_path | ||
) | ||
transformer = CogView3PlusTransformer2DModel() | ||
transformer.load_state_dict(converted_transformer_state_dict, strict=True) | ||
if dtype is not None: | ||
# Original checkpoint data type will be preserved | ||
transformer = transformer.to(dtype=dtype) | ||
|
||
if args.vae_checkpoint_path is not None: | ||
vae_config = { | ||
"in_channels": 3, | ||
"out_channels": 3, | ||
"down_block_types": ("DownEncoderBlock2D",) * 4, | ||
"up_block_types": ("UpDecoderBlock2D",) * 4, | ||
"block_out_channels": (128, 512, 1024, 1024), | ||
"layers_per_block": 3, | ||
"act_fn": "silu", | ||
"latent_channels": 16, | ||
"norm_num_groups": 32, | ||
"sample_size": 1024, | ||
"scaling_factor": 1.0, | ||
"force_upcast": True, | ||
"use_quant_conv": False, | ||
"use_post_quant_conv": False, | ||
"mid_block_add_attention": False, | ||
} | ||
converted_vae_state_dict = convert_cogview3_vae_checkpoint_to_diffusers(args.vae_checkpoint_path, vae_config) | ||
vae = AutoencoderKL(**vae_config) | ||
vae.load_state_dict(converted_vae_state_dict, strict=True) | ||
if dtype is not None: | ||
vae = vae.to(dtype=dtype) | ||
|
||
text_encoder_id = "google/t5-v1_1-xxl" | ||
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) | ||
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir) | ||
|
||
# Apparently, the conversion does not work anymore without this :shrug: | ||
for param in text_encoder.parameters(): | ||
param.data = param.data.contiguous() | ||
|
||
# TODO: figure out the correct scheduler if it is same as CogVideoXDDIMScheduler | ||
scheduler = CogVideoXDDIMScheduler.from_config( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess this is the major source of investigation pending from our end now? |
||
{ | ||
"snr_shift_scale": 4.0, | ||
"beta_end": 0.012, | ||
"beta_schedule": "scaled_linear", | ||
"beta_start": 0.00085, | ||
"clip_sample": False, | ||
"num_train_timesteps": 1000, | ||
"prediction_type": "v_prediction", | ||
"rescale_betas_zero_snr": True, | ||
"set_alpha_to_one": True, | ||
"timestep_spacing": "trailing", | ||
} | ||
) | ||
|
||
pipe = CogView3PlusPipeline( | ||
tokenizer=tokenizer, | ||
text_encoder=text_encoder, | ||
vae=vae, | ||
transformer=transformer, | ||
scheduler=scheduler, | ||
) | ||
|
||
# This is necessary for users with insufficient memory, such as those using Colab and notebooks, as it can | ||
# save some memory used for model loading. | ||
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub) | ||
|
||
|
||
if __name__ == "__main__": | ||
main(args) |
Uh oh!
There was an error while loading. Please reload this page.