|
import argparse |
|
import os |
|
|
|
import torch |
|
from safetensors.torch import load_file |
|
from transformers import AutoModel, AutoTokenizer |
|
|
|
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, LuminaNextDiT2DModel, LuminaText2ImgPipeline |
|
|
|
|
|
def main(args): |
|
|
|
all_sd = load_file(args.origin_ckpt_path, device="cpu") |
|
converted_state_dict = {} |
|
|
|
converted_state_dict["pad_token"] = all_sd["pad_token"] |
|
|
|
|
|
converted_state_dict["patch_embedder.weight"] = all_sd["x_embedder.weight"] |
|
converted_state_dict["patch_embedder.bias"] = all_sd["x_embedder.bias"] |
|
|
|
|
|
converted_state_dict["time_caption_embed.timestep_embedder.linear_1.weight"] = all_sd["t_embedder.mlp.0.weight"] |
|
converted_state_dict["time_caption_embed.timestep_embedder.linear_1.bias"] = all_sd["t_embedder.mlp.0.bias"] |
|
converted_state_dict["time_caption_embed.timestep_embedder.linear_2.weight"] = all_sd["t_embedder.mlp.2.weight"] |
|
converted_state_dict["time_caption_embed.timestep_embedder.linear_2.bias"] = all_sd["t_embedder.mlp.2.bias"] |
|
converted_state_dict["time_caption_embed.caption_embedder.0.weight"] = all_sd["cap_embedder.0.weight"] |
|
converted_state_dict["time_caption_embed.caption_embedder.0.bias"] = all_sd["cap_embedder.0.bias"] |
|
converted_state_dict["time_caption_embed.caption_embedder.1.weight"] = all_sd["cap_embedder.1.weight"] |
|
converted_state_dict["time_caption_embed.caption_embedder.1.bias"] = all_sd["cap_embedder.1.bias"] |
|
|
|
for i in range(24): |
|
|
|
converted_state_dict[f"layers.{i}.gate"] = all_sd[f"layers.{i}.attention.gate"] |
|
converted_state_dict[f"layers.{i}.adaLN_modulation.1.weight"] = all_sd[f"layers.{i}.adaLN_modulation.1.weight"] |
|
converted_state_dict[f"layers.{i}.adaLN_modulation.1.bias"] = all_sd[f"layers.{i}.adaLN_modulation.1.bias"] |
|
|
|
|
|
converted_state_dict[f"layers.{i}.attn1.to_q.weight"] = all_sd[f"layers.{i}.attention.wq.weight"] |
|
converted_state_dict[f"layers.{i}.attn1.to_k.weight"] = all_sd[f"layers.{i}.attention.wk.weight"] |
|
converted_state_dict[f"layers.{i}.attn1.to_v.weight"] = all_sd[f"layers.{i}.attention.wv.weight"] |
|
|
|
|
|
converted_state_dict[f"layers.{i}.attn2.to_q.weight"] = all_sd[f"layers.{i}.attention.wq.weight"] |
|
converted_state_dict[f"layers.{i}.attn2.to_k.weight"] = all_sd[f"layers.{i}.attention.wk_y.weight"] |
|
converted_state_dict[f"layers.{i}.attn2.to_v.weight"] = all_sd[f"layers.{i}.attention.wv_y.weight"] |
|
|
|
|
|
converted_state_dict[f"layers.{i}.attn2.to_out.0.weight"] = all_sd[f"layers.{i}.attention.wo.weight"] |
|
|
|
|
|
|
|
converted_state_dict[f"layers.{i}.attn1.norm_q.weight"] = all_sd[f"layers.{i}.attention.q_norm.weight"] |
|
converted_state_dict[f"layers.{i}.attn1.norm_q.bias"] = all_sd[f"layers.{i}.attention.q_norm.bias"] |
|
|
|
converted_state_dict[f"layers.{i}.attn1.norm_k.weight"] = all_sd[f"layers.{i}.attention.k_norm.weight"] |
|
converted_state_dict[f"layers.{i}.attn1.norm_k.bias"] = all_sd[f"layers.{i}.attention.k_norm.bias"] |
|
|
|
converted_state_dict[f"layers.{i}.attn2.norm_q.weight"] = all_sd[f"layers.{i}.attention.q_norm.weight"] |
|
converted_state_dict[f"layers.{i}.attn2.norm_q.bias"] = all_sd[f"layers.{i}.attention.q_norm.bias"] |
|
|
|
converted_state_dict[f"layers.{i}.attn2.norm_k.weight"] = all_sd[f"layers.{i}.attention.ky_norm.weight"] |
|
converted_state_dict[f"layers.{i}.attn2.norm_k.bias"] = all_sd[f"layers.{i}.attention.ky_norm.bias"] |
|
|
|
|
|
converted_state_dict[f"layers.{i}.attn_norm1.weight"] = all_sd[f"layers.{i}.attention_norm1.weight"] |
|
converted_state_dict[f"layers.{i}.attn_norm2.weight"] = all_sd[f"layers.{i}.attention_norm2.weight"] |
|
converted_state_dict[f"layers.{i}.norm1_context.weight"] = all_sd[f"layers.{i}.attention_y_norm.weight"] |
|
|
|
|
|
converted_state_dict[f"layers.{i}.feed_forward.linear_1.weight"] = all_sd[f"layers.{i}.feed_forward.w1.weight"] |
|
converted_state_dict[f"layers.{i}.feed_forward.linear_2.weight"] = all_sd[f"layers.{i}.feed_forward.w2.weight"] |
|
converted_state_dict[f"layers.{i}.feed_forward.linear_3.weight"] = all_sd[f"layers.{i}.feed_forward.w3.weight"] |
|
|
|
|
|
converted_state_dict[f"layers.{i}.ffn_norm1.weight"] = all_sd[f"layers.{i}.ffn_norm1.weight"] |
|
converted_state_dict[f"layers.{i}.ffn_norm2.weight"] = all_sd[f"layers.{i}.ffn_norm2.weight"] |
|
|
|
|
|
converted_state_dict["final_layer.linear.weight"] = all_sd["final_layer.linear.weight"] |
|
converted_state_dict["final_layer.linear.bias"] = all_sd["final_layer.linear.bias"] |
|
|
|
converted_state_dict["final_layer.adaLN_modulation.1.weight"] = all_sd["final_layer.adaLN_modulation.1.weight"] |
|
converted_state_dict["final_layer.adaLN_modulation.1.bias"] = all_sd["final_layer.adaLN_modulation.1.bias"] |
|
|
|
|
|
transformer = LuminaNextDiT2DModel( |
|
sample_size=128, |
|
patch_size=2, |
|
in_channels=4, |
|
hidden_size=2304, |
|
num_layers=24, |
|
num_attention_heads=32, |
|
num_kv_heads=8, |
|
multiple_of=256, |
|
ffn_dim_multiplier=None, |
|
norm_eps=1e-5, |
|
learn_sigma=True, |
|
qk_norm=True, |
|
cross_attention_dim=2048, |
|
scaling_factor=1.0, |
|
) |
|
transformer.load_state_dict(converted_state_dict, strict=True) |
|
|
|
num_model_params = sum(p.numel() for p in transformer.parameters()) |
|
print(f"Total number of transformer parameters: {num_model_params}") |
|
|
|
if args.only_transformer: |
|
transformer.save_pretrained(os.path.join(args.dump_path, "transformer")) |
|
else: |
|
scheduler = FlowMatchEulerDiscreteScheduler() |
|
|
|
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae", torch_dtype=torch.float32) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b") |
|
text_encoder = AutoModel.from_pretrained("google/gemma-2b") |
|
|
|
pipeline = LuminaText2ImgPipeline( |
|
tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, scheduler=scheduler |
|
) |
|
pipeline.save_pretrained(args.dump_path) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument( |
|
"--origin_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert." |
|
) |
|
parser.add_argument( |
|
"--image_size", |
|
default=1024, |
|
type=int, |
|
choices=[256, 512, 1024], |
|
required=False, |
|
help="Image size of pretrained model, either 512 or 1024.", |
|
) |
|
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.") |
|
parser.add_argument("--only_transformer", default=True, type=bool, required=True) |
|
|
|
args = parser.parse_args() |
|
main(args) |
|
|