|
| 1 | +import argparse |
| 2 | +from typing import Any, Dict |
| 3 | + |
| 4 | +import torch |
| 5 | +from accelerate import init_empty_weights |
| 6 | +from transformers import AutoModel, AutoTokenizer, CLIPTextModel, CLIPTokenizer |
| 7 | + |
| 8 | +from diffusers import ( |
| 9 | + AutoencoderKLHunyuanVideo, |
| 10 | + FlowMatchEulerDiscreteScheduler, |
| 11 | + HunyuanVideoPipeline, |
| 12 | + HunyuanVideoTransformer3DModel, |
| 13 | +) |
| 14 | + |
| 15 | + |
| 16 | +def remap_norm_scale_shift_(key, state_dict): |
| 17 | + weight = state_dict.pop(key) |
| 18 | + shift, scale = weight.chunk(2, dim=0) |
| 19 | + new_weight = torch.cat([scale, shift], dim=0) |
| 20 | + state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight |
| 21 | + |
| 22 | + |
| 23 | +def remap_txt_in_(key, state_dict): |
| 24 | + def rename_key(key): |
| 25 | + new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks") |
| 26 | + new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear") |
| 27 | + new_key = new_key.replace("txt_in", "context_embedder") |
| 28 | + new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1") |
| 29 | + new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2") |
| 30 | + new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder") |
| 31 | + new_key = new_key.replace("mlp", "ff") |
| 32 | + return new_key |
| 33 | + |
| 34 | + if "self_attn_qkv" in key: |
| 35 | + weight = state_dict.pop(key) |
| 36 | + to_q, to_k, to_v = weight.chunk(3, dim=0) |
| 37 | + state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q |
| 38 | + state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k |
| 39 | + state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v |
| 40 | + else: |
| 41 | + state_dict[rename_key(key)] = state_dict.pop(key) |
| 42 | + |
| 43 | + |
| 44 | +def remap_img_attn_qkv_(key, state_dict): |
| 45 | + weight = state_dict.pop(key) |
| 46 | + to_q, to_k, to_v = weight.chunk(3, dim=0) |
| 47 | + state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q |
| 48 | + state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k |
| 49 | + state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v |
| 50 | + |
| 51 | + |
| 52 | +def remap_txt_attn_qkv_(key, state_dict): |
| 53 | + weight = state_dict.pop(key) |
| 54 | + to_q, to_k, to_v = weight.chunk(3, dim=0) |
| 55 | + state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q |
| 56 | + state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k |
| 57 | + state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v |
| 58 | + |
| 59 | + |
| 60 | +def remap_single_transformer_blocks_(key, state_dict): |
| 61 | + hidden_size = 3072 |
| 62 | + |
| 63 | + if "linear1.weight" in key: |
| 64 | + linear1_weight = state_dict.pop(key) |
| 65 | + split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size) |
| 66 | + q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0) |
| 67 | + new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight") |
| 68 | + state_dict[f"{new_key}.attn.to_q.weight"] = q |
| 69 | + state_dict[f"{new_key}.attn.to_k.weight"] = k |
| 70 | + state_dict[f"{new_key}.attn.to_v.weight"] = v |
| 71 | + state_dict[f"{new_key}.proj_mlp.weight"] = mlp |
| 72 | + |
| 73 | + elif "linear1.bias" in key: |
| 74 | + linear1_bias = state_dict.pop(key) |
| 75 | + split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size) |
| 76 | + q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0) |
| 77 | + new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias") |
| 78 | + state_dict[f"{new_key}.attn.to_q.bias"] = q_bias |
| 79 | + state_dict[f"{new_key}.attn.to_k.bias"] = k_bias |
| 80 | + state_dict[f"{new_key}.attn.to_v.bias"] = v_bias |
| 81 | + state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias |
| 82 | + |
| 83 | + else: |
| 84 | + new_key = key.replace("single_blocks", "single_transformer_blocks") |
| 85 | + new_key = new_key.replace("linear2", "proj_out") |
| 86 | + new_key = new_key.replace("q_norm", "attn.norm_q") |
| 87 | + new_key = new_key.replace("k_norm", "attn.norm_k") |
| 88 | + state_dict[new_key] = state_dict.pop(key) |
| 89 | + |
| 90 | + |
| 91 | +TRANSFORMER_KEYS_RENAME_DICT = { |
| 92 | + "img_in": "x_embedder", |
| 93 | + "time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1", |
| 94 | + "time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2", |
| 95 | + "guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1", |
| 96 | + "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2", |
| 97 | + "vector_in.in_layer": "time_text_embed.text_embedder.linear_1", |
| 98 | + "vector_in.out_layer": "time_text_embed.text_embedder.linear_2", |
| 99 | + "double_blocks": "transformer_blocks", |
| 100 | + "img_attn_q_norm": "attn.norm_q", |
| 101 | + "img_attn_k_norm": "attn.norm_k", |
| 102 | + "img_attn_proj": "attn.to_out.0", |
| 103 | + "txt_attn_q_norm": "attn.norm_added_q", |
| 104 | + "txt_attn_k_norm": "attn.norm_added_k", |
| 105 | + "txt_attn_proj": "attn.to_add_out", |
| 106 | + "img_mod.linear": "norm1.linear", |
| 107 | + "img_norm1": "norm1.norm", |
| 108 | + "img_norm2": "norm2", |
| 109 | + "img_mlp": "ff", |
| 110 | + "txt_mod.linear": "norm1_context.linear", |
| 111 | + "txt_norm1": "norm1.norm", |
| 112 | + "txt_norm2": "norm2_context", |
| 113 | + "txt_mlp": "ff_context", |
| 114 | + "self_attn_proj": "attn.to_out.0", |
| 115 | + "modulation.linear": "norm.linear", |
| 116 | + "pre_norm": "norm.norm", |
| 117 | + "final_layer.norm_final": "norm_out.norm", |
| 118 | + "final_layer.linear": "proj_out", |
| 119 | + "fc1": "net.0.proj", |
| 120 | + "fc2": "net.2", |
| 121 | + "input_embedder": "proj_in", |
| 122 | +} |
| 123 | + |
| 124 | +TRANSFORMER_SPECIAL_KEYS_REMAP = { |
| 125 | + "txt_in": remap_txt_in_, |
| 126 | + "img_attn_qkv": remap_img_attn_qkv_, |
| 127 | + "txt_attn_qkv": remap_txt_attn_qkv_, |
| 128 | + "single_blocks": remap_single_transformer_blocks_, |
| 129 | + "final_layer.adaLN_modulation.1": remap_norm_scale_shift_, |
| 130 | +} |
| 131 | + |
| 132 | +VAE_KEYS_RENAME_DICT = {} |
| 133 | + |
| 134 | +VAE_SPECIAL_KEYS_REMAP = {} |
| 135 | + |
| 136 | + |
| 137 | +def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: |
| 138 | + state_dict[new_key] = state_dict.pop(old_key) |
| 139 | + |
| 140 | + |
| 141 | +def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: |
| 142 | + state_dict = saved_dict |
| 143 | + if "model" in saved_dict.keys(): |
| 144 | + state_dict = state_dict["model"] |
| 145 | + if "module" in saved_dict.keys(): |
| 146 | + state_dict = state_dict["module"] |
| 147 | + if "state_dict" in saved_dict.keys(): |
| 148 | + state_dict = state_dict["state_dict"] |
| 149 | + return state_dict |
| 150 | + |
| 151 | + |
| 152 | +def convert_transformer(ckpt_path: str): |
| 153 | + original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True)) |
| 154 | + |
| 155 | + with init_empty_weights(): |
| 156 | + transformer = HunyuanVideoTransformer3DModel() |
| 157 | + |
| 158 | + for key in list(original_state_dict.keys()): |
| 159 | + new_key = key[:] |
| 160 | + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): |
| 161 | + new_key = new_key.replace(replace_key, rename_key) |
| 162 | + update_state_dict_(original_state_dict, key, new_key) |
| 163 | + |
| 164 | + for key in list(original_state_dict.keys()): |
| 165 | + for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): |
| 166 | + if special_key not in key: |
| 167 | + continue |
| 168 | + handler_fn_inplace(key, original_state_dict) |
| 169 | + |
| 170 | + transformer.load_state_dict(original_state_dict, strict=True, assign=True) |
| 171 | + return transformer |
| 172 | + |
| 173 | + |
| 174 | +def convert_vae(ckpt_path: str): |
| 175 | + original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True)) |
| 176 | + |
| 177 | + with init_empty_weights(): |
| 178 | + vae = AutoencoderKLHunyuanVideo() |
| 179 | + |
| 180 | + for key in list(original_state_dict.keys()): |
| 181 | + new_key = key[:] |
| 182 | + for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): |
| 183 | + new_key = new_key.replace(replace_key, rename_key) |
| 184 | + update_state_dict_(original_state_dict, key, new_key) |
| 185 | + |
| 186 | + for key in list(original_state_dict.keys()): |
| 187 | + for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items(): |
| 188 | + if special_key not in key: |
| 189 | + continue |
| 190 | + handler_fn_inplace(key, original_state_dict) |
| 191 | + |
| 192 | + vae.load_state_dict(original_state_dict, strict=True, assign=True) |
| 193 | + return vae |
| 194 | + |
| 195 | + |
| 196 | +def get_args(): |
| 197 | + parser = argparse.ArgumentParser() |
| 198 | + parser.add_argument( |
| 199 | + "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" |
| 200 | + ) |
| 201 | + parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original VAE checkpoint") |
| 202 | + parser.add_argument("--text_encoder_path", type=str, default=None, help="Path to original llama checkpoint") |
| 203 | + parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to original llama tokenizer") |
| 204 | + parser.add_argument("--text_encoder_2_path", type=str, default=None, help="Path to original clip checkpoint") |
| 205 | + parser.add_argument("--save_pipeline", action="store_true") |
| 206 | + parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") |
| 207 | + parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.") |
| 208 | + return parser.parse_args() |
| 209 | + |
| 210 | + |
| 211 | +DTYPE_MAPPING = { |
| 212 | + "fp32": torch.float32, |
| 213 | + "fp16": torch.float16, |
| 214 | + "bf16": torch.bfloat16, |
| 215 | +} |
| 216 | + |
| 217 | + |
| 218 | +if __name__ == "__main__": |
| 219 | + args = get_args() |
| 220 | + |
| 221 | + transformer = None |
| 222 | + dtype = DTYPE_MAPPING[args.dtype] |
| 223 | + |
| 224 | + if args.save_pipeline: |
| 225 | + assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None |
| 226 | + assert args.text_encoder_path is not None |
| 227 | + assert args.tokenizer_path is not None |
| 228 | + assert args.text_encoder_2_path is not None |
| 229 | + |
| 230 | + if args.transformer_ckpt_path is not None: |
| 231 | + transformer = convert_transformer(args.transformer_ckpt_path) |
| 232 | + transformer = transformer.to(dtype=dtype) |
| 233 | + if not args.save_pipeline: |
| 234 | + transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") |
| 235 | + |
| 236 | + if args.vae_ckpt_path is not None: |
| 237 | + vae = convert_vae(args.vae_ckpt_path) |
| 238 | + if not args.save_pipeline: |
| 239 | + vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") |
| 240 | + |
| 241 | + if args.save_pipeline: |
| 242 | + text_encoder = AutoModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.float16) |
| 243 | + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right") |
| 244 | + text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16) |
| 245 | + tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path) |
| 246 | + scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0) |
| 247 | + |
| 248 | + pipe = HunyuanVideoPipeline( |
| 249 | + transformer=transformer, |
| 250 | + vae=vae, |
| 251 | + text_encoder=text_encoder, |
| 252 | + tokenizer=tokenizer, |
| 253 | + text_encoder_2=text_encoder_2, |
| 254 | + tokenizer_2=tokenizer_2, |
| 255 | + scheduler=scheduler, |
| 256 | + ) |
| 257 | + pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") |
0 commit comments