|
import argparse |
|
|
|
import torch |
|
|
|
from diffusers import HunyuanDiT2DModel |
|
|
|
|
|
def main(args): |
|
state_dict = torch.load(args.pt_checkpoint_path, map_location="cpu") |
|
|
|
if args.load_key != "none": |
|
try: |
|
state_dict = state_dict[args.load_key] |
|
except KeyError: |
|
raise KeyError( |
|
f"{args.load_key} not found in the checkpoint." |
|
f"Please load from the following keys:{state_dict.keys()}" |
|
) |
|
|
|
device = "cuda" |
|
model_config = HunyuanDiT2DModel.load_config("Tencent-Hunyuan/HunyuanDiT-Diffusers", subfolder="transformer") |
|
model_config[ |
|
"use_style_cond_and_image_meta_size" |
|
] = args.use_style_cond_and_image_meta_size |
|
|
|
|
|
for key in state_dict: |
|
print("local:", key) |
|
|
|
model = HunyuanDiT2DModel.from_config(model_config).to(device) |
|
|
|
for key in model.state_dict(): |
|
print("diffusers:", key) |
|
|
|
num_layers = 40 |
|
for i in range(num_layers): |
|
|
|
|
|
q, k, v = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.weight"], 3, dim=0) |
|
q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.bias"], 3, dim=0) |
|
state_dict[f"blocks.{i}.attn1.to_q.weight"] = q |
|
state_dict[f"blocks.{i}.attn1.to_q.bias"] = q_bias |
|
state_dict[f"blocks.{i}.attn1.to_k.weight"] = k |
|
state_dict[f"blocks.{i}.attn1.to_k.bias"] = k_bias |
|
state_dict[f"blocks.{i}.attn1.to_v.weight"] = v |
|
state_dict[f"blocks.{i}.attn1.to_v.bias"] = v_bias |
|
state_dict.pop(f"blocks.{i}.attn1.Wqkv.weight") |
|
state_dict.pop(f"blocks.{i}.attn1.Wqkv.bias") |
|
|
|
|
|
state_dict[f"blocks.{i}.attn1.norm_q.weight"] = state_dict[f"blocks.{i}.attn1.q_norm.weight"] |
|
state_dict[f"blocks.{i}.attn1.norm_q.bias"] = state_dict[f"blocks.{i}.attn1.q_norm.bias"] |
|
state_dict[f"blocks.{i}.attn1.norm_k.weight"] = state_dict[f"blocks.{i}.attn1.k_norm.weight"] |
|
state_dict[f"blocks.{i}.attn1.norm_k.bias"] = state_dict[f"blocks.{i}.attn1.k_norm.bias"] |
|
|
|
state_dict.pop(f"blocks.{i}.attn1.q_norm.weight") |
|
state_dict.pop(f"blocks.{i}.attn1.q_norm.bias") |
|
state_dict.pop(f"blocks.{i}.attn1.k_norm.weight") |
|
state_dict.pop(f"blocks.{i}.attn1.k_norm.bias") |
|
|
|
|
|
state_dict[f"blocks.{i}.attn1.to_out.0.weight"] = state_dict[f"blocks.{i}.attn1.out_proj.weight"] |
|
state_dict[f"blocks.{i}.attn1.to_out.0.bias"] = state_dict[f"blocks.{i}.attn1.out_proj.bias"] |
|
state_dict.pop(f"blocks.{i}.attn1.out_proj.weight") |
|
state_dict.pop(f"blocks.{i}.attn1.out_proj.bias") |
|
|
|
|
|
|
|
k, v = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.weight"], 2, dim=0) |
|
k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.bias"], 2, dim=0) |
|
state_dict[f"blocks.{i}.attn2.to_k.weight"] = k |
|
state_dict[f"blocks.{i}.attn2.to_k.bias"] = k_bias |
|
state_dict[f"blocks.{i}.attn2.to_v.weight"] = v |
|
state_dict[f"blocks.{i}.attn2.to_v.bias"] = v_bias |
|
state_dict.pop(f"blocks.{i}.attn2.kv_proj.weight") |
|
state_dict.pop(f"blocks.{i}.attn2.kv_proj.bias") |
|
|
|
|
|
state_dict[f"blocks.{i}.attn2.to_q.weight"] = state_dict[f"blocks.{i}.attn2.q_proj.weight"] |
|
state_dict[f"blocks.{i}.attn2.to_q.bias"] = state_dict[f"blocks.{i}.attn2.q_proj.bias"] |
|
state_dict.pop(f"blocks.{i}.attn2.q_proj.weight") |
|
state_dict.pop(f"blocks.{i}.attn2.q_proj.bias") |
|
|
|
|
|
state_dict[f"blocks.{i}.attn2.norm_q.weight"] = state_dict[f"blocks.{i}.attn2.q_norm.weight"] |
|
state_dict[f"blocks.{i}.attn2.norm_q.bias"] = state_dict[f"blocks.{i}.attn2.q_norm.bias"] |
|
state_dict[f"blocks.{i}.attn2.norm_k.weight"] = state_dict[f"blocks.{i}.attn2.k_norm.weight"] |
|
state_dict[f"blocks.{i}.attn2.norm_k.bias"] = state_dict[f"blocks.{i}.attn2.k_norm.bias"] |
|
|
|
state_dict.pop(f"blocks.{i}.attn2.q_norm.weight") |
|
state_dict.pop(f"blocks.{i}.attn2.q_norm.bias") |
|
state_dict.pop(f"blocks.{i}.attn2.k_norm.weight") |
|
state_dict.pop(f"blocks.{i}.attn2.k_norm.bias") |
|
|
|
|
|
state_dict[f"blocks.{i}.attn2.to_out.0.weight"] = state_dict[f"blocks.{i}.attn2.out_proj.weight"] |
|
state_dict[f"blocks.{i}.attn2.to_out.0.bias"] = state_dict[f"blocks.{i}.attn2.out_proj.bias"] |
|
state_dict.pop(f"blocks.{i}.attn2.out_proj.weight") |
|
state_dict.pop(f"blocks.{i}.attn2.out_proj.bias") |
|
|
|
|
|
norm2_weight = state_dict[f"blocks.{i}.norm2.weight"] |
|
norm2_bias = state_dict[f"blocks.{i}.norm2.bias"] |
|
state_dict[f"blocks.{i}.norm2.weight"] = state_dict[f"blocks.{i}.norm3.weight"] |
|
state_dict[f"blocks.{i}.norm2.bias"] = state_dict[f"blocks.{i}.norm3.bias"] |
|
state_dict[f"blocks.{i}.norm3.weight"] = norm2_weight |
|
state_dict[f"blocks.{i}.norm3.bias"] = norm2_bias |
|
|
|
|
|
|
|
state_dict[f"blocks.{i}.norm1.norm.weight"] = state_dict[f"blocks.{i}.norm1.weight"] |
|
state_dict[f"blocks.{i}.norm1.norm.bias"] = state_dict[f"blocks.{i}.norm1.bias"] |
|
state_dict[f"blocks.{i}.norm1.linear.weight"] = state_dict[f"blocks.{i}.default_modulation.1.weight"] |
|
state_dict[f"blocks.{i}.norm1.linear.bias"] = state_dict[f"blocks.{i}.default_modulation.1.bias"] |
|
state_dict.pop(f"blocks.{i}.norm1.weight") |
|
state_dict.pop(f"blocks.{i}.norm1.bias") |
|
state_dict.pop(f"blocks.{i}.default_modulation.1.weight") |
|
state_dict.pop(f"blocks.{i}.default_modulation.1.bias") |
|
|
|
|
|
state_dict[f"blocks.{i}.ff.net.0.proj.weight"] = state_dict[f"blocks.{i}.mlp.fc1.weight"] |
|
state_dict[f"blocks.{i}.ff.net.0.proj.bias"] = state_dict[f"blocks.{i}.mlp.fc1.bias"] |
|
state_dict[f"blocks.{i}.ff.net.2.weight"] = state_dict[f"blocks.{i}.mlp.fc2.weight"] |
|
state_dict[f"blocks.{i}.ff.net.2.bias"] = state_dict[f"blocks.{i}.mlp.fc2.bias"] |
|
state_dict.pop(f"blocks.{i}.mlp.fc1.weight") |
|
state_dict.pop(f"blocks.{i}.mlp.fc1.bias") |
|
state_dict.pop(f"blocks.{i}.mlp.fc2.weight") |
|
state_dict.pop(f"blocks.{i}.mlp.fc2.bias") |
|
|
|
|
|
state_dict["time_extra_emb.pooler.positional_embedding"] = state_dict["pooler.positional_embedding"] |
|
state_dict["time_extra_emb.pooler.k_proj.weight"] = state_dict["pooler.k_proj.weight"] |
|
state_dict["time_extra_emb.pooler.k_proj.bias"] = state_dict["pooler.k_proj.bias"] |
|
state_dict["time_extra_emb.pooler.q_proj.weight"] = state_dict["pooler.q_proj.weight"] |
|
state_dict["time_extra_emb.pooler.q_proj.bias"] = state_dict["pooler.q_proj.bias"] |
|
state_dict["time_extra_emb.pooler.v_proj.weight"] = state_dict["pooler.v_proj.weight"] |
|
state_dict["time_extra_emb.pooler.v_proj.bias"] = state_dict["pooler.v_proj.bias"] |
|
state_dict["time_extra_emb.pooler.c_proj.weight"] = state_dict["pooler.c_proj.weight"] |
|
state_dict["time_extra_emb.pooler.c_proj.bias"] = state_dict["pooler.c_proj.bias"] |
|
state_dict.pop("pooler.k_proj.weight") |
|
state_dict.pop("pooler.k_proj.bias") |
|
state_dict.pop("pooler.q_proj.weight") |
|
state_dict.pop("pooler.q_proj.bias") |
|
state_dict.pop("pooler.v_proj.weight") |
|
state_dict.pop("pooler.v_proj.bias") |
|
state_dict.pop("pooler.c_proj.weight") |
|
state_dict.pop("pooler.c_proj.bias") |
|
state_dict.pop("pooler.positional_embedding") |
|
|
|
|
|
state_dict["time_extra_emb.timestep_embedder.linear_1.bias"] = state_dict["t_embedder.mlp.0.bias"] |
|
state_dict["time_extra_emb.timestep_embedder.linear_1.weight"] = state_dict["t_embedder.mlp.0.weight"] |
|
state_dict["time_extra_emb.timestep_embedder.linear_2.bias"] = state_dict["t_embedder.mlp.2.bias"] |
|
state_dict["time_extra_emb.timestep_embedder.linear_2.weight"] = state_dict["t_embedder.mlp.2.weight"] |
|
|
|
state_dict.pop("t_embedder.mlp.0.bias") |
|
state_dict.pop("t_embedder.mlp.0.weight") |
|
state_dict.pop("t_embedder.mlp.2.bias") |
|
state_dict.pop("t_embedder.mlp.2.weight") |
|
|
|
|
|
state_dict["pos_embed.proj.weight"] = state_dict["x_embedder.proj.weight"] |
|
state_dict["pos_embed.proj.bias"] = state_dict["x_embedder.proj.bias"] |
|
state_dict.pop("x_embedder.proj.weight") |
|
state_dict.pop("x_embedder.proj.bias") |
|
|
|
|
|
state_dict["text_embedder.linear_1.bias"] = state_dict["mlp_t5.0.bias"] |
|
state_dict["text_embedder.linear_1.weight"] = state_dict["mlp_t5.0.weight"] |
|
state_dict["text_embedder.linear_2.bias"] = state_dict["mlp_t5.2.bias"] |
|
state_dict["text_embedder.linear_2.weight"] = state_dict["mlp_t5.2.weight"] |
|
state_dict.pop("mlp_t5.0.bias") |
|
state_dict.pop("mlp_t5.0.weight") |
|
state_dict.pop("mlp_t5.2.bias") |
|
state_dict.pop("mlp_t5.2.weight") |
|
|
|
|
|
state_dict["time_extra_emb.extra_embedder.linear_1.bias"] = state_dict["extra_embedder.0.bias"] |
|
state_dict["time_extra_emb.extra_embedder.linear_1.weight"] = state_dict["extra_embedder.0.weight"] |
|
state_dict["time_extra_emb.extra_embedder.linear_2.bias"] = state_dict["extra_embedder.2.bias"] |
|
state_dict["time_extra_emb.extra_embedder.linear_2.weight"] = state_dict["extra_embedder.2.weight"] |
|
state_dict.pop("extra_embedder.0.bias") |
|
state_dict.pop("extra_embedder.0.weight") |
|
state_dict.pop("extra_embedder.2.bias") |
|
state_dict.pop("extra_embedder.2.weight") |
|
|
|
|
|
def swap_scale_shift(weight): |
|
shift, scale = weight.chunk(2, dim=0) |
|
new_weight = torch.cat([scale, shift], dim=0) |
|
return new_weight |
|
|
|
state_dict["norm_out.linear.weight"] = swap_scale_shift(state_dict["final_layer.adaLN_modulation.1.weight"]) |
|
state_dict["norm_out.linear.bias"] = swap_scale_shift(state_dict["final_layer.adaLN_modulation.1.bias"]) |
|
state_dict.pop("final_layer.adaLN_modulation.1.weight") |
|
state_dict.pop("final_layer.adaLN_modulation.1.bias") |
|
|
|
|
|
state_dict["proj_out.weight"] = state_dict["final_layer.linear.weight"] |
|
state_dict["proj_out.bias"] = state_dict["final_layer.linear.bias"] |
|
state_dict.pop("final_layer.linear.weight") |
|
state_dict.pop("final_layer.linear.bias") |
|
|
|
|
|
if model_config["use_style_cond_and_image_meta_size"]: |
|
print(state_dict["style_embedder.weight"]) |
|
print(state_dict["style_embedder.weight"].shape) |
|
state_dict["time_extra_emb.style_embedder.weight"] = state_dict["style_embedder.weight"][0:1] |
|
state_dict.pop("style_embedder.weight") |
|
|
|
model.load_state_dict(state_dict) |
|
|
|
from diffusers import HunyuanDiTPipeline |
|
|
|
if args.use_style_cond_and_image_meta_size: |
|
pipe = HunyuanDiTPipeline.from_pretrained( |
|
"Tencent-Hunyuan/HunyuanDiT-Diffusers", transformer=model, torch_dtype=torch.float32 |
|
) |
|
else: |
|
pipe = HunyuanDiTPipeline.from_pretrained( |
|
"Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", transformer=model, torch_dtype=torch.float32 |
|
) |
|
pipe.to("cuda") |
|
pipe.to(dtype=torch.float32) |
|
|
|
if args.save: |
|
pipe.save_pretrained(args.output_checkpoint_path) |
|
|
|
|
|
prompt = "一个宇航员在骑马" |
|
|
|
generator = torch.Generator(device="cuda").manual_seed(0) |
|
image = pipe( |
|
height=1024, width=1024, prompt=prompt, generator=generator, num_inference_steps=25, guidance_scale=5.0 |
|
).images[0] |
|
|
|
image.save("img.png") |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument( |
|
"--save", default=True, type=bool, required=False, help="Whether to save the converted pipeline or not." |
|
) |
|
parser.add_argument( |
|
"--pt_checkpoint_path", default=None, type=str, required=True, help="Path to the .pt pretrained model." |
|
) |
|
parser.add_argument( |
|
"--output_checkpoint_path", |
|
default=None, |
|
type=str, |
|
required=False, |
|
help="Path to the output converted diffusers pipeline.", |
|
) |
|
parser.add_argument( |
|
"--load_key", default="none", type=str, required=False, help="The key to load from the pretrained .pt file" |
|
) |
|
parser.add_argument( |
|
"--use_style_cond_and_image_meta_size", |
|
type=bool, |
|
default=False, |
|
help="version <= v1.1: True; version >= v1.2: False", |
|
) |
|
|
|
args = parser.parse_args() |
|
main(args) |
|
|