Skip to content

Commit 5e3b7d2

Browse files
bubbliiiinga-r-r-o-wyiyixuxuDN6
authored
Add EasyAnimateV5.1 text-to-video, image-to-video, control-to-video generation model (#10626)
* Update EasyAnimate V5.1 * Add docs && add tests && Fix comments problems in transformer3d and vae * delete comments and remove useless import * delete process * Update EXAMPLE_DOC_STRING * rename transformer file * make fix-copies * make style * refactor pt. 1 * update toctree.yml * add model tests * Update layer_norm for norm_added_q and norm_added_k in Attention * Fix processor problem * refactor vae * Fix problem in comments * refactor tiling; remove einops dependency * fix docs path * make fix-copies * Update src/diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py * update _toctree.yml * fix test * update * update * update * make fix-copies * fix tests --------- Co-authored-by: Aryan <[email protected]> Co-authored-by: Aryan <[email protected]> Co-authored-by: YiYi Xu <[email protected]> Co-authored-by: Dhruv Nair <[email protected]>
1 parent 7513162 commit 5e3b7d2

24 files changed

+5432
-1
lines changed

docs/source/en/_toctree.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,8 @@
290290
title: CogView4Transformer2DModel
291291
- local: api/models/dit_transformer2d
292292
title: DiTTransformer2DModel
293+
- local: api/models/easyanimate_transformer3d
294+
title: EasyAnimateTransformer3DModel
293295
- local: api/models/flux_transformer
294296
title: FluxTransformer2DModel
295297
- local: api/models/hunyuan_transformer2d
@@ -352,6 +354,8 @@
352354
title: AutoencoderKLHunyuanVideo
353355
- local: api/models/autoencoderkl_ltx_video
354356
title: AutoencoderKLLTXVideo
357+
- local: api/models/autoencoderkl_magvit
358+
title: AutoencoderKLMagvit
355359
- local: api/models/autoencoderkl_mochi
356360
title: AutoencoderKLMochi
357361
- local: api/models/autoencoder_kl_wan
@@ -430,6 +434,8 @@
430434
title: DiffEdit
431435
- local: api/pipelines/dit
432436
title: DiT
437+
- local: api/pipelines/easyanimate
438+
title: EasyAnimate
433439
- local: api/pipelines/flux
434440
title: Flux
435441
- local: api/pipelines/control_flux_inpaint
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License. -->
11+
12+
# AutoencoderKLMagvit
13+
14+
The 3D variational autoencoder (VAE) model with KL loss used in [EasyAnimate](https://github.com/aigc-apps/EasyAnimate) was introduced by Alibaba PAI.
15+
16+
The model can be loaded with the following code snippet.
17+
18+
```python
19+
from diffusers import AutoencoderKLMagvit
20+
21+
vae = AutoencoderKLMagvit.from_pretrained("alibaba-pai/EasyAnimateV5.1-12b-zh", subfolder="vae", torch_dtype=torch.float16).to("cuda")
22+
```
23+
24+
## AutoencoderKLMagvit
25+
26+
[[autodoc]] AutoencoderKLMagvit
27+
- decode
28+
- encode
29+
- all
30+
31+
## AutoencoderKLOutput
32+
33+
[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
34+
35+
## DecoderOutput
36+
37+
[[autodoc]] models.autoencoders.vae.DecoderOutput
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License. -->
11+
12+
# EasyAnimateTransformer3DModel
13+
14+
A Diffusion Transformer model for 3D data from [EasyAnimate](https://github.com/aigc-apps/EasyAnimate) was introduced by Alibaba PAI.
15+
16+
The model can be loaded with the following code snippet.
17+
18+
```python
19+
from diffusers import EasyAnimateTransformer3DModel
20+
21+
transformer = EasyAnimateTransformer3DModel.from_pretrained("alibaba-pai/EasyAnimateV5.1-12b-zh", subfolder="transformer", torch_dtype=torch.float16).to("cuda")
22+
```
23+
24+
## EasyAnimateTransformer3DModel
25+
26+
[[autodoc]] EasyAnimateTransformer3DModel
27+
28+
## Transformer2DModelOutput
29+
30+
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
-->
15+
16+
# EasyAnimate
17+
[EasyAnimate](https://github.com/aigc-apps/EasyAnimate) by Alibaba PAI.
18+
19+
The description from it's GitHub page:
20+
*EasyAnimate is a pipeline based on the transformer architecture, designed for generating AI images and videos, and for training baseline models and Lora models for Diffusion Transformer. We support direct prediction from pre-trained EasyAnimate models, allowing for the generation of videos with various resolutions, approximately 6 seconds in length, at 8fps (EasyAnimateV5.1, 1 to 49 frames). Additionally, users can train their own baseline and Lora models for specific style transformations.*
21+
22+
This pipeline was contributed by [bubbliiiing](https://github.com/bubbliiiing). The original codebase can be found [here](https://huggingface.co/alibaba-pai). The original weights can be found under [hf.co/alibaba-pai](https://huggingface.co/alibaba-pai).
23+
24+
There are two official EasyAnimate checkpoints for text-to-video and video-to-video.
25+
26+
| checkpoints | recommended inference dtype |
27+
|:---:|:---:|
28+
| [`alibaba-pai/EasyAnimateV5.1-12b-zh`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh) | torch.float16 |
29+
| [`alibaba-pai/EasyAnimateV5.1-12b-zh-InP`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-InP) | torch.float16 |
30+
31+
There is one official EasyAnimate checkpoints available for image-to-video and video-to-video.
32+
33+
| checkpoints | recommended inference dtype |
34+
|:---:|:---:|
35+
| [`alibaba-pai/EasyAnimateV5.1-12b-zh-InP`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-InP) | torch.float16 |
36+
37+
There are two official EasyAnimate checkpoints available for control-to-video.
38+
39+
| checkpoints | recommended inference dtype |
40+
|:---:|:---:|
41+
| [`alibaba-pai/EasyAnimateV5.1-12b-zh-Control`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-Control) | torch.float16 |
42+
| [`alibaba-pai/EasyAnimateV5.1-12b-zh-Control-Camera`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-Control-Camera) | torch.float16 |
43+
44+
For the EasyAnimateV5.1 series:
45+
- Text-to-video (T2V) and Image-to-video (I2V) works for multiple resolutions. The width and height can vary from 256 to 1024.
46+
- Both T2V and I2V models support generation with 1~49 frames and work best at this value. Exporting videos at 8 FPS is recommended.
47+
48+
## Quantization
49+
50+
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
51+
52+
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`EasyAnimatePipeline`] for inference with bitsandbytes.
53+
54+
```py
55+
import torch
56+
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, EasyAnimateTransformer3DModel, EasyAnimatePipeline
57+
from diffusers.utils import export_to_video
58+
59+
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
60+
transformer_8bit = EasyAnimateTransformer3DModel.from_pretrained(
61+
"alibaba-pai/EasyAnimateV5.1-12b-zh",
62+
subfolder="transformer",
63+
quantization_config=quant_config,
64+
torch_dtype=torch.float16,
65+
)
66+
67+
pipeline = EasyAnimatePipeline.from_pretrained(
68+
"alibaba-pai/EasyAnimateV5.1-12b-zh",
69+
transformer=transformer_8bit,
70+
torch_dtype=torch.float16,
71+
device_map="balanced",
72+
)
73+
74+
prompt = "A cat walks on the grass, realistic style."
75+
negative_prompt = "bad detailed"
76+
video = pipeline(prompt=prompt, negative_prompt=negative_prompt, num_frames=49, num_inference_steps=30).frames[0]
77+
export_to_video(video, "cat.mp4", fps=8)
78+
```
79+
80+
## EasyAnimatePipeline
81+
82+
[[autodoc]] EasyAnimatePipeline
83+
- all
84+
- __call__
85+
86+
## EasyAnimatePipelineOutput
87+
88+
[[autodoc]] pipelines.easyanimate.pipeline_output.EasyAnimatePipelineOutput

src/diffusers/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@
9494
"AutoencoderKLCogVideoX",
9595
"AutoencoderKLHunyuanVideo",
9696
"AutoencoderKLLTXVideo",
97+
"AutoencoderKLMagvit",
9798
"AutoencoderKLMochi",
9899
"AutoencoderKLTemporalDecoder",
99100
"AutoencoderKLWan",
@@ -109,6 +110,7 @@
109110
"ControlNetUnionModel",
110111
"ControlNetXSAdapter",
111112
"DiTTransformer2DModel",
113+
"EasyAnimateTransformer3DModel",
112114
"FluxControlNetModel",
113115
"FluxMultiControlNetModel",
114116
"FluxTransformer2DModel",
@@ -293,6 +295,9 @@
293295
"CogView4Pipeline",
294296
"ConsisIDPipeline",
295297
"CycleDiffusionPipeline",
298+
"EasyAnimateControlPipeline",
299+
"EasyAnimateInpaintPipeline",
300+
"EasyAnimatePipeline",
296301
"FluxControlImg2ImgPipeline",
297302
"FluxControlInpaintPipeline",
298303
"FluxControlNetImg2ImgPipeline",
@@ -620,6 +625,7 @@
620625
AutoencoderKLCogVideoX,
621626
AutoencoderKLHunyuanVideo,
622627
AutoencoderKLLTXVideo,
628+
AutoencoderKLMagvit,
623629
AutoencoderKLMochi,
624630
AutoencoderKLTemporalDecoder,
625631
AutoencoderKLWan,
@@ -635,6 +641,7 @@
635641
ControlNetUnionModel,
636642
ControlNetXSAdapter,
637643
DiTTransformer2DModel,
644+
EasyAnimateTransformer3DModel,
638645
FluxControlNetModel,
639646
FluxMultiControlNetModel,
640647
FluxTransformer2DModel,
@@ -798,6 +805,9 @@
798805
CogView4Pipeline,
799806
ConsisIDPipeline,
800807
CycleDiffusionPipeline,
808+
EasyAnimateControlPipeline,
809+
EasyAnimateInpaintPipeline,
810+
EasyAnimatePipeline,
801811
FluxControlImg2ImgPipeline,
802812
FluxControlInpaintPipeline,
803813
FluxControlNetImg2ImgPipeline,

src/diffusers/models/__init__.py

100644100755
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
3434
_import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
3535
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
36+
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
3637
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
3738
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
3839
_import_structure["autoencoders.autoencoder_kl_wan"] = ["AutoencoderKLWan"]
@@ -72,6 +73,7 @@
7273
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
7374
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
7475
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
76+
_import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"]
7577
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
7678
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
7779
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
@@ -109,6 +111,7 @@
109111
AutoencoderKLCogVideoX,
110112
AutoencoderKLHunyuanVideo,
111113
AutoencoderKLLTXVideo,
114+
AutoencoderKLMagvit,
112115
AutoencoderKLMochi,
113116
AutoencoderKLTemporalDecoder,
114117
AutoencoderKLWan,
@@ -144,6 +147,7 @@
144147
ConsisIDTransformer3DModel,
145148
DiTTransformer2DModel,
146149
DualTransformer2DModel,
150+
EasyAnimateTransformer3DModel,
147151
FluxTransformer2DModel,
148152
HunyuanDiT2DModel,
149153
HunyuanVideoTransformer3DModel,

src/diffusers/models/attention_processor.py

100644100755
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,10 @@ def __init__(
274274
self.to_add_out = None
275275

276276
if qk_norm is not None and added_kv_proj_dim is not None:
277-
if qk_norm == "fp32_layer_norm":
277+
if qk_norm == "layer_norm":
278+
self.norm_added_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
279+
self.norm_added_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
280+
elif qk_norm == "fp32_layer_norm":
278281
self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
279282
self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
280283
elif qk_norm == "rms_norm":

src/diffusers/models/autoencoders/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
66
from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
77
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
8+
from .autoencoder_kl_magvit import AutoencoderKLMagvit
89
from .autoencoder_kl_mochi import AutoencoderKLMochi
910
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
1011
from .autoencoder_kl_wan import AutoencoderKLWan

0 commit comments

Comments
 (0)