Skip to content

CogView3Plus DiT #9570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 41 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
45b6cb6
merge 9588
zRzRzRzRzRzRzR Oct 5, 2024
8abfe00
max_shard_size="5GB" for colab running
zRzRzRzRzRzRzR Oct 5, 2024
d668ad9
conversion script updates; modeling test; refactor transformer
a-r-r-o-w Oct 6, 2024
53935c0
make fix-copies
a-r-r-o-w Oct 6, 2024
8af6709
Merge branch 'huggingface:main' into cogview3-plus
zRzRzRzRzRzRzR Oct 7, 2024
56b5599
Update convert_cogview3_to_diffusers.py
zRzRzRzRzRzRzR Oct 7, 2024
8d14495
Merge branch 'main' into cogview3-plus
a-r-r-o-w Oct 8, 2024
8e2ddd5
initial pipeline draft
a-r-r-o-w Oct 8, 2024
0e33401
make style
a-r-r-o-w Oct 8, 2024
3873c57
fight bugs 🐛🪳
a-r-r-o-w Oct 8, 2024
e77b988
add example
a-r-r-o-w Oct 8, 2024
958d3e7
add tests; refactor
a-r-r-o-w Oct 8, 2024
ab8c65f
make style
a-r-r-o-w Oct 8, 2024
059812a
make fix-copies
a-r-r-o-w Oct 8, 2024
3138ad1
add co-author
a-r-r-o-w Oct 8, 2024
86909dc
remove files
a-r-r-o-w Oct 8, 2024
7da234b
add docs
a-r-r-o-w Oct 8, 2024
0114e3b
add co-author
a-r-r-o-w Oct 8, 2024
4e8de65
fight docs
a-r-r-o-w Oct 8, 2024
23a093e
Merge branch 'huggingface:main' into cogview3-plus
zRzRzRzRzRzRzR Oct 8, 2024
0c1ebc3
Merge branch 'main' into cogview3-plus
a-r-r-o-w Oct 9, 2024
f35850c
address reviews
a-r-r-o-w Oct 9, 2024
b439f4c
make style
a-r-r-o-w Oct 9, 2024
9d9b0b2
make model work
a-r-r-o-w Oct 9, 2024
2158f00
remove qkv fusion
a-r-r-o-w Oct 9, 2024
1b06020
remove qkv fusion tets
a-r-r-o-w Oct 9, 2024
80e7cca
address review comments
a-r-r-o-w Oct 10, 2024
5cc2ade
Merge branch 'main' into cogview3-plus
a-r-r-o-w Oct 10, 2024
0e4577d
fix make fix-copies error
a-r-r-o-w Oct 11, 2024
53ff253
Merge branch 'main' into cogview3-plus
a-r-r-o-w Oct 11, 2024
5c68cfd
Merge branch 'huggingface:main' into cogview3-plus
zRzRzRzRzRzRzR Oct 11, 2024
9c3a81d
remove None and TODO
zRzRzRzRzRzRzR Oct 11, 2024
6d1f1c9
Merge branch 'cogview3-plus' of https://github.com/zRzRzRzRzRzRzR/dif…
zRzRzRzRzRzRzR Oct 11, 2024
6603901
for FP16(draft)
zRzRzRzRzRzRzR Oct 11, 2024
fe33f0a
Merge branch 'main' into cogview3-plus
a-r-r-o-w Oct 11, 2024
db2a958
make style
a-r-r-o-w Oct 11, 2024
52f97f6
Merge branch 'huggingface:main' into cogview3-plus
zRzRzRzRzRzRzR Oct 12, 2024
270d407
remove dynamic cfg
a-r-r-o-w Oct 14, 2024
4ac4e52
Merge branch 'main' into cogview3-plus
a-r-r-o-w Oct 14, 2024
21dd890
remove pooled_projection_dim as a parameter
a-r-r-o-w Oct 14, 2024
221b486
fix tests
a-r-r-o-w Oct 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,8 @@
title: AuraFlowTransformer2DModel
- local: api/models/cogvideox_transformer3d
title: CogVideoXTransformer3DModel
- local: api/models/cogview3plus_transformer2d
title: CogView3PlusTransformer2DModel
- local: api/models/dit_transformer2d
title: DiTTransformer2DModel
- local: api/models/flux_transformer
Expand Down Expand Up @@ -320,6 +322,8 @@
title: BLIP-Diffusion
- local: api/pipelines/cogvideox
title: CogVideoX
- local: api/pipelines/cogview3
title: CogView3
- local: api/pipelines/consistency_models
title: Consistency Models
- local: api/pipelines/controlnet
Expand Down
30 changes: 30 additions & 0 deletions docs/source/en/api/models/cogview3plus_transformer2d.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->

# CogView3PlusTransformer2DModel

A Diffusion Transformer model for 2D data from [CogView3Plus](https://github.com/THUDM/CogView3) was introduced in [CogView3: Finer and Faster Text-to-Image Generation via Relay Diffusion](https://huggingface.co/papers/2403.05121) by Tsinghua University & ZhipuAI.

The model can be loaded with the following code snippet.

```python
from diffusers import CogView3PlusTransformer2DModel

vae = CogView3PlusTransformer2DModel.from_pretrained("THUDM/CogView3Plus-3b", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
```

## CogView3PlusTransformer2DModel

[[autodoc]] CogView3PlusTransformer2DModel

## Transformer2DModelOutput

[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
40 changes: 40 additions & 0 deletions docs/source/en/api/pipelines/cogview3.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-->

# CogView3Plus

[CogView3: Finer and Faster Text-to-Image Generation via Relay Diffusion](https://huggingface.co/papers/2403.05121) from Tsinghua University & ZhipuAI, by Wendi Zheng, Jiayan Teng, Zhuoyi Yang, Weihan Wang, Jidong Chen, Xiaotao Gu, Yuxiao Dong, Ming Ding, Jie Tang.

The abstract from the paper is:

*Recent advancements in text-to-image generative systems have been largely driven by diffusion models. However, single-stage text-to-image diffusion models still face challenges, in terms of computational efficiency and the refinement of image details. To tackle the issue, we propose CogView3, an innovative cascaded framework that enhances the performance of text-to-image diffusion. CogView3 is the first model implementing relay diffusion in the realm of text-to-image generation, executing the task by first creating low-resolution images and subsequently applying relay-based super-resolution. This methodology not only results in competitive text-to-image outputs but also greatly reduces both training and inference costs. Our experimental results demonstrate that CogView3 outperforms SDXL, the current state-of-the-art open-source text-to-image diffusion model, by 77.0% in human evaluations, all while requiring only about 1/2 of the inference time. The distilled variant of CogView3 achieves comparable performance while only utilizing 1/10 of the inference time by SDXL.*

<Tip>

Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.

</Tip>

This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).

## CogView3PlusPipeline

[[autodoc]] CogView3PlusPipeline
- all
- __call__

## CogView3PipelineOutput

[[autodoc]] pipelines.cogview3.pipeline_output.CogView3PipelineOutput
245 changes: 245 additions & 0 deletions scripts/convert_cogview3_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
"""
Convert a CogView3 checkpoint to the Diffusers format.

This script converts a CogView3 checkpoint to the Diffusers format, which can then be used
with the Diffusers library.

Example usage:
python scripts/convert_cogview3_to_diffusers.py \
--transformer_checkpoint_path 'your path/cogview3plus_3b/1/mp_rank_00_model_states.pt' \
--vae_checkpoint_path 'your path/3plus_ae/imagekl_ch16.pt' \
--output_path "/raid/yiyi/cogview3_diffusers" \
--dtype "bf16"

Arguments:
--transformer_checkpoint_path: Path to Transformer state dict.
--vae_checkpoint_path: Path to VAE state dict.
--output_path: The path to save the converted model.
--push_to_hub: Whether to push the converted checkpoint to the HF Hub or not. Defaults to `False`.
--text_encoder_cache_dir: Cache directory where text encoder is located. Defaults to None, which means HF_HOME will be used
--dtype: The dtype to save the model in (default: "bf16", options: "fp16", "bf16", "fp32"). If None, the dtype of the state dict is considered.

Default is "bf16" because CogView3 uses bfloat16 for Training.

Note: You must provide either --original_state_dict_repo_id or --checkpoint_path.
"""

import argparse
from contextlib import nullcontext

import torch
from accelerate import init_empty_weights
from transformers import T5EncoderModel, T5Tokenizer

from diffusers import AutoencoderKL, CogVideoXDDIMScheduler, CogView3PlusPipeline, CogView3PlusTransformer2DModel
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
from diffusers.utils.import_utils import is_accelerate_available


CTX = init_empty_weights if is_accelerate_available else nullcontext

TOKENIZER_MAX_LENGTH = 224

parser = argparse.ArgumentParser()
parser.add_argument("--transformer_checkpoint_path", default=None, type=str)
parser.add_argument("--vae_checkpoint_path", default=None, type=str)
parser.add_argument("--output_path", required=True, type=str)
parser.add_argument("--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving")
parser.add_argument("--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory")
parser.add_argument("--dtype", type=str, default="bf16")

args = parser.parse_args()


# this is specific to `AdaLayerNormContinuous`:
# diffusers implementation split the linear projection into the scale, shift while CogView3 split it tino shift, scale
def swap_scale_shift(weight, dim):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight


def convert_cogview3_transformer_checkpoint_to_diffusers(ckpt_path):
original_state_dict = torch.load(ckpt_path, map_location="cpu")
original_state_dict = original_state_dict["module"]
original_state_dict = {k.replace("model.diffusion_model.", ""): v for k, v in original_state_dict.items()}

new_state_dict = {}

# Convert patch_embed
new_state_dict["patch_embed.proj.weight"] = original_state_dict.pop("mixins.patch_embed.proj.weight")
new_state_dict["patch_embed.proj.bias"] = original_state_dict.pop("mixins.patch_embed.proj.bias")
new_state_dict["patch_embed.text_proj.weight"] = original_state_dict.pop("mixins.patch_embed.text_proj.weight")
new_state_dict["patch_embed.text_proj.bias"] = original_state_dict.pop("mixins.patch_embed.text_proj.bias")

# Convert time_condition_embed
new_state_dict["time_condition_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
"time_embed.0.weight"
)
new_state_dict["time_condition_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
"time_embed.0.bias"
)
new_state_dict["time_condition_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
"time_embed.2.weight"
)
new_state_dict["time_condition_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
"time_embed.2.bias"
)
new_state_dict["time_condition_embed.condition_embedder.linear_1.weight"] = original_state_dict.pop(
"label_emb.0.0.weight"
)
new_state_dict["time_condition_embed.condition_embedder.linear_1.bias"] = original_state_dict.pop(
"label_emb.0.0.bias"
)
new_state_dict["time_condition_embed.condition_embedder.linear_2.weight"] = original_state_dict.pop(
"label_emb.0.2.weight"
)
new_state_dict["time_condition_embed.condition_embedder.linear_2.bias"] = original_state_dict.pop(
"label_emb.0.2.bias"
)

# Convert transformer blocks
for i in range(30):
block_prefix = f"transformer_blocks.{i}."
old_prefix = f"transformer.layers.{i}."
adaln_prefix = f"mixins.adaln.adaln_modules.{i}."

new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(adaln_prefix + "1.weight")
new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(adaln_prefix + "1.bias")

qkv_weight = original_state_dict.pop(old_prefix + "attention.query_key_value.weight")
qkv_bias = original_state_dict.pop(old_prefix + "attention.query_key_value.bias")
q, k, v = qkv_weight.chunk(3, dim=0)
q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0)

new_state_dict[block_prefix + "attn1.to_q.weight"] = q
new_state_dict[block_prefix + "attn1.to_q.bias"] = q_bias
new_state_dict[block_prefix + "attn1.to_k.weight"] = k
new_state_dict[block_prefix + "attn1.to_k.bias"] = k_bias
new_state_dict[block_prefix + "attn1.to_v.weight"] = v
new_state_dict[block_prefix + "attn1.to_v.bias"] = v_bias

new_state_dict[block_prefix + "attn1.to_out.0.weight"] = original_state_dict.pop(
old_prefix + "attention.dense.weight"
)
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = original_state_dict.pop(
old_prefix + "attention.dense.bias"
)

new_state_dict[block_prefix + "ff.net.0.proj.weight"] = original_state_dict.pop(
old_prefix + "mlp.dense_h_to_4h.weight"
)
new_state_dict[block_prefix + "ff.net.0.proj.bias"] = original_state_dict.pop(
old_prefix + "mlp.dense_h_to_4h.bias"
)
new_state_dict[block_prefix + "ff.net.2.weight"] = original_state_dict.pop(
old_prefix + "mlp.dense_4h_to_h.weight"
)
new_state_dict[block_prefix + "ff.net.2.bias"] = original_state_dict.pop(old_prefix + "mlp.dense_4h_to_h.bias")

# Convert final norm and projection
new_state_dict["norm_out.linear.weight"] = swap_scale_shift(
original_state_dict.pop("mixins.final_layer.adaln.1.weight"), dim=0
)
new_state_dict["norm_out.linear.bias"] = swap_scale_shift(
original_state_dict.pop("mixins.final_layer.adaln.1.bias"), dim=0
)
new_state_dict["proj_out.weight"] = original_state_dict.pop("mixins.final_layer.linear.weight")
new_state_dict["proj_out.bias"] = original_state_dict.pop("mixins.final_layer.linear.bias")

return new_state_dict


def convert_cogview3_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
return convert_ldm_vae_checkpoint(original_state_dict, vae_config)


def main(args):
if args.dtype is None:
dtype = None
if args.dtype == "fp16":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense to have dtype default to None and just to by default to convert to the original dtype (we had many occasions that we just accidentally upcasted a mode during the conversion, which is very much undesired)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh okay, will update

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make sure to incorporate this change!

dtype = torch.float16
elif args.dtype == "bf16":
dtype = torch.bfloat16
elif args.dtype == "fp32":
dtype = torch.float32
else:
raise ValueError(f"Unsupported dtype: {args.dtype}")

transformer = None
vae = None

if args.transformer_checkpoint_path is not None:
converted_transformer_state_dict = convert_cogview3_transformer_checkpoint_to_diffusers(
args.transformer_checkpoint_path
)
transformer = CogView3PlusTransformer2DModel()
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
if dtype is not None:
# Original checkpoint data type will be preserved
transformer = transformer.to(dtype=dtype)

if args.vae_checkpoint_path is not None:
vae_config = {
"in_channels": 3,
"out_channels": 3,
"down_block_types": ("DownEncoderBlock2D",) * 4,
"up_block_types": ("UpDecoderBlock2D",) * 4,
"block_out_channels": (128, 512, 1024, 1024),
"layers_per_block": 3,
"act_fn": "silu",
"latent_channels": 16,
"norm_num_groups": 32,
"sample_size": 1024,
"scaling_factor": 1.0,
"force_upcast": True,
"use_quant_conv": False,
"use_post_quant_conv": False,
"mid_block_add_attention": False,
}
converted_vae_state_dict = convert_cogview3_vae_checkpoint_to_diffusers(args.vae_checkpoint_path, vae_config)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_state_dict, strict=True)
if dtype is not None:
vae = vae.to(dtype=dtype)

text_encoder_id = "google/t5-v1_1-xxl"
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)

# Apparently, the conversion does not work anymore without this :shrug:
for param in text_encoder.parameters():
param.data = param.data.contiguous()

# TODO: figure out the correct scheduler if it is same as CogVideoXDDIMScheduler
scheduler = CogVideoXDDIMScheduler.from_config(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is the major source of investigation pending from our end now?

{
"snr_shift_scale": 4.0,
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"clip_sample": False,
"num_train_timesteps": 1000,
"prediction_type": "v_prediction",
"rescale_betas_zero_snr": True,
"set_alpha_to_one": True,
"timestep_spacing": "trailing",
}
)

pipe = CogView3PlusPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
transformer=transformer,
scheduler=scheduler,
)

# This is necessary for users with insufficient memory, such as those using Colab and notebooks, as it can
# save some memory used for model loading.
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)


if __name__ == "__main__":
main(args)
4 changes: 4 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
"AutoencoderOobleck",
"AutoencoderTiny",
"CogVideoXTransformer3DModel",
"CogView3PlusTransformer2DModel",
"ConsistencyDecoderVAE",
"ControlNetModel",
"ControlNetXSAdapter",
Expand Down Expand Up @@ -258,6 +259,7 @@
"CogVideoXImageToVideoPipeline",
"CogVideoXPipeline",
"CogVideoXVideoToVideoPipeline",
"CogView3PlusPipeline",
"CycleDiffusionPipeline",
"FluxControlNetImg2ImgPipeline",
"FluxControlNetInpaintPipeline",
Expand Down Expand Up @@ -558,6 +560,7 @@
AutoencoderOobleck,
AutoencoderTiny,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
ConsistencyDecoderVAE,
ControlNetModel,
ControlNetXSAdapter,
Expand Down Expand Up @@ -710,6 +713,7 @@
CogVideoXImageToVideoPipeline,
CogVideoXPipeline,
CogVideoXVideoToVideoPipeline,
CogView3PlusPipeline,
CycleDiffusionPipeline,
FluxControlNetImg2ImgPipeline,
FluxControlNetInpaintPipeline,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
_import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"]
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
Expand Down Expand Up @@ -98,6 +99,7 @@
from .transformers import (
AuraFlowTransformer2DModel,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
DiTTransformer2DModel,
DualTransformer2DModel,
FluxTransformer2DModel,
Expand Down
Loading
Loading