from __future__ import annotations

import gc
import json
import tempfile
from typing import Generator

import numpy as np
import PIL.Image
import torch
from diffusers import DiffusionPipeline, StableDiffusionUpscalePipeline
from diffusers.pipelines.deepfloyd_if import (fast27_timesteps,
                                              smart27_timesteps,
                                              smart50_timesteps,
                                              smart100_timesteps,
                                              smart185_timesteps)

from settings import (DISABLE_AUTOMATIC_CPU_OFFLOAD, DISABLE_SD_X4_UPSCALER,
                      HF_TOKEN, MAX_NUM_IMAGES, MAX_NUM_STEPS, MAX_SEED,
                      RUN_GARBAGE_COLLECTION)


class Model:
    def __init__(self):
        self.device = torch.device(
            'cuda:0' if torch.cuda.is_available() else 'cpu')
        self.pipe = None
        self.super_res_1_pipe = None
        self.super_res_2_pipe = None
        self.watermark_image = None

        if torch.cuda.is_available():
            self.load_weights()
            self.watermark_image = PIL.Image.fromarray(
                self.pipe.watermarker.watermark_image.to(
                    torch.uint8).cpu().numpy(),
                mode='RGBA')

    def load_weights(self) -> None:
        self.pipe = DiffusionPipeline.from_pretrained(
            'DeepFloyd/IF-I-XL-v1.0',
            torch_dtype=torch.float16,
            variant='fp16',
            use_safetensors=True,
            use_auth_token=HF_TOKEN)
        self.super_res_1_pipe = DiffusionPipeline.from_pretrained(
            'DeepFloyd/IF-II-L-v1.0',
            text_encoder=None,
            torch_dtype=torch.float16,
            variant='fp16',
            use_safetensors=True,
            use_auth_token=HF_TOKEN)

        if not DISABLE_SD_X4_UPSCALER:
            self.super_res_2_pipe = StableDiffusionUpscalePipeline.from_pretrained(
                'stabilityai/stable-diffusion-x4-upscaler',
                torch_dtype=torch.float16)

        if DISABLE_AUTOMATIC_CPU_OFFLOAD:
            self.pipe.to(self.device)
            self.super_res_1_pipe.to(self.device)
            if not DISABLE_SD_X4_UPSCALER:
                self.super_res_2_pipe.to(self.device)
        else:
            self.pipe.enable_model_cpu_offload()
            self.super_res_1_pipe.enable_model_cpu_offload()
            if not DISABLE_SD_X4_UPSCALER:
                self.super_res_2_pipe.enable_model_cpu_offload()

    def apply_watermark_to_sd_x4_upscaler_results(
            self, images: list[PIL.Image.Image]) -> None:
        w, h = images[0].size

        stability_x4_upscaler_sample_size = 128

        coef = min(h / stability_x4_upscaler_sample_size,
                   w / stability_x4_upscaler_sample_size)
        img_h, img_w = (int(h / coef), int(w / coef)) if coef < 1 else (h, w)

        S1, S2 = 1024**2, img_w * img_h
        K = (S2 / S1)**0.5
        watermark_size = int(K * 62)
        watermark_x = img_w - int(14 * K)
        watermark_y = img_h - int(14 * K)

        watermark_image = self.watermark_image.copy().resize(
            (watermark_size, watermark_size),
            PIL.Image.Resampling.BICUBIC,
            reducing_gap=None)

        for image in images:
            image.paste(watermark_image,
                        box=(
                            watermark_x - watermark_size,
                            watermark_y - watermark_size,
                            watermark_x,
                            watermark_y,
                        ),
                        mask=watermark_image.split()[-1])

    @staticmethod
    def to_pil_images(images: torch.Tensor) -> list[PIL.Image.Image]:
        images = (images / 2 + 0.5).clamp(0, 1)
        images = images.cpu().permute(0, 2, 3, 1).float().numpy()
        images = np.round(images * 255).astype(np.uint8)
        return [PIL.Image.fromarray(image) for image in images]

    @staticmethod
    def check_seed(seed: int) -> None:
        if not 0 <= seed <= MAX_SEED:
            raise ValueError

    @staticmethod
    def check_num_images(num_images: int) -> None:
        if not 1 <= num_images <= MAX_NUM_IMAGES:
            raise ValueError

    @staticmethod
    def check_num_inference_steps(num_steps: int) -> None:
        if not 1 <= num_steps <= MAX_NUM_STEPS:
            raise ValueError

    @staticmethod
    def get_custom_timesteps(name: str) -> list[int] | None:
        if name == 'none':
            timesteps = None
        elif name == 'fast27':
            timesteps = fast27_timesteps
        elif name == 'smart27':
            timesteps = smart27_timesteps
        elif name == 'smart50':
            timesteps = smart50_timesteps
        elif name == 'smart100':
            timesteps = smart100_timesteps
        elif name == 'smart185':
            timesteps = smart185_timesteps
        else:
            raise ValueError
        return timesteps

    @staticmethod
    def run_garbage_collection():
        gc.collect()
        torch.cuda.empty_cache()

    def run_stage1(
        self,
        prompt: str,
        negative_prompt: str = '',
        seed: int = 0,
        num_images: int = 1,
        guidance_scale_1: float = 7.0,
        custom_timesteps_1: str = 'smart100',
        num_inference_steps_1: int = 100,
    ) -> tuple[list[PIL.Image.Image], str, str]:
        self.check_seed(seed)
        self.check_num_images(num_images)
        self.check_num_inference_steps(num_inference_steps_1)

        if RUN_GARBAGE_COLLECTION:
            self.run_garbage_collection()

        generator = torch.Generator(device=self.device).manual_seed(seed)

        prompt_embeds, negative_embeds = self.pipe.encode_prompt(
            prompt=prompt, negative_prompt=negative_prompt)

        timesteps = self.get_custom_timesteps(custom_timesteps_1)

        images = self.pipe(prompt_embeds=prompt_embeds,
                           negative_prompt_embeds=negative_embeds,
                           num_images_per_prompt=num_images,
                           guidance_scale=guidance_scale_1,
                           timesteps=timesteps,
                           num_inference_steps=num_inference_steps_1,
                           generator=generator,
                           output_type='pt').images
        pil_images = self.to_pil_images(images)
        # self.pipe.watermarker.apply_watermark(
        #     pil_images, self.pipe.unet.config.sample_size)

        stage1_params = {
            'prompt': prompt,
            'negative_prompt': negative_prompt,
            'seed': seed,
            'num_images': num_images,
            'guidance_scale_1': guidance_scale_1,
            'custom_timesteps_1': custom_timesteps_1,
            'num_inference_steps_1': num_inference_steps_1,
        }
        with tempfile.NamedTemporaryFile(mode='w', delete=False) as param_file:
            param_file.write(json.dumps(stage1_params))
        stage1_result = {
            'prompt_embeds': prompt_embeds,
            'negative_embeds': negative_embeds,
            'images': images,
            'pil_images': pil_images,
        }
        with tempfile.NamedTemporaryFile(delete=False) as result_file:
            torch.save(stage1_result, result_file.name)
        return pil_images, param_file.name, result_file.name

    def run_stage2(
        self,
        stage1_result_path: str,
        stage2_index: int,
        seed_2: int = 0,
        guidance_scale_2: float = 4.0,
        custom_timesteps_2: str = 'smart50',
        num_inference_steps_2: int = 50,
        disable_watermark: bool = False,
    ) -> PIL.Image.Image:
        self.check_seed(seed_2)
        self.check_num_inference_steps(num_inference_steps_2)

        if RUN_GARBAGE_COLLECTION:
            self.run_garbage_collection()

        generator = torch.Generator(device=self.device).manual_seed(seed_2)

        stage1_result = torch.load(stage1_result_path)
        prompt_embeds = stage1_result['prompt_embeds']
        negative_embeds = stage1_result['negative_embeds']
        images = stage1_result['images']
        images = images[[stage2_index]]

        timesteps = self.get_custom_timesteps(custom_timesteps_2)

        out = self.super_res_1_pipe(image=images,
                                    prompt_embeds=prompt_embeds,
                                    negative_prompt_embeds=negative_embeds,
                                    num_images_per_prompt=1,
                                    guidance_scale=guidance_scale_2,
                                    timesteps=timesteps,
                                    num_inference_steps=num_inference_steps_2,
                                    generator=generator,
                                    output_type='pt',
                                    noise_level=250).images
        pil_images = self.to_pil_images(out)

        if disable_watermark:
            return pil_images[0]

        # self.super_res_1_pipe.watermarker.apply_watermark(
        #     pil_images, self.super_res_1_pipe.unet.config.sample_size)
        return pil_images[0]

    def run_stage3(
        self,
        image: PIL.Image.Image,
        prompt: str = '',
        negative_prompt: str = '',
        seed_3: int = 0,
        guidance_scale_3: float = 9.0,
        num_inference_steps_3: int = 75,
    ) -> PIL.Image.Image:
        self.check_seed(seed_3)
        self.check_num_inference_steps(num_inference_steps_3)

        if RUN_GARBAGE_COLLECTION:
            self.run_garbage_collection()

        generator = torch.Generator(device=self.device).manual_seed(seed_3)
        out = self.super_res_2_pipe(image=image,
                                    prompt=prompt,
                                    negative_prompt=negative_prompt,
                                    num_images_per_prompt=1,
                                    guidance_scale=guidance_scale_3,
                                    num_inference_steps=num_inference_steps_3,
                                    generator=generator,
                                    noise_level=100).images
        # self.apply_watermark_to_sd_x4_upscaler_results(out)
        return out[0]

    def run_stage2_3(
        self,
        stage1_result_path: str,
        stage2_index: int,
        seed_2: int = 0,
        guidance_scale_2: float = 4.0,
        custom_timesteps_2: str = 'smart50',
        num_inference_steps_2: int = 50,
        prompt: str = '',
        negative_prompt: str = '',
        seed_3: int = 0,
        guidance_scale_3: float = 9.0,
        num_inference_steps_3: int = 75,
    ) -> Generator[PIL.Image.Image]:
        self.check_seed(seed_3)
        self.check_num_inference_steps(num_inference_steps_3)

        out_image = self.run_stage2(
            stage1_result_path=stage1_result_path,
            stage2_index=stage2_index,
            seed_2=seed_2,
            guidance_scale_2=guidance_scale_2,
            custom_timesteps_2=custom_timesteps_2,
            num_inference_steps_2=num_inference_steps_2,
            disable_watermark=True)
        temp_image = out_image.copy()
        # self.super_res_1_pipe.watermarker.apply_watermark(
        #     [temp_image], self.super_res_1_pipe.unet.config.sample_size)
        yield temp_image
        yield self.run_stage3(image=out_image,
                              prompt=prompt,
                              negative_prompt=negative_prompt,
                              seed_3=seed_3,
                              guidance_scale_3=guidance_scale_3,
                              num_inference_steps_3=num_inference_steps_3)