import keras_cv
import tensorflow as tf
from diffusers import (AutoencoderKL, StableDiffusionPipeline,
                       UNet2DConditionModel)
from diffusers.pipelines.stable_diffusion.safety_checker import \
    StableDiffusionSafetyChecker
from transformers import CLIPTextModel, CLIPTokenizer

from conversion_utils import populate_text_encoder, populate_unet

PRETRAINED_CKPT = "CompVis/stable-diffusion-v1-4"
REVISION = None
NON_EMA_REVISION = None
IMG_HEIGHT = IMG_WIDTH = 512
MAX_SEQ_LENGTH = 77


def initialize_pt_models(placeholder_token: str):
    """Initializes the separate models of Stable Diffusion from diffusers and downloads
    their pre-trained weights."""
    pt_text_encoder = CLIPTextModel.from_pretrained(
        PRETRAINED_CKPT, subfolder="text_encoder", revision=REVISION
    )
    pt_tokenizer = CLIPTokenizer.from_pretrained(PRETRAINED_CKPT, subfolder="tokenizer")
    pt_vae = AutoencoderKL.from_pretrained(
        PRETRAINED_CKPT, subfolder="vae", revision=REVISION
    )
    pt_unet = UNet2DConditionModel.from_pretrained(
        PRETRAINED_CKPT, subfolder="unet", revision=NON_EMA_REVISION
    )
    pt_safety_checker = StableDiffusionSafetyChecker.from_pretrained(
        PRETRAINED_CKPT, subfolder="safety_checker", revision=NON_EMA_REVISION
    )

    if placeholder_token is not None:
        num_added_tokens = pt_tokenizer.add_tokens(placeholder_token)
        if num_added_tokens == 0:
            raise ValueError(
                f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
                " `placeholder_token` that is not already in the tokenizer."
            )
        # Resize the token embeddings as we are adding new special tokens to the tokenizer
        pt_text_encoder.resize_token_embeddings(len(pt_tokenizer))

    return pt_text_encoder, pt_tokenizer, pt_vae, pt_unet, pt_safety_checker


def initialize_tf_models(
    text_encoder_weights: str, unet_weights: str, placeholder_token: str = None
):
    """Initializes the separate models of Stable Diffusion from KerasCV and optionally
    downloads their pre-trained weights."""
    tf_sd_model = keras_cv.models.StableDiffusion(
        img_height=IMG_HEIGHT, img_width=IMG_WIDTH
    )

    if text_encoder_weights is None:
        tf_text_encoder = tf_sd_model.text_encoder
    else:
        tf_text_encoder = keras_cv.models.stable_diffusion.TextEncoder(
            MAX_SEQ_LENGTH, download_weights=False
        )

    if unet_weights is None:
        tf_unet = tf_sd_model.diffusion_model
    else:
        tf_unet = keras_cv.models.stable_diffusion.DiffusionModel(
            IMG_HEIGHT, IMG_WIDTH, MAX_SEQ_LENGTH, download_weights=False
        )

    tf_tokenizer = tf_sd_model.tokenizer
    if placeholder_token is not None:
        tf_tokenizer.add_tokens(placeholder_token)

    return tf_text_encoder, tf_unet, tf_tokenizer


def create_new_text_encoder(tf_text_encoder, tf_tokenizer):
    """Initializes a fresh text encoder in case the weights are from Textual Inversion.

    Reference: https://keras.io/examples/generative/fine_tune_via_textual_inversion/
    """
    new_vocab_size = len(tf_tokenizer.vocab)
    new_text_encoder = keras_cv.models.stable_diffusion.TextEncoder(
        MAX_SEQ_LENGTH,
        vocab_size=new_vocab_size,
        download_weights=False,
    )

    old_position_weights = tf_text_encoder.layers[2].position_embedding.get_weights()
    new_text_encoder.layers[2].position_embedding.set_weights(old_position_weights)
    return new_text_encoder


def run_conversion(
    text_encoder_weights: str = None,
    unet_weights: str = None,
    placeholder_token: str = None,
):
    (
        pt_text_encoder,
        pt_tokenizer,
        pt_vae,
        pt_unet,
        pt_safety_checker,
    ) = initialize_pt_models(populate_text_encoder)
    tf_text_encoder, tf_unet, tf_tokenizer = initialize_tf_models(
        text_encoder_weights, unet_weights, placeholder_token
    )
    print("Pre-trained model weights downloaded.")

    if placeholder_token is not None:
        print("Initializing a new text encoder with the placeholder token...")
        tf_text_encoder = create_new_text_encoder(tf_text_encoder, tf_tokenizer)

    if text_encoder_weights is not None:
        print("Loading fine-tuned text encoder weights.")
        text_encoder_weights_path = tf.keras.utils.get_file(origin=text_encoder_weights)
        tf_text_encoder.load_weights(text_encoder_weights_path)
        text_encoder_state_dict_from_tf = populate_text_encoder(tf_text_encoder)
        pt_text_encoder.load_state_dict(text_encoder_state_dict_from_tf)
        print("Populated PT text encoder from TF weights.")

    if unet_weights is not None:
        print("Loading fine-tuned UNet weights.")
        unet_weights_path = tf.keras.utils.get_file(origin=unet_weights)
        tf_unet.load_weights(unet_weights_path)
        unet_state_dict_from_tf = populate_unet(tf_unet)
        pt_unet.load_state_dict(unet_state_dict_from_tf)
        print("Populated PT UNet from TF weights.")

    print("Weights ported, preparing StabelDiffusionPipeline...")
    pipeline = StableDiffusionPipeline.from_pretrained(
        PRETRAINED_CKPT,
        unet=pt_unet,
        text_encoder=pt_text_encoder,
        tokenizer=pt_tokenizer,
        vae=pt_vae,
        safety_checker=pt_safety_checker,
        revision=None,
    )
    return pipeline