Skip to content

Commit 27b5786

Browse files
CogVideoX-5b-I2V support (#9418)
* draft Init * draft * vae encode image * make style * image latents preparation * remove image encoder from conversion script * fix minor bugs * make pipeline work * make style * remove debug prints * fix imports * update example * make fix-copies * add fast tests * fix import * update vae * update docs * update image link * apply suggestions from review * apply suggestions from review * add slow test * make use of learned positional embeddings * apply suggestions from review * doc change * Update convert_cogvideox_to_diffusers.py * make style * final changes * make style * fix tests --------- Co-authored-by: Aryan <[email protected]>
1 parent c7407a3 commit 27b5786

File tree

12 files changed

+1328
-25
lines changed

12 files changed

+1328
-25
lines changed

docs/source/en/api/loaders/single_file.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load:
2323
## Supported pipelines
2424

2525
- [`CogVideoXPipeline`]
26+
- [`CogVideoXImageToVideoPipeline`]
27+
- [`CogVideoXVideoToVideoPipeline`]
2628
- [`StableDiffusionPipeline`]
2729
- [`StableDiffusionImg2ImgPipeline`]
2830
- [`StableDiffusionInpaintPipeline`]

docs/source/en/api/pipelines/cogvideox.md

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,12 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
2929

3030
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).
3131

32-
There are two models available that can be used with the CogVideoX pipeline:
33-
- [`THUDM/CogVideoX-2b`](https://huggingface.co/THUDM/CogVideoX-2b)
34-
- [`THUDM/CogVideoX-5b`](https://huggingface.co/THUDM/CogVideoX-5b)
32+
There are two models available that can be used with the text-to-video and video-to-video CogVideoX pipelines:
33+
- [`THUDM/CogVideoX-2b`](https://huggingface.co/THUDM/CogVideoX-2b): The recommended dtype for running this model is `fp16`.
34+
- [`THUDM/CogVideoX-5b`](https://huggingface.co/THUDM/CogVideoX-5b): The recommended dtype for running this model is `bf16`.
35+
36+
There is one model available that can be used with the image-to-video CogVideoX pipeline:
37+
- [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V): The recommended dtype for running this model is `bf16`.
3538

3639
## Inference
3740

@@ -41,10 +44,15 @@ First, load the pipeline:
4144

4245
```python
4346
import torch
44-
from diffusers import CogVideoXPipeline
45-
from diffusers.utils import export_to_video
47+
from diffusers import CogVideoXPipeline, CogVideoXImageToVideoPipeline
48+
from diffusers.utils import export_to_video,load_image
49+
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b").to("cuda") # or "THUDM/CogVideoX-2b"
50+
```
4651

47-
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b").to("cuda")
52+
If you are using the image-to-video pipeline, load it as follows:
53+
54+
```python
55+
pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V").to("cuda")
4856
```
4957

5058
Then change the memory layout of the pipelines `transformer` component to `torch.channels_last`:
@@ -53,7 +61,7 @@ Then change the memory layout of the pipelines `transformer` component to `torch
5361
pipe.transformer.to(memory_format=torch.channels_last)
5462
```
5563

56-
Finally, compile the components and run inference:
64+
Compile the components and run inference:
5765

5866
```python
5967
pipe.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
@@ -63,7 +71,7 @@ prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wood
6371
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
6472
```
6573

66-
The [benchmark](https://gist.github.com/a-r-r-o-w/5183d75e452a368fd17448fcc810bd3f) results on an 80GB A100 machine are:
74+
The [T2V benchmark](https://gist.github.com/a-r-r-o-w/5183d75e452a368fd17448fcc810bd3f) results on an 80GB A100 machine are:
6775

6876
```
6977
Without torch.compile(): Average inference time: 96.89 seconds.
@@ -87,6 +95,12 @@ CogVideoX-2b requires about 19 GB of GPU memory to decode 49 frames (6 seconds o
8795
- all
8896
- __call__
8997

98+
## CogVideoXImageToVideoPipeline
99+
100+
[[autodoc]] CogVideoXImageToVideoPipeline
101+
- all
102+
- __call__
103+
90104
## CogVideoXVideoToVideoPipeline
91105

92106
[[autodoc]] CogVideoXVideoToVideoPipeline

scripts/convert_cogvideox_to_diffusers.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,13 @@
44
import torch
55
from transformers import T5EncoderModel, T5Tokenizer
66

7-
from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
7+
from diffusers import (
8+
AutoencoderKLCogVideoX,
9+
CogVideoXDDIMScheduler,
10+
CogVideoXImageToVideoPipeline,
11+
CogVideoXPipeline,
12+
CogVideoXTransformer3DModel,
13+
)
814

915

1016
def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]):
@@ -78,6 +84,7 @@ def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):
7884
"mixins.final_layer.norm_final": "norm_out.norm",
7985
"mixins.final_layer.linear": "proj_out",
8086
"mixins.final_layer.adaLN_modulation.1": "norm_out.linear",
87+
"mixins.pos_embed.pos_embedding": "patch_embed.pos_embedding", # Specific to CogVideoX-5b-I2V
8188
}
8289

8390
TRANSFORMER_SPECIAL_KEYS_REMAP = {
@@ -131,15 +138,18 @@ def convert_transformer(
131138
num_layers: int,
132139
num_attention_heads: int,
133140
use_rotary_positional_embeddings: bool,
141+
i2v: bool,
134142
dtype: torch.dtype,
135143
):
136144
PREFIX_KEY = "model.diffusion_model."
137145

138146
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
139147
transformer = CogVideoXTransformer3DModel(
148+
in_channels=32 if i2v else 16,
140149
num_layers=num_layers,
141150
num_attention_heads=num_attention_heads,
142151
use_rotary_positional_embeddings=use_rotary_positional_embeddings,
152+
use_learned_positional_embeddings=i2v,
143153
).to(dtype=dtype)
144154

145155
for key in list(original_state_dict.keys()):
@@ -153,7 +163,6 @@ def convert_transformer(
153163
if special_key not in key:
154164
continue
155165
handler_fn_inplace(key, original_state_dict)
156-
157166
transformer.load_state_dict(original_state_dict, strict=True)
158167
return transformer
159168

@@ -205,6 +214,7 @@ def get_args():
205214
parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE")
206215
# For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
207216
parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE")
217+
parser.add_argument("--i2v", action="store_true", default=False, help="Whether to save the model weights in fp16")
208218
return parser.parse_args()
209219

210220

@@ -225,6 +235,7 @@ def get_args():
225235
args.num_layers,
226236
args.num_attention_heads,
227237
args.use_rotary_positional_embeddings,
238+
args.i2v,
228239
dtype,
229240
)
230241
if args.vae_ckpt_path is not None:
@@ -234,7 +245,7 @@ def get_args():
234245
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
235246
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
236247

237-
# Apparently, the conversion does not work any more without this :shrug:
248+
# Apparently, the conversion does not work anymore without this :shrug:
238249
for param in text_encoder.parameters():
239250
param.data = param.data.contiguous()
240251

@@ -252,9 +263,17 @@ def get_args():
252263
"timestep_spacing": "trailing",
253264
}
254265
)
255-
256-
pipe = CogVideoXPipeline(
257-
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
266+
if args.i2v:
267+
pipeline_cls = CogVideoXImageToVideoPipeline
268+
else:
269+
pipeline_cls = CogVideoXPipeline
270+
271+
pipe = pipeline_cls(
272+
tokenizer=tokenizer,
273+
text_encoder=text_encoder,
274+
vae=vae,
275+
transformer=transformer,
276+
scheduler=scheduler,
258277
)
259278

260279
if args.fp16:
@@ -265,4 +284,7 @@ def get_args():
265284
# We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
266285
# for users to specify variant when the default is not fp32 and they want to run with the correct default (which
267286
# is either fp16/bf16 here).
268-
pipe.save_pretrained(args.output_path, safe_serialization=True, push_to_hub=args.push_to_hub)
287+
288+
# This is necessary This is necessary for users with insufficient memory,
289+
# such as those using Colab and notebooks, as it can save some memory used for model loading.
290+
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@
252252
"BlipDiffusionControlNetPipeline",
253253
"BlipDiffusionPipeline",
254254
"CLIPImageProjection",
255+
"CogVideoXImageToVideoPipeline",
255256
"CogVideoXPipeline",
256257
"CogVideoXVideoToVideoPipeline",
257258
"CycleDiffusionPipeline",
@@ -692,6 +693,7 @@
692693
AudioLDMPipeline,
693694
AuraFlowPipeline,
694695
CLIPImageProjection,
696+
CogVideoXImageToVideoPipeline,
695697
CogVideoXPipeline,
696698
CogVideoXVideoToVideoPipeline,
697699
CycleDiffusionPipeline,

src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,8 +1089,10 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor:
10891089
return self.tiled_encode(x)
10901090

10911091
frame_batch_size = self.num_sample_frames_batch_size
1092+
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
1093+
num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
10921094
enc = []
1093-
for i in range(num_frames // frame_batch_size):
1095+
for i in range(num_batches):
10941096
remaining_frames = num_frames % frame_batch_size
10951097
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
10961098
end_frame = frame_batch_size * (i + 1) + remaining_frames
@@ -1140,8 +1142,9 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut
11401142
return self.tiled_decode(z, return_dict=return_dict)
11411143

11421144
frame_batch_size = self.num_latent_frames_batch_size
1145+
num_batches = num_frames // frame_batch_size
11431146
dec = []
1144-
for i in range(num_frames // frame_batch_size):
1147+
for i in range(num_batches):
11451148
remaining_frames = num_frames % frame_batch_size
11461149
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
11471150
end_frame = frame_batch_size * (i + 1) + remaining_frames
@@ -1233,8 +1236,10 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
12331236
for i in range(0, height, overlap_height):
12341237
row = []
12351238
for j in range(0, width, overlap_width):
1239+
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
1240+
num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
12361241
time = []
1237-
for k in range(num_frames // frame_batch_size):
1242+
for k in range(num_batches):
12381243
remaining_frames = num_frames % frame_batch_size
12391244
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
12401245
end_frame = frame_batch_size * (k + 1) + remaining_frames
@@ -1309,8 +1314,9 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod
13091314
for i in range(0, height, overlap_height):
13101315
row = []
13111316
for j in range(0, width, overlap_width):
1317+
num_batches = num_frames // frame_batch_size
13121318
time = []
1313-
for k in range(num_frames // frame_batch_size):
1319+
for k in range(num_batches):
13141320
remaining_frames = num_frames % frame_batch_size
13151321
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
13161322
end_frame = frame_batch_size * (k + 1) + remaining_frames

src/diffusers/models/embeddings.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ def __init__(
350350
spatial_interpolation_scale: float = 1.875,
351351
temporal_interpolation_scale: float = 1.0,
352352
use_positional_embeddings: bool = True,
353+
use_learned_positional_embeddings: bool = True,
353354
) -> None:
354355
super().__init__()
355356

@@ -363,15 +364,17 @@ def __init__(
363364
self.spatial_interpolation_scale = spatial_interpolation_scale
364365
self.temporal_interpolation_scale = temporal_interpolation_scale
365366
self.use_positional_embeddings = use_positional_embeddings
367+
self.use_learned_positional_embeddings = use_learned_positional_embeddings
366368

367369
self.proj = nn.Conv2d(
368370
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
369371
)
370372
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
371373

372-
if use_positional_embeddings:
374+
if use_positional_embeddings or use_learned_positional_embeddings:
375+
persistent = use_learned_positional_embeddings
373376
pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
374-
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
377+
self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
375378

376379
def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
377380
post_patch_height = sample_height // self.patch_size
@@ -415,8 +418,15 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
415418
[text_embeds, image_embeds], dim=1
416419
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
417420

418-
if self.use_positional_embeddings:
421+
if self.use_positional_embeddings or self.use_learned_positional_embeddings:
422+
if self.use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != height):
423+
raise ValueError(
424+
"It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'."
425+
"If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues."
426+
)
427+
419428
pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
429+
420430
if (
421431
self.sample_height != height
422432
or self.sample_width != width

src/diffusers/models/transformers/cogvideox_transformer_3d.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,10 +235,18 @@ def __init__(
235235
spatial_interpolation_scale: float = 1.875,
236236
temporal_interpolation_scale: float = 1.0,
237237
use_rotary_positional_embeddings: bool = False,
238+
use_learned_positional_embeddings: bool = False,
238239
):
239240
super().__init__()
240241
inner_dim = num_attention_heads * attention_head_dim
241242

243+
if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
244+
raise ValueError(
245+
"There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
246+
"embeddings. If you're using a custom model and/or believe this should be supported, please open an "
247+
"issue at https://github.com/huggingface/diffusers/issues."
248+
)
249+
242250
# 1. Patch embedding
243251
self.patch_embed = CogVideoXPatchEmbed(
244252
patch_size=patch_size,
@@ -254,6 +262,7 @@ def __init__(
254262
spatial_interpolation_scale=spatial_interpolation_scale,
255263
temporal_interpolation_scale=temporal_interpolation_scale,
256264
use_positional_embeddings=not use_rotary_positional_embeddings,
265+
use_learned_positional_embeddings=use_learned_positional_embeddings,
257266
)
258267
self.embedding_dropout = nn.Dropout(dropout)
259268

@@ -465,8 +474,11 @@ def custom_forward(*inputs):
465474
hidden_states = self.proj_out(hidden_states)
466475

467476
# 5. Unpatchify
477+
# Note: we use `-1` instead of `channels`:
478+
# - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
479+
# - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
468480
p = self.config.patch_size
469-
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
481+
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
470482
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
471483

472484
if not return_dict:

src/diffusers/pipelines/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,11 @@
132132
"AudioLDM2UNet2DConditionModel",
133133
]
134134
_import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"]
135-
_import_structure["cogvideo"] = ["CogVideoXPipeline", "CogVideoXVideoToVideoPipeline"]
135+
_import_structure["cogvideo"] = [
136+
"CogVideoXPipeline",
137+
"CogVideoXImageToVideoPipeline",
138+
"CogVideoXVideoToVideoPipeline",
139+
]
136140
_import_structure["controlnet"].extend(
137141
[
138142
"BlipDiffusionControlNetPipeline",
@@ -452,7 +456,7 @@
452456
)
453457
from .aura_flow import AuraFlowPipeline
454458
from .blip_diffusion import BlipDiffusionPipeline
455-
from .cogvideo import CogVideoXPipeline, CogVideoXVideoToVideoPipeline
459+
from .cogvideo import CogVideoXImageToVideoPipeline, CogVideoXPipeline, CogVideoXVideoToVideoPipeline
456460
from .controlnet import (
457461
BlipDiffusionControlNetPipeline,
458462
StableDiffusionControlNetImg2ImgPipeline,

src/diffusers/pipelines/cogvideo/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
2424
else:
2525
_import_structure["pipeline_cogvideox"] = ["CogVideoXPipeline"]
26+
_import_structure["pipeline_cogvideox_image2video"] = ["CogVideoXImageToVideoPipeline"]
2627
_import_structure["pipeline_cogvideox_video2video"] = ["CogVideoXVideoToVideoPipeline"]
2728

2829
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -34,6 +35,7 @@
3435
from ...utils.dummy_torch_and_transformers_objects import *
3536
else:
3637
from .pipeline_cogvideox import CogVideoXPipeline
38+
from .pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline
3739
from .pipeline_cogvideox_video2video import CogVideoXVideoToVideoPipeline
3840

3941
else:

0 commit comments

Comments
 (0)