Skip to content

CogView3Plus DiT #9570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 41 commits into from
Oct 14, 2024
Merged

CogView3Plus DiT #9570

merged 41 commits into from
Oct 14, 2024

Conversation

zRzRzRzRzRzRzR
Copy link
Contributor

This is a draft of CogView3-Plus, not yet fully perfected:
The content that needs to be produced includes:
SAT2diffusers link
VAE implementation and PipeLine integration
Automatic Documentation and Proofreading

Expected to be completed on October 7th
Keep in touch, looking forward to the community's help.
@a-r-r-o-w @yiyixuxu

@zRzRzRzRzRzRzR
Copy link
Contributor Author

The current version has the same output shape, and the model conversion script is also normal. However, there is still a lot of work to be done.

  • The verification of each layer of transformers has not been carried out yet. Currently, the pos embed implementation is the same, but the tensor is different, and it is being verified.
  • @yiyixuxu mentioned that the VAE part can be directly converted using AutoEncoderKL, so I directly used the solution she provided.
  • Regarding the schedule, I first used DDPM, but this should need adjustment. Because I did not see any related operations for shift_scale.
  • The Pipeline has not been implemented today, so I will try to run it in tomorrow's work.
  • The doc documentation is incomplete

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@a-r-r-o-w
Copy link
Member

Testing script for transformer implementation:

Code
import torch

from omegaconf import DictConfig
from sgm.modules.diffusionmodules.dit import DiffusionTransformer
from diffusers import CogView3PlusTransformer2DModel

@torch.no_grad()
def main():
    config = DictConfig({
        "in_channels": 16,
        "out_channels": 16,
        "hidden_size": 2560,
        "num_layers": 30,
        "patch_size": 2,
        "block_size": 16,
        "num_attention_heads": 64,
        "text_length": 224,
        "time_embed_dim": 512,
        "num_classes": "sequential",
        "adm_in_channels": 1536,
        "modules": {
            "pos_embed_config": {
                "target": "sgm.modules.diffusionmodules.dit.PositionEmbeddingMixin",
                "params": {
                    "max_height": 128,
                    "max_width": 128,
                    "max_length": 4096
                }
            },
            "patch_embed_config": {
                "target": "sgm.modules.diffusionmodules.dit.ImagePatchEmbeddingMixin",
                "params": {
                    "text_hidden_size": 4096
                }
            },
            "attention_config": {
                "target": "sgm.modules.diffusionmodules.dit.AdalnAttentionMixin",
                "params": {
                    "qk_ln": True
                }
            },
            "final_layer_config": {
                "target": "sgm.modules.diffusionmodules.dit.FinalLayerMixin"
            }
        },
    })

    transformer = DiffusionTransformer(**config)

    ckpt_path_cogview3_plus = "/raid/aryan/CogView3-SAT/cogview3plus_3b/1/mp_rank_00_model_states.pt"
    state_dict = torch.load(ckpt_path_cogview3_plus)["module"]
    state_dict = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items()}
    transformer.load_state_dict(state_dict, strict=False)
    transformer = transformer.to("cuda", dtype=torch.bfloat16)

    transformer_diffusers = CogView3PlusTransformer2DModel.from_pretrained("/raid/aryan/CogView3Plus-trial/", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")

    print(sum(p.numel() for p in transformer.parameters() if p.requires_grad))
    print(sum(p.numel() for p in transformer_diffusers.parameters() if p.requires_grad))

    x = torch.ones((2, 16, 128, 128), device="cuda", dtype=torch.bfloat16)
    timesteps = torch.ones((2,), device="cuda", dtype=torch.bfloat16)
    context = torch.ones((2, 224, 4096), device="cuda", dtype=torch.bfloat16)
    y = torch.ones((2, 1536), device="cuda", dtype=torch.bfloat16)

    breakpoint()
    kwargs = {'target_size': [(1024, 1024)], "idx": timesteps}
    output = transformer(x, timesteps, context, y, **kwargs)
    output_diffusers = transformer_diffusers(x, context, y, timesteps)[0]

    print((output - output_diffusers).abs().max(), (output - output_diffusers).abs().sum())

main()

@yiyixuxu mentioned that the VAE part can be directly converted using AutoEncoderKL, so I directly used the solution she provided.

Based on her testing script, I've updated the conversion script. I see similar outputs between both the implementations, but I think it would be good to verify this on your end as well.

Regarding the schedule, I first used DDPM, but this should need adjustment. Because I did not see any related operations for shift_scale.

cc @yiyixuxu. I think we could support the shift_scale parameter if it makes sense in the schedulers, unless we need different scheduler implementations like CogVideoX

The Pipeline has not been implemented today, so I will try to run it in tomorrow's work.

For the label embeddings of shape [B, 1536], I think we will have to pass target_size_as_tu ple, crop_coords_top_left and original_size_as_tuple to the transformer.

{'target_size_as_tuple': tensor([[1024, 1024]], device='cuda:0'), 'txt': ['an astronaut riding a horse in space'], 'crop_coords_top_left': tensor([[0, 0]], device='cuda:0'), 'original_size_as_tuple': tensor([[1024, 1024]], device='cuda:0')}

But I don't think that's a clean approach.Maybe better would be to prepare the sinusoidal embeddings in the pipeline and then pass that. @yiyixuxu WDYT?



def main(args):
if args.dtype == "fp16":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense to have dtype default to None and just to by default to convert to the original dtype (we had many occasions that we just accidentally upcasted a mode during the conversion, which is very much undesired)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh okay, will update

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make sure to incorporate this change!

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Oct 7, 2024

for the label embeddings of shape [B, 1536], I think we will have to pass target_size_as_tu ple, crop_coords_top_left and original_size_as_tuple to the transformer.

{'target_size_as_tuple': tensor([[1024, 1024]], device='cuda:0'), 'txt': ['an astronaut riding a horse in space'], 'crop_coords_top_left': tensor([[0, 0]], device='cuda:0'), 'original_size_as_tuple': tensor([[1024, 1024]], device='cuda:0')}

what are you talking about there? @a-r-r-o-w

@a-r-r-o-w
Copy link
Member

I think this is ready for review of the code parts. The additional embeddings used in timestep condition are similar to SDXL, so I can refactor it out like how we do for SDXL. Apart from that, if there are any changes you'd like, please let me know. The pipeline works and inference runs fine. The outputs are a bit oversaturated and worse in quality, and we also need to validate multiple image per batch - both of which I or Yuxuan will take a look at tomorrow.

@a-r-r-o-w
Copy link
Member

Some results:

Image 1 Image 2
Image 3 Image 4
Image 5 Image 6

"""


# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we did not use this, no?


self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)

def _get_t5_prompt_embeds(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this copied from any place?

)

if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the encode_prompt for negative prompts is different - they use zeros for empty prompts https://github.com/THUDM/CogView3/blob/f80f1001a3bd276a7825bff30d910abeab7e593f/sat/sample_dit.py#L172

did you look into if it caused a difference?

Copy link
Member

@a-r-r-o-w a-r-r-o-w Oct 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes, sorry I had it in mind but forgot. I'll update to something like:

if negative_prompt is None:
    negative_prompt_embeds = torch.zeros(<shape>, device, dtype)

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Oct 10, 2024

follow up on this comment #9570 (comment)

so I run one example to compare with using our encode_prompt vs using prompt embed directly - I think there is a difference and the original one is better

Using our encode_prompt to process prompt Using encoded prompt embeds from cogview3 codebase Output from cogview3 codebase

generate using prompt

import torch
device = torch.device("cuda:2") 
dtype = torch.bfloat16

from diffusers import CogView3PlusPipeline
pipe = CogView3PlusPipeline.from_pretrained("/raid/yiyi/cogview3_diffusers", torch_dtype=torch.bfloat16)
pipe.to(device)

latents = torch.load("/raid/yiyi/CogView3/sat/randn.pt").to(device).to(dtype)
prompt = "Portrait of a young woman with dark skin, bright violet eyes, and braided hair adorned with beads, standing in a mystical forest with glowing fireflies."

image = pipe(prompt, guidance_scale=5, latents=latents).images[0]
image.save("yiyi_test_10_out.png")

using encoded prompt embeds from original code base

import torch
device = torch.device("cuda:2") 
dtype = torch.bfloat16

from diffusers import CogView3PlusPipeline
pipe = CogView3PlusPipeline.from_pretrained("/raid/yiyi/cogview3_diffusers", torch_dtype=torch.bfloat16)
pipe.to(device)

latents = torch.load("/raid/yiyi/CogView3/sat/randn.pt").to(device).to(dtype)
prompt_embeds = torch.load("/raid/yiyi/CogView3/sat/cond.pt").to(device).to(dtype)    
negative_prompt_embeds = torch.load("/raid/yiyi/CogView3/sat/uc.pt").to(device).to(dtype)
image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, guidance_scale=5, latents=latents).images[0]
image.save("yiyi_test_11_out.png")

@a-r-r-o-w
Copy link
Member

Test model available here: https://huggingface.co/ZP2HF/CogView-3-Plus/. Once the PR is approved, can be moved to the THUDM org for release

@zRzRzRzRzRzRzR
Copy link
Contributor Author

'Follow up on this comment #9570 (comment)'
''
'So I ran a sample to compare the effect of using our encode_prompt versus using the prompt embed directly - I think the original one is better'
' '
'Using our encode_prompt to process the prompt Using the prompt embedding encoded from the cogview3 codebase Output from the cogview3 codebase'
''
'Use prompt to generate'
''
'```python'
'import torch'
device is set to torch.device("cuda:2")
data type is torch.bfloat16

import CogView3PlusPipeline from the diffusers module
pipe is initialized as CogView3PlusPipeline from the pretrained model at '/raid/yiyi/cogview3_diffusers' with torch_dtype set to torch.bfloat16
pipe.to(device)

latents = torch.load("/raid/yiyi/CogView3/sat/randn.pt").to(device).to(dtype)
prompt = "Portrait of a young woman with dark skin, bright violet eyes, and braided hair adorned with beads, standing in a mystical forest with glowing fireflies."

image = pipe(prompt, guidance_scale=5, latents=latents).images[0]
image.save("yiyi_test_10_out.png")


Use the encoding prompt embedding from the original codebase
|
|
```python
|
import torch
|
device = torch.device("cuda:2")
|
dtype = torch.bfloat16

It seems there was a misunderstanding in your request. The provided YAML content contains code snippets and simple English phrases, which are already in English. There is no translation needed from another language into English, as all the text provided is in English. Therefore, the output is the same as the input, with the formatting adjusted to match the example provided.

Import CogView3PlusPipeline from the diffusers module
Create a CogView3PlusPipeline object from a pretrained model located at "/raid/yiyi/cogview3_diffusers" with the data type set to torch.bfloat16
Move the pipeline to the specified device

latents = torch.load("/raid/yiyi/CogView3/sat/randn.pt").to(device).to(dtype)
prompt_embeds = torch.load("/raid/yiyi/CogView3/sat/cond.pt").to(device).to(dtype)
negative_prompt_embeds = torch.load("/raid/yiyi/CogView3/sat/uc.pt").to(device).to(dtype)
image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, guidance_scale=5, latents=latents).images[0]
"image.save("yiyi_test_11_out.png")"

The provided text is already in English, so no translation is necessary. The output reflects the original input format with the text fields unchanged.
|
The content for this entry is empty and therefore nothing to translate.

|
Does this code have issues when running in FP16, I encountered a black screen issue; there were no errors, but it output a completely black image。

Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:03<00:00,  1.02it/s]
Loading pipeline components...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:06<00:00,  1.37s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:20<00:00,  2.45it/s]
/share/home/zyx/Code/diffusers/src/diffusers/image_processor.py:111: RuntimeWarning: invalid value encountered in cast
  images = (images * 255).round().astype("uint8")

The logs like this

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
guidance_scale: float = 6,
use_dynamic_cfg: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still a question, does this work? maybe @a-r-r-o-w you can test it out too

@zRzRzRzRzRzRzR
Copy link
Contributor Author

use_dynamic_cfg is Not use in this pipeline with DDIM I believe.

@zRzRzRzRzRzRzR
Copy link
Contributor Author

zRzRzRzRzRzRzR commented Oct 14, 2024

I believe this version is now ready; my colleagues and I have verified it, and it can run the model properly. Looking forward to the merge.

@a-r-r-o-w a-r-r-o-w merged commit 8d81564 into huggingface:main Oct 14, 2024
15 checks passed
sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
* merge 9588

* max_shard_size="5GB" for colab running

* conversion script updates; modeling test; refactor transformer

* make fix-copies

* Update convert_cogview3_to_diffusers.py

* initial pipeline draft

* make style

* fight bugs 🐛🪳

* add example

* add tests; refactor

* make style

* make fix-copies

* add co-author

YiYi Xu <[email protected]>

* remove files

* add docs

* add co-author

Co-Authored-By: YiYi Xu <[email protected]>

* fight docs

* address reviews

* make style

* make model work

* remove qkv fusion

* remove qkv fusion tets

* address review comments

* fix make fix-copies error

* remove None and TODO

* for FP16(draft)

* make style

* remove dynamic cfg

* remove pooled_projection_dim as a parameter

* fix tests

---------

Co-authored-by: Aryan <[email protected]>
Co-authored-by: YiYi Xu <[email protected]>
@zRzRzRzRzRzRzR zRzRzRzRzRzRzR deleted the cogview3-plus branch January 14, 2025 06:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants