-
Notifications
You must be signed in to change notification settings - Fork 6k
CogVideoX-5b-I2V support #9418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CogVideoX-5b-I2V support #9418
Changes from 25 commits
6e3ae04
ad78738
8966671
a56c510
c238fe2
c1f7a80
3df95b2
677a553
4f51829
33c7cd6
bc07f9f
98f1023
aa12e1b
1970f4f
e044850
f7d8e37
9f6f3f6
877cdc0
29f1007
0c1358c
8222a55
61831bd
2d8dce9
4f89426
21a6f79
6ce0778
7e637d6
6f313e8
ed8bda9
c8ec68c
33056c5
6dc9bdb
edeb626
380a820
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,13 @@ | |
import torch | ||
from transformers import T5EncoderModel, T5Tokenizer | ||
|
||
from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel | ||
from diffusers import ( | ||
AutoencoderKLCogVideoX, | ||
CogVideoXDDIMScheduler, | ||
CogVideoXImageToVideoPipeline, | ||
CogVideoXPipeline, | ||
CogVideoXTransformer3DModel, | ||
) | ||
|
||
|
||
def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]): | ||
|
@@ -78,6 +84,7 @@ def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]): | |
"mixins.final_layer.norm_final": "norm_out.norm", | ||
"mixins.final_layer.linear": "proj_out", | ||
"mixins.final_layer.adaLN_modulation.1": "norm_out.linear", | ||
"mixins.pos_embed.pos_embedding": "patch_embed.pos_embedding", # Specific to CogVideoX-5b-I2V | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we have any if/else to guard that accordingly? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This layer is absent in the T2V models actually. It's called There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep, this is safe and should not affect the T2V checkpoints since they follow different layer naming conventions |
||
} | ||
|
||
TRANSFORMER_SPECIAL_KEYS_REMAP = { | ||
|
@@ -131,15 +138,18 @@ def convert_transformer( | |
num_layers: int, | ||
num_attention_heads: int, | ||
use_rotary_positional_embeddings: bool, | ||
i2v: bool, | ||
dtype: torch.dtype, | ||
): | ||
PREFIX_KEY = "model.diffusion_model." | ||
|
||
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) | ||
transformer = CogVideoXTransformer3DModel( | ||
in_channels=32 if i2v else 16, | ||
num_layers=num_layers, | ||
num_attention_heads=num_attention_heads, | ||
use_rotary_positional_embeddings=use_rotary_positional_embeddings, | ||
use_learned_positional_embeddings=i2v, | ||
).to(dtype=dtype) | ||
|
||
for key in list(original_state_dict.keys()): | ||
|
@@ -153,7 +163,6 @@ def convert_transformer( | |
if special_key not in key: | ||
continue | ||
handler_fn_inplace(key, original_state_dict) | ||
|
||
transformer.load_state_dict(original_state_dict, strict=True) | ||
return transformer | ||
|
||
|
@@ -205,6 +214,7 @@ def get_args(): | |
parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE") | ||
# For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0 | ||
parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE") | ||
parser.add_argument("--i2v", action="store_true", default=False, help="Whether to save the model weights in fp16") | ||
return parser.parse_args() | ||
|
||
|
||
|
@@ -225,6 +235,7 @@ def get_args(): | |
args.num_layers, | ||
args.num_attention_heads, | ||
args.use_rotary_positional_embeddings, | ||
args.i2v, | ||
dtype, | ||
) | ||
if args.vae_ckpt_path is not None: | ||
|
@@ -233,7 +244,6 @@ def get_args(): | |
text_encoder_id = "google/t5-v1_1-xxl" | ||
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) | ||
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir) | ||
|
||
# Apparently, the conversion does not work any more without this :shrug: | ||
for param in text_encoder.parameters(): | ||
param.data = param.data.contiguous() | ||
|
@@ -252,9 +262,17 @@ def get_args(): | |
"timestep_spacing": "trailing", | ||
} | ||
) | ||
|
||
pipe = CogVideoXPipeline( | ||
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler | ||
if args.i2v: | ||
pipeline_cls = CogVideoXImageToVideoPipeline | ||
else: | ||
pipeline_cls = CogVideoXPipeline | ||
|
||
pipe = pipeline_cls( | ||
tokenizer=tokenizer, | ||
text_encoder=text_encoder, | ||
vae=vae, | ||
transformer=transformer, | ||
scheduler=scheduler, | ||
) | ||
|
||
if args.fp16: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -350,6 +350,7 @@ def __init__( | |
spatial_interpolation_scale: float = 1.875, | ||
temporal_interpolation_scale: float = 1.0, | ||
use_positional_embeddings: bool = True, | ||
use_learned_positional_embeddings: bool = True, | ||
) -> None: | ||
super().__init__() | ||
|
||
|
@@ -363,15 +364,17 @@ def __init__( | |
self.spatial_interpolation_scale = spatial_interpolation_scale | ||
self.temporal_interpolation_scale = temporal_interpolation_scale | ||
self.use_positional_embeddings = use_positional_embeddings | ||
self.use_learned_positional_embeddings = use_learned_positional_embeddings | ||
|
||
self.proj = nn.Conv2d( | ||
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias | ||
) | ||
self.text_proj = nn.Linear(text_embed_dim, embed_dim) | ||
|
||
if use_positional_embeddings: | ||
if use_positional_embeddings or use_learned_positional_embeddings: | ||
persistent = use_learned_positional_embeddings | ||
pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames) | ||
self.register_buffer("pos_embedding", pos_embedding, persistent=False) | ||
self.register_buffer("pos_embedding", pos_embedding, persistent=persistent) | ||
|
||
def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor: | ||
post_patch_height = sample_height // self.patch_size | ||
|
@@ -415,8 +418,15 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): | |
[text_embeds, image_embeds], dim=1 | ||
).contiguous() # [batch, seq_length + num_frames x height x width, channels] | ||
|
||
if self.use_positional_embeddings: | ||
if self.use_positional_embeddings or self.use_learned_positional_embeddings: | ||
if self.use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != height): | ||
raise ValueError( | ||
"It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'." | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In other words, the 2b variant supports it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, we had some success with multiresolution inference quality on 2B T2V. The reason for allowing this is to not confine lora training to 720x480 videos on 2B model. 5B T2V will skip this entire branch. 5B I2V use positional embeddings that were learned, so we can't generate them on-the-fly like sincos for the 2B T2V model |
||
"If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues." | ||
) | ||
|
||
pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1 | ||
|
||
if ( | ||
self.sample_height != height | ||
or self.sample_width != width | ||
|
Uh oh!
There was an error while loading. Please reload this page.