import torch
from enum import Enum
import gc
import numpy as np
import jax.numpy as jnp
import tomesd
import jax

from flax.training.common_utils import shard
from flax.jax_utils import replicate
from flax import jax_utils
import einops

from transformers import CLIPTokenizer, CLIPFeatureExtractor, FlaxCLIPTextModel
from diffusers import (
    FlaxDDIMScheduler,
    FlaxAutoencoderKL,
    FlaxStableDiffusionControlNetPipeline,
    StableDiffusionPipeline,
)
from text_to_animation.models.unet_2d_condition_flax import FlaxUNet2DConditionModel
from text_to_animation.models.controlnet_flax import FlaxControlNetModel

from text_to_animation.pipelines.text_to_video_pipeline_flax import (
    FlaxTextToVideoPipeline,
)

import utils.utils as utils
import utils.gradio_utils as gradio_utils
import os

on_huggingspace = os.environ.get("SPACE_AUTHOR_NAME") == "PAIR"

unshard = lambda x: einops.rearrange(x, "d b ... -> (d b) ...")


class ModelType(Enum):
    Text2Video = 1
    ControlNetPose = 2
    StableDiffusion = 3


def replicate_devices(array):
    return jnp.expand_dims(array, 0).repeat(jax.device_count(), 0)


class ControlAnimationModel:
    def __init__(self, device, dtype, **kwargs):
        self.device = device
        self.dtype = dtype
        self.rng = jax.random.PRNGKey(0)
        self.pipe_dict = {
            ModelType.Text2Video: FlaxTextToVideoPipeline,  # TODO: Replace with our TextToVideo JAX Pipeline
            ModelType.ControlNetPose: FlaxStableDiffusionControlNetPipeline,
        }
        self.pipe = None
        self.model_type = None

        self.states = {}
        self.model_name = ""

        self.from_local = True  # if the attn model is available in local (after adaptation by adapt_attn.py)

    def set_model(
        self,
        model_type: ModelType,
        model_id: str,
        controlnet,
        controlnet_params,
        tokenizer,
        scheduler,
        scheduler_state,
        **kwargs,
    ):
        if hasattr(self, "pipe") and self.pipe is not None:
            del self.pipe
            self.pipe = None
        gc.collect()
        scheduler, scheduler_state = FlaxDDIMScheduler.from_pretrained(
            model_id, subfolder="scheduler", from_pt=True
        )
        tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
        feature_extractor = CLIPFeatureExtractor.from_pretrained(
            model_id, subfolder="feature_extractor"
        )
        if self.from_local:
            unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
                f'./{model_id.split("/")[-1]}',
                subfolder="unet",
                from_pt=True,
                dtype=self.dtype,
            )
        else:
            unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
                model_id, subfolder="unet", from_pt=True, dtype=self.dtype
            )
        vae, vae_params = FlaxAutoencoderKL.from_pretrained(
            model_id, subfolder="vae", from_pt=True, dtype=self.dtype
        )
        text_encoder = FlaxCLIPTextModel.from_pretrained(
            model_id, subfolder="text_encoder", from_pt=True, dtype=self.dtype
        )
        self.pipe = FlaxTextToVideoPipeline(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=unet,
            controlnet=controlnet,
            scheduler=scheduler,
            safety_checker=None,
            feature_extractor=feature_extractor,
        )
        self.params = {
            "unet": unet_params,
            "vae": vae_params,
            "scheduler": scheduler_state,
            "controlnet": controlnet_params,
            "text_encoder": text_encoder.params,
        }
        self.p_params = jax_utils.replicate(self.params)

        self.model_type = model_type
        self.model_name = model_id

    # def inference_chunk(self, image, frame_ids, prompt, negative_prompt, **kwargs):

    #     prompt_ids = self.pipe.prepare_text_inputs(prompt)
    #     n_prompt_ids = self.pipe.prepare_text_inputs(negative_prompt)
    #     latents = kwargs.pop('latents')
    #     # rng = jax.random.split(self.rng, jax.device_count())
    #     prng, self.rng = jax.random.split(self.rng)
    #     #prng = jax.numpy.stack([prng] * jax.device_count())#same prng seed on every device
    #     prng_seed = jax.random.split(prng, jax.device_count())
    #     image = replicate_devices(image[frame_ids])
    #     latents = replicate_devices(latents)
    #     prompt_ids = replicate_devices(prompt_ids)
    #     n_prompt_ids = replicate_devices(n_prompt_ids)
    #     return (self.pipe(image=image,
    #                         latents=latents,
    #                         prompt_ids=prompt_ids,
    #                         neg_prompt_ids=n_prompt_ids,
    #                         params=self.p_params,
    #                         prng_seed=prng_seed, jit = True,
    #                         ).images)[0]

    def inference(self, image, split_to_chunks=False, chunk_size=8, **kwargs):
        if not hasattr(self, "pipe") or self.pipe is None:
            return

        if "merging_ratio" in kwargs:
            merging_ratio = kwargs.pop("merging_ratio")

            # if merging_ratio > 0:
            tomesd.apply_patch(self.pipe, ratio=merging_ratio)

        # f = image.shape[0]

        assert "prompt" in kwargs
        prompt = [kwargs.pop("prompt")]
        negative_prompt = [kwargs.pop("negative_prompt", "")]

        frames_counter = 0

        # Processing chunk-by-chunk
        if split_to_chunks:
            pass
            # # not tested
            # f = image.shape[0]
            # chunk_ids = np.arange(0, f, chunk_size - 1)
            # result = []
            # for i in range(len(chunk_ids)):
            #     ch_start = chunk_ids[i]
            #     ch_end = f if i == len(chunk_ids) - 1 else chunk_ids[i + 1]
            #     frame_ids = [0] + list(range(ch_start, ch_end))
            #     print(f'Processing chunk {i + 1} / {len(chunk_ids)}')
            #     result.append(self.inference_chunk(image=image,
            #                                        frame_ids=frame_ids,
            #                                        prompt=prompt,
            #                                        negative_prompt=negative_prompt,
            #                                        **kwargs).images[1:])
            #     frames_counter += len(chunk_ids)-1
            #     if on_huggingspace and frames_counter >= 80:
            #         break
            # result = np.concatenate(result)
            # return result
        else:
            if "jit" in kwargs and kwargs.pop("jit"):
                prompt_ids = self.pipe.prepare_text_inputs(prompt)
                n_prompt_ids = self.pipe.prepare_text_inputs(negative_prompt)
                latents = kwargs.pop("latents")
                prng, self.rng = jax.random.split(self.rng)
                prng_seed = jax.random.split(prng, jax.device_count())
                image = replicate_devices(image)
                latents = replicate_devices(latents)
                prompt_ids = replicate_devices(prompt_ids)
                n_prompt_ids = replicate_devices(n_prompt_ids)
                return (
                    self.pipe(
                        image=image,
                        latents=latents,
                        prompt_ids=prompt_ids,
                        neg_prompt_ids=n_prompt_ids,
                        params=self.p_params,
                        prng_seed=prng_seed,
                        jit=True,
                    ).images
                )[0]
            else:
                prompt_ids = self.pipe.prepare_text_inputs(prompt)
                n_prompt_ids = self.pipe.prepare_text_inputs(negative_prompt)
                latents = kwargs.pop("latents")
                prng_seed, self.rng = jax.random.split(self.rng)
                return self.pipe(
                    image=image,
                    latents=latents,
                    prompt_ids=prompt_ids,
                    neg_prompt_ids=n_prompt_ids,
                    params=self.params,
                    prng_seed=prng_seed,
                    jit=False,
                ).images

    def process_controlnet_pose(
        self,
        video_path,
        prompt,
        chunk_size=8,
        watermark="Picsart AI Research",
        merging_ratio=0.0,
        num_inference_steps=20,
        controlnet_conditioning_scale=1.0,
        guidance_scale=9.0,
        seed=42,
        eta=0.0,
        resolution=512,
        use_cf_attn=True,
        save_path=None,
    ):
        print("Module Pose")
        video_path = gradio_utils.motion_to_video_path(video_path)
        if self.model_type != ModelType.ControlNetPose:
            controlnet = FlaxControlNetModel.from_pretrained(
                "fusing/stable-diffusion-v1-5-controlnet-openpose"
            )
            self.set_model(
                ModelType.ControlNetPose,
                model_id="runwayml/stable-diffusion-v1-5",
                controlnet=controlnet,
            )
            self.pipe.scheduler = FlaxDDIMScheduler.from_config(
                self.pipe.scheduler.config
            )
            if use_cf_attn:
                self.pipe.unet.set_attn_processor(processor=self.controlnet_attn_proc)
                self.pipe.controlnet.set_attn_processor(
                    processor=self.controlnet_attn_proc
                )

        video_path = (
            gradio_utils.motion_to_video_path(video_path)
            if "Motion" in video_path
            else video_path
        )

        added_prompt = "best quality, extremely detailed, HD, ultra-realistic, 8K, HQ, masterpiece, trending on artstation, art, smooth"
        negative_prompts = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic"

        video, fps = utils.prepare_video(
            video_path, resolution, self.device, self.dtype, False, output_fps=4
        )
        control = (
            utils.pre_process_pose(video, apply_pose_detect=False)
            .to(self.device)
            .to(self.dtype)
        )
        f, _, h, w = video.shape
        self.generator.manual_seed(seed)
        latents = torch.randn(
            (1, 4, h // 8, w // 8),
            dtype=self.dtype,
            device=self.device,
            generator=self.generator,
        )
        latents = latents.repeat(f, 1, 1, 1)
        result = self.inference(
            image=control,
            prompt=prompt + ", " + added_prompt,
            height=h,
            width=w,
            negative_prompt=negative_prompts,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            controlnet_conditioning_scale=controlnet_conditioning_scale,
            eta=eta,
            latents=latents,
            seed=seed,
            output_type="numpy",
            split_to_chunks=True,
            chunk_size=chunk_size,
            merging_ratio=merging_ratio,
        )
        return utils.create_gif(
            result,
            fps,
            path=save_path,
            watermark=gradio_utils.logo_name_to_path(watermark),
        )

    def process_text2video(
        self,
        prompt,
        model_name="dreamlike-art/dreamlike-photoreal-2.0",
        motion_field_strength_x=12,
        motion_field_strength_y=12,
        t0=44,
        t1=47,
        n_prompt="",
        chunk_size=8,
        video_length=8,
        watermark="Picsart AI Research",
        merging_ratio=0.0,
        seed=0,
        resolution=512,
        fps=2,
        use_cf_attn=True,
        use_motion_field=True,
        smooth_bg=False,
        smooth_bg_strength=0.4,
        path=None,
    ):
        print("Module Text2Video")
        if self.model_type != ModelType.Text2Video or model_name != self.model_name:
            print("Model update")
            unet = FlaxUNet2DConditionModel.from_pretrained(
                model_name, subfolder="unet"
            )
            self.set_model(ModelType.Text2Video, model_id=model_name, unet=unet)
            self.pipe.scheduler = FlaxDDIMScheduler.from_config(
                self.pipe.scheduler.config
            )
            if use_cf_attn:
                self.pipe.unet.set_attn_processor(processor=self.text2video_attn_proc)
        self.generator.manual_seed(seed)

        added_prompt = "high quality, HD, 8K, trending on artstation, high focus, dramatic lighting"
        negative_prompts = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic"

        prompt = prompt.rstrip()
        if len(prompt) > 0 and (prompt[-1] == "," or prompt[-1] == "."):
            prompt = prompt.rstrip()[:-1]
        prompt = prompt.rstrip()
        prompt = prompt + ", " + added_prompt
        if len(n_prompt) > 0:
            negative_prompt = n_prompt
        else:
            negative_prompt = None

        result = self.inference(
            prompt=prompt,
            video_length=video_length,
            height=resolution,
            width=resolution,
            num_inference_steps=50,
            guidance_scale=7.5,
            guidance_stop_step=1.0,
            t0=t0,
            t1=t1,
            motion_field_strength_x=motion_field_strength_x,
            motion_field_strength_y=motion_field_strength_y,
            use_motion_field=use_motion_field,
            smooth_bg=smooth_bg,
            smooth_bg_strength=smooth_bg_strength,
            seed=seed,
            output_type="numpy",
            negative_prompt=negative_prompt,
            merging_ratio=merging_ratio,
            split_to_chunks=True,
            chunk_size=chunk_size,
        )
        return utils.create_video(
            result, fps, path=path, watermark=gradio_utils.logo_name_to_path(watermark)
        )

    def generate_animation(
        self,
        prompt: str,
        model_link: str = "dreamlike-art/dreamlike-photoreal-2.0",
        is_safetensor: bool = False,
        motion_field_strength_x: int = 12,
        motion_field_strength_y: int = 12,
        t0: int = 44,
        t1: int = 47,
        n_prompt: str = "",
        chunk_size: int = 8,
        video_length: int = 8,
        merging_ratio: float = 0.0,
        seed: int = 0,
        resolution: int = 512,
        fps: int = 2,
        use_cf_attn: bool = True,
        use_motion_field: bool = True,
        smooth_bg: bool = False,
        smooth_bg_strength: float = 0.4,
        path: str = None,
    ):
        if is_safetensor and model_link[-len(".safetensors") :] == ".safetensors":
            pipe = utils.load_safetensors_model(model_link)
        return

    def generate_initial_frames(
        self,
        prompt: str,
        model_link: str = "dreamlike-art/dreamlike-photoreal-2.0",
        is_safetensor: bool = False,
        n_prompt: str = "",
        width: int = 512,
        height: int = 512,
        # batch_count: int = 4,
        # batch_size: int = 1,
        cfg_scale: float = 7.0,
        seed: int = 0,
    ):
        print(f">>> prompt: {prompt}, model_link: {model_link}")

        pipe = StableDiffusionPipeline.from_pretrained(model_link)

        batch_size = 4
        prompt = [prompt] * batch_size
        negative_prompt = [n_prompt] * batch_size

        images = pipe(
            prompt,
            negative_prompt=negative_prompt,
            width=width,
            height=height,
            guidance_scale=cfg_scale,
        ).images

        return images