# Run this script to convert the Stable Cascade model weights to a diffusers pipeline.
import argparse
from contextlib import nullcontext

import torch
from safetensors.torch import load_file
from transformers import (
    AutoTokenizer,
    CLIPConfig,
    CLIPImageProcessor,
    CLIPTextModelWithProjection,
    CLIPVisionModelWithProjection,
)

from diffusers import (
    DDPMWuerstchenScheduler,
    StableCascadeCombinedPipeline,
    StableCascadeDecoderPipeline,
    StableCascadePriorPipeline,
)
from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers
from diffusers.models import StableCascadeUNet
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.pipelines.wuerstchen import PaellaVQModel
from diffusers.utils import is_accelerate_available


if is_accelerate_available():
    from accelerate import init_empty_weights

parser = argparse.ArgumentParser(description="Convert Stable Cascade model weights to a diffusers pipeline")
parser.add_argument("--model_path", type=str, help="Location of Stable Cascade weights")
parser.add_argument(
    "--stage_c_name", type=str, default="stage_c_lite.safetensors", help="Name of stage c checkpoint file"
)
parser.add_argument(
    "--stage_b_name", type=str, default="stage_b_lite.safetensors", help="Name of stage b checkpoint file"
)
parser.add_argument("--skip_stage_c", action="store_true", help="Skip converting stage c")
parser.add_argument("--skip_stage_b", action="store_true", help="Skip converting stage b")
parser.add_argument("--use_safetensors", action="store_true", help="Use SafeTensors for conversion")
parser.add_argument(
    "--prior_output_path",
    default="stable-cascade-prior-lite",
    type=str,
    help="Hub organization to save the pipelines to",
)
parser.add_argument(
    "--decoder_output_path",
    type=str,
    default="stable-cascade-decoder-lite",
    help="Hub organization to save the pipelines to",
)
parser.add_argument(
    "--combined_output_path",
    type=str,
    default="stable-cascade-combined-lite",
    help="Hub organization to save the pipelines to",
)
parser.add_argument("--save_combined", action="store_true")
parser.add_argument("--push_to_hub", action="store_true", help="Push to hub")
parser.add_argument("--variant", type=str, help="Set to bf16 to save bfloat16 weights")

args = parser.parse_args()

if args.skip_stage_b and args.skip_stage_c:
    raise ValueError("At least one stage should be converted")
if (args.skip_stage_b or args.skip_stage_c) and args.save_combined:
    raise ValueError("Cannot skip stages when creating a combined pipeline")

model_path = args.model_path

device = "cpu"
if args.variant == "bf16":
    dtype = torch.bfloat16
else:
    dtype = torch.float32

# set paths to model weights
prior_checkpoint_path = f"{model_path}/{args.stage_c_name}"
decoder_checkpoint_path = f"{model_path}/{args.stage_b_name}"

# Clip Text encoder and tokenizer
config = CLIPConfig.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
config.text_config.projection_dim = config.projection_dim
text_encoder = CLIPTextModelWithProjection.from_pretrained(
    "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", config=config.text_config
)
tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")

# image processor
feature_extractor = CLIPImageProcessor()
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
# scheduler for prior and decoder
scheduler = DDPMWuerstchenScheduler()

ctx = init_empty_weights if is_accelerate_available() else nullcontext

if not args.skip_stage_c:
    # Prior
    if args.use_safetensors:
        prior_orig_state_dict = load_file(prior_checkpoint_path, device=device)
    else:
        prior_orig_state_dict = torch.load(prior_checkpoint_path, map_location=device)

    prior_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(prior_orig_state_dict)
    with ctx():
        prior_model = StableCascadeUNet(
            in_channels=16,
            out_channels=16,
            timestep_ratio_embedding_dim=64,
            patch_size=1,
            conditioning_dim=1536,
            block_out_channels=[1536, 1536],
            num_attention_heads=[24, 24],
            down_num_layers_per_block=[4, 12],
            up_num_layers_per_block=[12, 4],
            down_blocks_repeat_mappers=[1, 1],
            up_blocks_repeat_mappers=[1, 1],
            block_types_per_layer=[
                ["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
                ["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
            ],
            clip_text_in_channels=1280,
            clip_text_pooled_in_channels=1280,
            clip_image_in_channels=768,
            clip_seq=4,
            kernel_size=3,
            dropout=[0.1, 0.1],
            self_attn=True,
            timestep_conditioning_type=["sca", "crp"],
            switch_level=[False],
        )

    if is_accelerate_available():
        load_model_dict_into_meta(prior_model, prior_state_dict)
    else:
        prior_model.load_state_dict(prior_state_dict)

    # Prior pipeline
    prior_pipeline = StableCascadePriorPipeline(
        prior=prior_model,
        tokenizer=tokenizer,
        text_encoder=text_encoder,
        image_encoder=image_encoder,
        scheduler=scheduler,
        feature_extractor=feature_extractor,
    )
    prior_pipeline.to(dtype).save_pretrained(
        args.prior_output_path, push_to_hub=args.push_to_hub, variant=args.variant
    )

if not args.skip_stage_b:
    # Decoder
    if args.use_safetensors:
        decoder_orig_state_dict = load_file(decoder_checkpoint_path, device=device)
    else:
        decoder_orig_state_dict = torch.load(decoder_checkpoint_path, map_location=device)

    decoder_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(decoder_orig_state_dict)

    with ctx():
        decoder = StableCascadeUNet(
            in_channels=4,
            out_channels=4,
            timestep_ratio_embedding_dim=64,
            patch_size=2,
            conditioning_dim=1280,
            block_out_channels=[320, 576, 1152, 1152],
            down_num_layers_per_block=[2, 4, 14, 4],
            up_num_layers_per_block=[4, 14, 4, 2],
            down_blocks_repeat_mappers=[1, 1, 1, 1],
            up_blocks_repeat_mappers=[2, 2, 2, 2],
            num_attention_heads=[0, 9, 18, 18],
            block_types_per_layer=[
                ["SDCascadeResBlock", "SDCascadeTimestepBlock"],
                ["SDCascadeResBlock", "SDCascadeTimestepBlock"],
                ["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
                ["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
            ],
            clip_text_pooled_in_channels=1280,
            clip_seq=4,
            effnet_in_channels=16,
            pixel_mapper_in_channels=3,
            kernel_size=3,
            dropout=[0, 0, 0.1, 0.1],
            self_attn=True,
            timestep_conditioning_type=["sca"],
        )

    if is_accelerate_available():
        load_model_dict_into_meta(decoder, decoder_state_dict)
    else:
        decoder.load_state_dict(decoder_state_dict)

    # VQGAN from Wuerstchen-V2
    vqmodel = PaellaVQModel.from_pretrained("warp-ai/wuerstchen", subfolder="vqgan")

    # Decoder pipeline
    decoder_pipeline = StableCascadeDecoderPipeline(
        decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, vqgan=vqmodel, scheduler=scheduler
    )
    decoder_pipeline.to(dtype).save_pretrained(
        args.decoder_output_path, push_to_hub=args.push_to_hub, variant=args.variant
    )

if args.save_combined:
    # Stable Cascade combined pipeline
    stable_cascade_pipeline = StableCascadeCombinedPipeline(
        # Decoder
        text_encoder=text_encoder,
        tokenizer=tokenizer,
        decoder=decoder,
        scheduler=scheduler,
        vqgan=vqmodel,
        # Prior
        prior_text_encoder=text_encoder,
        prior_tokenizer=tokenizer,
        prior_prior=prior_model,
        prior_scheduler=scheduler,
        prior_image_encoder=image_encoder,
        prior_feature_extractor=feature_extractor,
    )
    stable_cascade_pipeline.to(dtype).save_pretrained(
        args.combined_output_path, push_to_hub=args.push_to_hub, variant=args.variant
    )