-
Notifications
You must be signed in to change notification settings - Fork 6k
CogView3Plus DiT #9570
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CogView3Plus DiT #9570
Conversation
The current version has the same output shape, and the model conversion script is also normal. However, there is still a lot of work to be done.
|
86a59f9
to
45b6cb6
Compare
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Testing script for transformer implementation: Codeimport torch
from omegaconf import DictConfig
from sgm.modules.diffusionmodules.dit import DiffusionTransformer
from diffusers import CogView3PlusTransformer2DModel
@torch.no_grad()
def main():
config = DictConfig({
"in_channels": 16,
"out_channels": 16,
"hidden_size": 2560,
"num_layers": 30,
"patch_size": 2,
"block_size": 16,
"num_attention_heads": 64,
"text_length": 224,
"time_embed_dim": 512,
"num_classes": "sequential",
"adm_in_channels": 1536,
"modules": {
"pos_embed_config": {
"target": "sgm.modules.diffusionmodules.dit.PositionEmbeddingMixin",
"params": {
"max_height": 128,
"max_width": 128,
"max_length": 4096
}
},
"patch_embed_config": {
"target": "sgm.modules.diffusionmodules.dit.ImagePatchEmbeddingMixin",
"params": {
"text_hidden_size": 4096
}
},
"attention_config": {
"target": "sgm.modules.diffusionmodules.dit.AdalnAttentionMixin",
"params": {
"qk_ln": True
}
},
"final_layer_config": {
"target": "sgm.modules.diffusionmodules.dit.FinalLayerMixin"
}
},
})
transformer = DiffusionTransformer(**config)
ckpt_path_cogview3_plus = "/raid/aryan/CogView3-SAT/cogview3plus_3b/1/mp_rank_00_model_states.pt"
state_dict = torch.load(ckpt_path_cogview3_plus)["module"]
state_dict = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items()}
transformer.load_state_dict(state_dict, strict=False)
transformer = transformer.to("cuda", dtype=torch.bfloat16)
transformer_diffusers = CogView3PlusTransformer2DModel.from_pretrained("/raid/aryan/CogView3Plus-trial/", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
print(sum(p.numel() for p in transformer.parameters() if p.requires_grad))
print(sum(p.numel() for p in transformer_diffusers.parameters() if p.requires_grad))
x = torch.ones((2, 16, 128, 128), device="cuda", dtype=torch.bfloat16)
timesteps = torch.ones((2,), device="cuda", dtype=torch.bfloat16)
context = torch.ones((2, 224, 4096), device="cuda", dtype=torch.bfloat16)
y = torch.ones((2, 1536), device="cuda", dtype=torch.bfloat16)
breakpoint()
kwargs = {'target_size': [(1024, 1024)], "idx": timesteps}
output = transformer(x, timesteps, context, y, **kwargs)
output_diffusers = transformer_diffusers(x, context, y, timesteps)[0]
print((output - output_diffusers).abs().max(), (output - output_diffusers).abs().sum())
main()
Based on her testing script, I've updated the conversion script. I see similar outputs between both the implementations, but I think it would be good to verify this on your end as well.
cc @yiyixuxu. I think we could support the shift_scale parameter if it makes sense in the schedulers, unless we need different scheduler implementations like CogVideoX
For the label embeddings of shape
But I don't think that's a clean approach.Maybe better would be to prepare the sinusoidal embeddings in the pipeline and then pass that. @yiyixuxu WDYT? |
|
||
|
||
def main(args): | ||
if args.dtype == "fp16": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it makes sense to have dtype
default to None
and just to by default to convert to the original dtype
(we had many occasions that we just accidentally upcasted a mode during the conversion, which is very much undesired)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh okay, will update
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make sure to incorporate this change!
{'target_size_as_tuple': tensor([[1024, 1024]], device='cuda:0'), 'txt': ['an astronaut riding a horse in space'], 'crop_coords_top_left': tensor([[0, 0]], device='cuda:0'), 'original_size_as_tuple': tensor([[1024, 1024]], device='cuda:0')} what are you talking about there? @a-r-r-o-w |
YiYi Xu <[email protected]>
Co-Authored-By: YiYi Xu <[email protected]>
I think this is ready for review of the code parts. The additional embeddings used in timestep condition are similar to SDXL, so I can refactor it out like how we do for SDXL. Apart from that, if there are any changes you'd like, please let me know. The pipeline works and inference runs fine. The outputs are a bit oversaturated and worse in quality, and we also need to validate multiple image per batch - both of which I or Yuxuan will take a look at tomorrow. |
""" | ||
|
||
|
||
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we did not use this, no?
|
||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) | ||
|
||
def _get_t5_prompt_embeds( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this copied from any place?
) | ||
|
||
if do_classifier_free_guidance and negative_prompt_embeds is None: | ||
negative_prompt = negative_prompt or "" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the encode_prompt for negative prompts is different - they use zeros for empty prompts https://github.com/THUDM/CogView3/blob/f80f1001a3bd276a7825bff30d910abeab7e593f/sat/sample_dit.py#L172
did you look into if it caused a difference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yes, sorry I had it in mind but forgot. I'll update to something like:
if negative_prompt is None:
negative_prompt_embeds = torch.zeros(<shape>, device, dtype)
follow up on this comment #9570 (comment) so I run one example to compare with using our encode_prompt vs using prompt embed directly - I think there is a difference and the original one is better
generate using prompt import torch
device = torch.device("cuda:2")
dtype = torch.bfloat16
from diffusers import CogView3PlusPipeline
pipe = CogView3PlusPipeline.from_pretrained("/raid/yiyi/cogview3_diffusers", torch_dtype=torch.bfloat16)
pipe.to(device)
latents = torch.load("/raid/yiyi/CogView3/sat/randn.pt").to(device).to(dtype)
prompt = "Portrait of a young woman with dark skin, bright violet eyes, and braided hair adorned with beads, standing in a mystical forest with glowing fireflies."
image = pipe(prompt, guidance_scale=5, latents=latents).images[0]
image.save("yiyi_test_10_out.png") using encoded prompt embeds from original code base import torch
device = torch.device("cuda:2")
dtype = torch.bfloat16
from diffusers import CogView3PlusPipeline
pipe = CogView3PlusPipeline.from_pretrained("/raid/yiyi/cogview3_diffusers", torch_dtype=torch.bfloat16)
pipe.to(device)
latents = torch.load("/raid/yiyi/CogView3/sat/randn.pt").to(device).to(dtype)
prompt_embeds = torch.load("/raid/yiyi/CogView3/sat/cond.pt").to(device).to(dtype)
negative_prompt_embeds = torch.load("/raid/yiyi/CogView3/sat/uc.pt").to(device).to(dtype)
image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, guidance_scale=5, latents=latents).images[0]
image.save("yiyi_test_11_out.png") |
Test model available here: https://huggingface.co/ZP2HF/CogView-3-Plus/. Once the PR is approved, can be moved to the THUDM org for release |
'Follow up on this comment #9570 (comment)' import CogView3PlusPipeline from the diffusers module latents = torch.load("/raid/yiyi/CogView3/sat/randn.pt").to(device).to(dtype) image = pipe(prompt, guidance_scale=5, latents=latents).images[0] Use the encoding prompt embedding from the original codebase It seems there was a misunderstanding in your request. The provided YAML content contains code snippets and simple English phrases, which are already in English. There is no translation needed from another language into English, as all the text provided is in English. Therefore, the output is the same as the input, with the formatting adjusted to match the example provided.
latents = torch.load("/raid/yiyi/CogView3/sat/randn.pt").to(device).to(dtype) The provided text is already in English, so no translation is necessary. The output reflects the original input format with the text fields unchanged. |
The logs like this |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
num_inference_steps: int = 50, | ||
timesteps: Optional[List[int]] = None, | ||
guidance_scale: float = 6, | ||
use_dynamic_cfg: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
still a question, does this work? maybe @a-r-r-o-w you can test it out too
|
I believe this version is now ready; my colleagues and I have verified it, and it can run the model properly. Looking forward to the merge. |
* merge 9588 * max_shard_size="5GB" for colab running * conversion script updates; modeling test; refactor transformer * make fix-copies * Update convert_cogview3_to_diffusers.py * initial pipeline draft * make style * fight bugs 🐛🪳 * add example * add tests; refactor * make style * make fix-copies * add co-author YiYi Xu <[email protected]> * remove files * add docs * add co-author Co-Authored-By: YiYi Xu <[email protected]> * fight docs * address reviews * make style * make model work * remove qkv fusion * remove qkv fusion tets * address review comments * fix make fix-copies error * remove None and TODO * for FP16(draft) * make style * remove dynamic cfg * remove pooled_projection_dim as a parameter * fix tests --------- Co-authored-by: Aryan <[email protected]> Co-authored-by: YiYi Xu <[email protected]>
This is a draft of CogView3-Plus, not yet fully perfected:
The content that needs to be produced includes:
SAT2diffusers link
VAE implementation and PipeLine integration
Automatic Documentation and Proofreading
Expected to be completed on October 7th
Keep in touch, looking forward to the community's help.
@a-r-r-o-w @yiyixuxu