import torch
from DeepCache import DeepCacheSDHelper
from diffusers import ControlNetModel
from diffusers.models.attention_processor import AttnProcessor2_0, IPAdapterAttnProcessor2_0

from .config import Config
from .logger import Logger
from .upscaler import RealESRGAN
from .utils import timer


class Loader:
    """
    A lazy-loading resource manager for Stable Diffusion pipelines. Lifecycles are managed by
    comparing the current state with desired. Can be used as a singleton when created by the
    `get_loader()` helper.

    Usage:
        loader = get_loader(singleton=True)
        loader.load(
            pipeline_id="controlnet_txt2img",
            ip_adapter_model="full-face",
            model="XpucT/Reliberate",
            scheduler="UniPC",
            controlnet_annotator="canny",
            deepcache_interval=2,
            scale=2,
            use_karras=True
        )
    """

    def __init__(self):
        self.model = ""
        self.pipeline = None
        self.upscaler = None
        self.controlnet = None
        self.annotator = ""  # controlnet annotator (canny)
        self.ip_adapter = ""  # ip-adapter kind (full-face or plus)
        self.log = Logger("Loader")

    def should_unload_upscaler(self, scale=1):
        return self.upscaler is not None and self.upscaler.scale != scale

    def should_unload_deepcache(self, cache_interval=1):
        has_deepcache = hasattr(self.pipeline, "deepcache")
        if has_deepcache and cache_interval == 1:
            return True
        if has_deepcache and self.pipeline.deepcache.params["cache_interval"] != cache_interval:
            # Unload if interval is different so it can be reloaded
            return True
        return False

    def should_unload_ip_adapter(self, ip_adapter_model=""):
        if not self.ip_adapter:
            return False
        if not ip_adapter_model:
            return True
        if self.ip_adapter != ip_adapter_model:
            # Unload if model is different so it can be reloaded
            return True
        return False

    def should_unload_controlnet(self, pipeline_id="", annotator=""):
        if self.controlnet is None:
            return False
        if self.annotator != annotator:
            return True
        if not pipeline_id.startswith("controlnet_"):
            return True
        return False

    def should_unload_pipeline(self, model=""):
        if self.pipeline is None:
            return False
        if self.model != model:
            return True
        return False

    # Copied from https://github.com/huggingface/diffusers/blob/v0.28.0/src/diffusers/loaders/ip_adapter.py#L300
    def unload_ip_adapter(self):
        # Remove the image encoder if text-to-image
        if isinstance(self.pipeline, Config.PIPELINES["txt2img"]):
            self.pipeline.image_encoder = None
            self.pipeline.register_to_config(image_encoder=[None, None])

        # Remove hidden projection layer added by IP-Adapter
        self.pipeline.unet.encoder_hid_proj = None
        self.pipeline.unet.config.encoder_hid_dim_type = None

        # Remove the feature extractor
        self.pipeline.feature_extractor = None
        self.pipeline.register_to_config(feature_extractor=[None, None])

        # Replace the custom attention processors with defaults
        attn_procs = {}
        for name, value in self.pipeline.unet.attn_processors.items():
            attn_processor_class = AttnProcessor2_0()  # raises if not torch 2
            attn_procs[name] = (
                attn_processor_class
                if isinstance(value, IPAdapterAttnProcessor2_0)
                else value.__class__()
            )
        self.pipeline.unet.set_attn_processor(attn_procs)
        self.ip_adapter = ""

    def unload_all(
        self,
        pipeline_id="",
        ip_adapter_model="",
        model="",
        controlnet_annotator="",
        deepcache_interval=1,
        scale=1,
    ):
        if self.should_unload_deepcache(deepcache_interval):  # remove deepcache first
            self.log.info("Disabling DeepCache")
            self.pipeline.deepcache.disable()
            delattr(self.pipeline, "deepcache")

        if self.should_unload_ip_adapter(ip_adapter_model):
            self.log.info("Unloading IP-Adapter")
            self.unload_ip_adapter()

        if self.should_unload_controlnet(pipeline_id, controlnet_annotator):
            self.log.info("Unloading ControlNet")
            self.controlnet = None
            self.annotator = ""

        if self.should_unload_upscaler(scale):
            self.log.info("Unloading upscaler")
            self.upscaler = None

        if self.should_unload_pipeline(model):
            self.log.info("Unloading pipeline")
            self.pipeline = None
            self.model = ""

    def should_load_upscaler(self, scale=1):
        return self.upscaler is None and scale > 1

    def should_load_deepcache(self, cache_interval=1):
        has_deepcache = hasattr(self.pipeline, "deepcache")
        if not has_deepcache and cache_interval > 1:
            return True
        return False

    def should_load_controlnet(self, pipeline_id=""):
        return self.controlnet is None and pipeline_id.startswith("controlnet_")

    def should_load_ip_adapter(self, ip_adapter_model=""):
        has_ip_adapter = (
            hasattr(self.pipeline.unet, "encoder_hid_proj")
            and self.pipeline.unet.config.encoder_hid_dim_type == "ip_image_proj"
        )
        return not has_ip_adapter and ip_adapter_model != ""

    def should_load_scheduler(self, cls, use_karras=False):
        has_karras = hasattr(self.pipeline.scheduler.config, "use_karras_sigmas")
        if not isinstance(self.pipeline.scheduler, cls):
            return True
        if has_karras and self.pipeline.scheduler.config.use_karras_sigmas != use_karras:
            return True
        return False

    def should_load_pipeline(self, pipeline_id=""):
        if self.pipeline is None:
            return True
        if not isinstance(self.pipeline, Config.PIPELINES[pipeline_id]):
            return True
        return False

    def load_upscaler(self, scale=1):
        with timer(f"Loading {scale}x upscaler", logger=self.log.info):
            self.upscaler = RealESRGAN(scale, device=self.pipeline.device)
            self.upscaler.load_weights()

    def load_deepcache(self, cache_interval=1):
        self.log.info(f"Enabling DeepCache interval {cache_interval}")
        self.pipeline.deepcache = DeepCacheSDHelper(self.pipeline)
        self.pipeline.deepcache.set_params(cache_interval=cache_interval)
        self.pipeline.deepcache.enable()

    def load_controlnet(self, controlnet_annotator):
        with timer("Loading ControlNet", logger=self.log.info):
            self.controlnet = ControlNetModel.from_pretrained(
                Config.ANNOTATORS[controlnet_annotator],
                variant="fp16",
                torch_dtype=torch.float16,
            )
            self.annotator = controlnet_annotator

    def load_ip_adapter(self, ip_adapter_model=""):
        with timer("Loading IP-Adapter", logger=self.log.info):
            self.pipeline.load_ip_adapter(
                "h94/IP-Adapter",
                subfolder="models",
                weight_name=f"ip-adapter-{ip_adapter_model}_sd15.safetensors",
            )
            self.pipeline.set_ip_adapter_scale(0.5)  # 50% works the best
            self.ip_adapter = ip_adapter_model

    def load_scheduler(self, cls, use_karras=False, **kwargs):
        self.log.info(f"Loading {cls.__name__}{' with Karras' if use_karras else ''}")
        self.pipeline.scheduler = cls(**kwargs)

    def load_pipeline(
        self,
        pipeline_id,
        model,
        **kwargs,
    ):
        Pipeline = Config.PIPELINES[pipeline_id]

        # Load from scratch
        if self.pipeline is None:
            with timer(f"Loading {model} ({pipeline_id})", logger=self.log.info):
                if self.controlnet is not None:
                    kwargs["controlnet"] = self.controlnet
                if model in Config.SINGLE_FILE_MODELS:
                    checkpoint = Config.HF_REPOS[model][0]
                    self.pipeline = Pipeline.from_single_file(
                        f"https://huggingface.co/{model}/{checkpoint}",
                        **kwargs,
                    ).to("cuda")
                else:
                    self.pipeline = Pipeline.from_pretrained(model, **kwargs).to("cuda")

        # Change to a different one
        else:
            with timer(f"Changing pipeline to {pipeline_id}", logger=self.log.info):
                kwargs = {}
                if self.controlnet is not None:
                    kwargs["controlnet"] = self.controlnet
                self.pipeline = Pipeline.from_pipe(
                    self.pipeline,
                    **kwargs,
                ).to("cuda")

        # Update model and disable terminal progress bars
        self.model = model
        self.pipeline.set_progress_bar_config(disable=True)

    def load(
        self,
        pipeline_id,
        ip_adapter_model,
        model,
        scheduler,
        controlnet_annotator,
        deepcache_interval,
        scale,
        use_karras,
    ):
        Scheduler = Config.SCHEDULERS[scheduler]

        scheduler_kwargs = {
            "beta_start": 0.00085,
            "beta_end": 0.012,
            "beta_schedule": "scaled_linear",
            "timestep_spacing": "leading",
            "steps_offset": 1,
        }

        if scheduler not in ["Euler a"]:
            scheduler_kwargs["use_karras_sigmas"] = use_karras

        pipeline_kwargs = {
            "torch_dtype": torch.float16,  # defaults to fp32
            "safety_checker": None,
            "requires_safety_checker": False,
            "scheduler": Scheduler(**scheduler_kwargs),
        }

        # Single-file models don't need a variant
        if model not in Config.SINGLE_FILE_MODELS:
            pipeline_kwargs["variant"] = "fp16"
        else:
            pipeline_kwargs["variant"] = None

        # Prepare state for loading checks
        self.unload_all(
            pipeline_id,
            ip_adapter_model,
            model,
            controlnet_annotator,
            deepcache_interval,
            scale,
        )

        # Load controlnet model before pipeline
        if self.should_load_controlnet(pipeline_id):
            self.load_controlnet(controlnet_annotator)

        if self.should_load_pipeline(pipeline_id):
            self.load_pipeline(pipeline_id, model, **pipeline_kwargs)

        if self.should_load_scheduler(Scheduler, use_karras):
            self.load_scheduler(Scheduler, use_karras, **scheduler_kwargs)

        if self.should_load_deepcache(deepcache_interval):
            self.load_deepcache(deepcache_interval)

        if self.should_load_ip_adapter(ip_adapter_model):
            self.load_ip_adapter(ip_adapter_model)

        if self.should_load_upscaler(scale):
            self.load_upscaler(scale)


# Get a singleton or a new instance of the Loader
def get_loader(singleton=False):
    if not singleton:
        return Loader()
    else:
        if not hasattr(get_loader, "_instance"):
            get_loader._instance = Loader()
        assert isinstance(get_loader._instance, Loader)
        return get_loader._instance