Skip to content

[core] Hunyuan Video #10136

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 67 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from 61 commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
4d4be9a
copy transformer
a-r-r-o-w Dec 5, 2024
2e61a9d
copy vae
a-r-r-o-w Dec 5, 2024
d885a6b
copy pipeline
a-r-r-o-w Dec 5, 2024
332c771
make fix-copies
a-r-r-o-w Dec 5, 2024
2257ff8
Merge branch 'main' into hunyuan-video
a-r-r-o-w Dec 10, 2024
7709747
refactor; make original code work with diffusers; test latents for co…
a-r-r-o-w Dec 10, 2024
fbe5031
move rope into pipeline; remove flash attention; refactor
a-r-r-o-w Dec 10, 2024
0894c38
begin conversion script
a-r-r-o-w Dec 10, 2024
a159e58
make style
a-r-r-o-w Dec 10, 2024
5bce938
refactor attention
a-r-r-o-w Dec 10, 2024
491a5b4
refactor
a-r-r-o-w Dec 10, 2024
a47a710
refactor final layer
a-r-r-o-w Dec 11, 2024
ee6880d
their mlp -> our feedforward
a-r-r-o-w Dec 11, 2024
a23cfa1
make style
a-r-r-o-w Dec 11, 2024
ab319fe
add docs
a-r-r-o-w Dec 11, 2024
e3abe38
refactor layer names
a-r-r-o-w Dec 11, 2024
43f6295
refactor modulation
a-r-r-o-w Dec 11, 2024
bb6f023
cleanup
a-r-r-o-w Dec 11, 2024
d727684
refactor norms
a-r-r-o-w Dec 11, 2024
a247ca6
refactor activations
a-r-r-o-w Dec 11, 2024
7ba4609
refactor single blocks attention
a-r-r-o-w Dec 11, 2024
cb4fc37
refactor attention processor
a-r-r-o-w Dec 11, 2024
1e80f7c
make style
a-r-r-o-w Dec 11, 2024
a9bd457
cleanup a bit
a-r-r-o-w Dec 11, 2024
ca4c81e
Merge branch 'main' into hunyuan-video
a-r-r-o-w Dec 11, 2024
f637479
refactor double transformer block attention
a-r-r-o-w Dec 11, 2024
19b2d56
update mochi attn proc
a-r-r-o-w Dec 11, 2024
c1faf0d
use diffusers attention implementation in all modules; checkpoint for…
a-r-r-o-w Dec 11, 2024
cfcb5f0
Merge branch 'main' into hunyuan-video
a-r-r-o-w Dec 11, 2024
e258480
Merge branch 'main' into hunyuan-video
a-r-r-o-w Dec 12, 2024
6915d62
remove helper functions in vae
a-r-r-o-w Dec 12, 2024
1b27c3a
refactor upsample
a-r-r-o-w Dec 12, 2024
bea9e1b
refactor causal conv
a-r-r-o-w Dec 12, 2024
d6c16ef
refactor resnet
a-r-r-o-w Dec 12, 2024
d0036ff
refactor
a-r-r-o-w Dec 12, 2024
da53620
refactor
a-r-r-o-w Dec 12, 2024
f143b02
refactor
a-r-r-o-w Dec 12, 2024
2a72d20
grad checkpointing
a-r-r-o-w Dec 12, 2024
d0c61e0
autoencoder test
a-r-r-o-w Dec 12, 2024
59c8552
Merge branch 'main' into hunyuan-video
a-r-r-o-w Dec 12, 2024
bae257d
Merge branch 'main' into hunyuan-video
a-r-r-o-w Dec 13, 2024
845f303
fix scaling factor
a-r-r-o-w Dec 13, 2024
556c6e9
refactor clip
a-r-r-o-w Dec 13, 2024
4c6cf2d
refactor llama text encoding
a-r-r-o-w Dec 13, 2024
d9ae8de
add coauthor
a-r-r-o-w Dec 13, 2024
e713660
refactor rope; diff: 0.14990234375; reason and fix: create rope grid …
a-r-r-o-w Dec 13, 2024
9039db4
use diffusers timesteps embedding; diff: 0.10205078125
a-r-r-o-w Dec 13, 2024
16778b1
rename
a-r-r-o-w Dec 13, 2024
b6c7ae0
convert
a-r-r-o-w Dec 13, 2024
e7c382e
update
a-r-r-o-w Dec 13, 2024
166194f
Merge branch 'main' into hunyuan-video
a-r-r-o-w Dec 13, 2024
dbba9c7
add tests for transformer
a-r-r-o-w Dec 14, 2024
36dea10
add pipeline tests; text encoder 2 is not optional
a-r-r-o-w Dec 14, 2024
1c7b317
fix attention implementation for torch
a-r-r-o-w Dec 14, 2024
ca98227
add example
a-r-r-o-w Dec 14, 2024
154b31c
update docs
a-r-r-o-w Dec 14, 2024
ae0b359
update docs
a-r-r-o-w Dec 15, 2024
eee00ab
apply suggestions from review
a-r-r-o-w Dec 15, 2024
0461475
refactor vae
a-r-r-o-w Dec 15, 2024
0578c2a
Merge branch 'main' into hunyuan-video
a-r-r-o-w Dec 15, 2024
edfc64b
update
a-r-r-o-w Dec 15, 2024
bfe9c46
Apply suggestions from code review
a-r-r-o-w Dec 15, 2024
f906aa8
Update src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py
a-r-r-o-w Dec 15, 2024
9795469
Update src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py
a-r-r-o-w Dec 15, 2024
6867c52
make fix-copies
a-r-r-o-w Dec 15, 2024
ce7b0b9
update
a-r-r-o-w Dec 15, 2024
fc2f124
Merge branch 'main' into hunyuan-video
a-r-r-o-w Dec 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,8 @@
title: FluxTransformer2DModel
- local: api/models/hunyuan_transformer2d
title: HunyuanDiT2DModel
- local: api/models/hunyuan_video_transformer_3d
title: HunyuanVideoTransformer3DModel
- local: api/models/latte_transformer3d
title: LatteTransformer3DModel
- local: api/models/lumina_nextdit2d
Expand Down Expand Up @@ -314,6 +316,8 @@
title: AutoencoderKLAllegro
- local: api/models/autoencoderkl_cogvideox
title: AutoencoderKLCogVideoX
- local: api/models/autoencoder_kl_hunyuan_video
title: AutoencoderKLHunyuanVideo
- local: api/models/autoencoderkl_ltx_video
title: AutoencoderKLLTXVideo
- local: api/models/autoencoderkl_mochi
Expand Down Expand Up @@ -392,6 +396,8 @@
title: Flux
- local: api/pipelines/hunyuandit
title: Hunyuan-DiT
- local: api/pipelines/hunyuan_video
title: HunyuanVideo
- local: api/pipelines/i2vgenxl
title: I2VGen-XL
- local: api/pipelines/pix2pix
Expand Down
32 changes: 32 additions & 0 deletions docs/source/en/api/models/autoencoder_kl_hunyuan_video.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->

# AutoencoderKLHunyuanVideo

The 3D variational autoencoder (VAE) model with KL loss used in [HunyuanVideo](https://github.com/Tencent/HunyuanVideo/), which was introduced in [HunyuanVideo: A Systematic Framework For Large Video Generative Models](https://huggingface.co/papers/2412.03603) by Tencent.

The model can be loaded with the following code snippet.

```python
from diffusers import AutoencoderKLHunyuanVideo

vae = AutoencoderKLHunyuanVideo.from_pretrained("tencent/HunyuanVideo", torch_dtype=torch.float16)
```

## AutoencoderKLHunyuanVideo

[[autodoc]] AutoencoderKLHunyuanVideo
- decode
- all

## DecoderOutput

[[autodoc]] models.autoencoders.vae.DecoderOutput
30 changes: 30 additions & 0 deletions docs/source/en/api/models/hunyuan_video_transformer_3d.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->

# HunyuanVideoTransformer3DModel

A Diffusion Transformer model for 3D video-like data was introduced in [HunyuanVideo: A Systematic Framework For Large Video Generative Models](https://huggingface.co/papers/2412.03603) by Tencent.

The model can be loaded with the following code snippet.

```python
from diffusers import HunyuanVideoTransformer3DModel

transformer = HunyuanVideoTransformer3DModel.from_pretrained("tencent/HunyuanVideo", torch_dtype=torch.bfloat16)
```

## HunyuanVideoTransformer3DModel

[[autodoc]] HunyuanVideoTransformer3DModel

## Transformer2DModelOutput

[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
43 changes: 43 additions & 0 deletions docs/source/en/api/pipelines/hunyuan_video.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->

# HunyuanVideo

[HunyuanVideo](https://www.arxiv.org/abs/2412.03603) by Tencent.

*Recent advancements in video generation have significantly impacted daily life for both individuals and industries. However, the leading video generation models remain closed-source, resulting in a notable performance gap between industry capabilities and those available to the public. In this report, we introduce HunyuanVideo, an innovative open-source video foundation model that demonstrates performance in video generation comparable to, or even surpassing, that of leading closed-source models. HunyuanVideo encompasses a comprehensive framework that integrates several key elements, including data curation, advanced architectural design, progressive model scaling and training, and an efficient infrastructure tailored for large-scale model training and inference. As a result, we successfully trained a video generative model with over 13 billion parameters, making it the largest among all open-source models. We conducted extensive experiments and implemented a series of targeted designs to ensure high visual quality, motion dynamics, text-video alignment, and advanced filming techniques. According to evaluations by professionals, HunyuanVideo outperforms previous state-of-the-art models, including Runway Gen-3, Luma 1.6, and three top-performing Chinese video generative models. By releasing the code for the foundation model and its applications, we aim to bridge the gap between closed-source and open-source communities. This initiative will empower individuals within the community to experiment with their ideas, fostering a more dynamic and vibrant video generation ecosystem. The code is publicly available at [this https URL](https://github.com/Tencent/HunyuanVideo).*

<Tip>

Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.

</Tip>

Recommendations for inference:
- Both text encoders should be in `torch.float16`.
- Transformer should be in `torch.bfloat16`.
- VAE should be in `torch.float16`.
- `num_frames` should be of the form `4 * k + 1`, for example `49` or `129`.
- For smaller resolution images, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution images, try higher values (between `7.0` and `12.0`). The default value is `7.0` for HunyuanVideo.
- For more information about supported resolutions and other details, please refer to the original repository [here](https://github.com/Tencent/HunyuanVideo/).

## HunyuanVideoPipeline

[[autodoc]] HunyuanVideoPipeline
- all
- __call__

## HunyuanVideoPipelineOutput

[[autodoc]] pipelines.hunyuan_video.pipeline_output.HunyuanVideoPipelineOutput
257 changes: 257 additions & 0 deletions scripts/convert_hunyuan_video_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
import argparse
from typing import Any, Dict

import torch
from accelerate import init_empty_weights
from transformers import AutoModel, AutoTokenizer, CLIPTextModel, CLIPTokenizer

from diffusers import (
AutoencoderKLHunyuanVideo,
FlowMatchEulerDiscreteScheduler,
HunyuanVideoPipeline,
HunyuanVideoTransformer3DModel,
)


def remap_norm_scale_shift_(key, state_dict):
weight = state_dict.pop(key)
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight


def remap_txt_in_(key, state_dict):
def rename_key(key):
new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
new_key = new_key.replace("txt_in", "context_embedder")
new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
new_key = new_key.replace("mlp", "ff")
return new_key

if "self_attn_qkv" in key:
weight = state_dict.pop(key)
to_q, to_k, to_v = weight.chunk(3, dim=0)
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
else:
state_dict[rename_key(key)] = state_dict.pop(key)


def remap_img_attn_qkv_(key, state_dict):
weight = state_dict.pop(key)
to_q, to_k, to_v = weight.chunk(3, dim=0)
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v


def remap_txt_attn_qkv_(key, state_dict):
weight = state_dict.pop(key)
to_q, to_k, to_v = weight.chunk(3, dim=0)
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v


def remap_single_transformer_blocks_(key, state_dict):
hidden_size = 3072

if "linear1.weight" in key:
linear1_weight = state_dict.pop(key)
split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight")
state_dict[f"{new_key}.attn.to_q.weight"] = q
state_dict[f"{new_key}.attn.to_k.weight"] = k
state_dict[f"{new_key}.attn.to_v.weight"] = v
state_dict[f"{new_key}.proj_mlp.weight"] = mlp

elif "linear1.bias" in key:
linear1_bias = state_dict.pop(key)
split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias")
state_dict[f"{new_key}.attn.to_q.bias"] = q_bias
state_dict[f"{new_key}.attn.to_k.bias"] = k_bias
state_dict[f"{new_key}.attn.to_v.bias"] = v_bias
state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias

else:
new_key = key.replace("single_blocks", "single_transformer_blocks")
new_key = new_key.replace("linear2", "proj_out")
new_key = new_key.replace("q_norm", "attn.norm_q")
new_key = new_key.replace("k_norm", "attn.norm_k")
state_dict[new_key] = state_dict.pop(key)


TRANSFORMER_KEYS_RENAME_DICT = {
"img_in": "x_embedder",
"time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
"time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
"guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
"guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
"vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
"vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
"double_blocks": "transformer_blocks",
"img_attn_q_norm": "attn.norm_q",
"img_attn_k_norm": "attn.norm_k",
"img_attn_proj": "attn.to_out.0",
"txt_attn_q_norm": "attn.norm_added_q",
"txt_attn_k_norm": "attn.norm_added_k",
"txt_attn_proj": "attn.to_add_out",
"img_mod.linear": "norm1.linear",
"img_norm1": "norm1.norm",
"img_norm2": "norm2",
"img_mlp": "ff",
"txt_mod.linear": "norm1_context.linear",
"txt_norm1": "norm1.norm",
"txt_norm2": "norm2_context",
"txt_mlp": "ff_context",
"self_attn_proj": "attn.to_out.0",
"modulation.linear": "norm.linear",
"pre_norm": "norm.norm",
"final_layer.norm_final": "norm_out.norm",
"final_layer.linear": "proj_out",
"fc1": "net.0.proj",
"fc2": "net.2",
"input_embedder": "proj_in",
}

TRANSFORMER_SPECIAL_KEYS_REMAP = {
"txt_in": remap_txt_in_,
"img_attn_qkv": remap_img_attn_qkv_,
"txt_attn_qkv": remap_txt_attn_qkv_,
"single_blocks": remap_single_transformer_blocks_,
"final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
}

VAE_KEYS_RENAME_DICT = {}

VAE_SPECIAL_KEYS_REMAP = {}


def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
state_dict[new_key] = state_dict.pop(old_key)


def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
state_dict = saved_dict
if "model" in saved_dict.keys():
state_dict = state_dict["model"]
if "module" in saved_dict.keys():
state_dict = state_dict["module"]
if "state_dict" in saved_dict.keys():
state_dict = state_dict["state_dict"]
return state_dict


def convert_transformer(ckpt_path: str):
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True))

with init_empty_weights():
transformer = HunyuanVideoTransformer3DModel()

for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_(original_state_dict, key, new_key)

for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)

transformer.load_state_dict(original_state_dict, strict=True, assign=True)
return transformer


def convert_vae(ckpt_path: str):
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True))

with init_empty_weights():
vae = AutoencoderKLHunyuanVideo()

for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_(original_state_dict, key, new_key)

for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)

vae.load_state_dict(original_state_dict, strict=True, assign=True)
return vae


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
)
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original VAE checkpoint")
parser.add_argument("--text_encoder_path", type=str, default=None, help="Path to original llama checkpoint")
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to original llama tokenizer")
parser.add_argument("--text_encoder_2_path", type=str, default=None, help="Path to original clip checkpoint")
parser.add_argument("--save_pipeline", action="store_true")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
return parser.parse_args()


DTYPE_MAPPING = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}


if __name__ == "__main__":
args = get_args()

transformer = None
dtype = DTYPE_MAPPING[args.dtype]

if args.save_pipeline:
assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None
assert args.text_encoder_path is not None
assert args.tokenizer_path is not None
assert args.text_encoder_2_path is not None

if args.transformer_ckpt_path is not None:
transformer = convert_transformer(args.transformer_ckpt_path)
transformer = transformer.to(dtype=dtype)
if not args.save_pipeline:
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")

if args.vae_ckpt_path is not None:
vae = convert_vae(args.vae_ckpt_path)
if not args.save_pipeline:
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")

if args.save_pipeline:
text_encoder = AutoModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right")
text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16)
tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path)
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)

pipe = HunyuanVideoPipeline(
transformer=transformer,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
scheduler=scheduler,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
Loading
Loading