import os

import comfy
import comfy.utils
import cv2
import folder_paths
import numpy as np
import torch
from comfy import model_management
from PIL import Image

from ..log import NullWriter, log
from ..utils import get_model_path, np2tensor, pil2tensor, tensor2np


class MTB_LoadFaceEnhanceModel:
    """Loads a GFPGan or RestoreFormer model for face enhancement."""

    def __init__(self) -> None:
        pass

    @classmethod
    def get_models_root(cls):
        fr = get_model_path("face_restore")
        # fr = Path(folder_paths.models_dir) / "face_restore"
        if fr.exists():
            return (fr, None)

        um = get_model_path("upscale_models")
        return (fr, um) if um.exists() else (None, None)

    @classmethod
    def get_models(cls):
        fr_models_path, um_models_path = cls.get_models_root()

        if fr_models_path is None and um_models_path is None:
            if not hasattr(cls, "_warned"):
                log.warning("Face restoration models not found.")
                cls._warned = True
            return []
        if not fr_models_path.exists():
            # log.warning(
            #     f"No Face Restore checkpoints found at {fr_models_path} (if you've used mtb before these checkpoints were saved in upscale_models before)"
            # )
            # log.warning(
            #     "For now we fallback to upscale_models but this will be removed in a future version"
            # )
            if um_models_path.exists():
                return [
                    x
                    for x in um_models_path.iterdir()
                    if x.name.endswith(".pth")
                    and ("GFPGAN" in x.name or "RestoreFormer" in x.name)
                ]
            return []

        return [
            x
            for x in fr_models_path.iterdir()
            if x.name.endswith(".pth")
            and ("GFPGAN" in x.name or "RestoreFormer" in x.name)
        ]

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "model_name": (
                    [x.name for x in cls.get_models()],
                    {"default": "None"},
                ),
                "upscale": ("INT", {"default": 1}),
            },
            "optional": {"bg_upsampler": ("UPSCALE_MODEL", {"default": None})},
        }

    RETURN_TYPES = ("FACEENHANCE_MODEL",)
    RETURN_NAMES = ("model",)
    FUNCTION = "load_model"
    CATEGORY = "mtb/facetools"
    DEPRECATED = True

    def load_model(self, model_name, upscale=2, bg_upsampler=None):
        from gfpgan import GFPGANer

        basic = "RestoreFormer" not in model_name

        fr_root, um_root = self.get_models_root()

        if bg_upsampler is not None:
            log.warning(
                f"Upscale value overridden to {bg_upsampler.scale} from bg_upsampler"
            )
            upscale = bg_upsampler.scale
            bg_upsampler = BGUpscaleWrapper(bg_upsampler)

        sys.stdout = NullWriter()
        model = GFPGANer(
            model_path=(
                (fr_root if fr_root.exists() else um_root) / model_name
            ).as_posix(),
            upscale=upscale,
            arch="clean"
            if basic
            else "RestoreFormer",  # or original for v1.0 only
            channel_multiplier=2,  # 1 for v1.0 only
            bg_upsampler=bg_upsampler,
        )

        sys.stdout = sys.__stdout__
        return (model,)


class BGUpscaleWrapper:
    def __init__(self, upscale_model) -> None:
        self.upscale_model = upscale_model

    def enhance(self, img: Image.Image, outscale=2):
        device = model_management.get_torch_device()
        self.upscale_model.to(device)

        tile = 128 + 64
        overlap = 8

        imgt = np2tensor(img)
        imgt = imgt.movedim(-1, -3).to(device)

        steps = imgt.shape[0] * comfy.utils.get_tiled_scale_steps(
            imgt.shape[3],
            imgt.shape[2],
            tile_x=tile,
            tile_y=tile,
            overlap=overlap,
        )

        log.debug(f"Steps: {steps}")

        pbar = comfy.utils.ProgressBar(steps)

        s = comfy.utils.tiled_scale(
            imgt,
            lambda a: self.upscale_model(a),
            tile_x=tile,
            tile_y=tile,
            overlap=overlap,
            upscale_amount=self.upscale_model.scale,
            pbar=pbar,
        )

        self.upscale_model.cpu()
        s = torch.clamp(s.movedim(-3, -1), min=0, max=1.0)
        return (tensor2np(s)[0],)


import sys


class MTB_RestoreFace:
    """Uses GFPGan to restore faces"""

    def __init__(self) -> None:
        pass

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "restore"
    CATEGORY = "mtb/facetools"
    DEPRECATED = True

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "image": ("IMAGE",),
                "model": ("FACEENHANCE_MODEL",),
                # Input are aligned faces
                "aligned": ("BOOLEAN", {"default": False}),
                # Only restore the center face
                "only_center_face": ("BOOLEAN", {"default": False}),
                # Adjustable weights
                "weight": ("FLOAT", {"default": 0.5}),
                "save_tmp_steps": ("BOOLEAN", {"default": True}),
            },
            "optional": {
                "preserve_alpha": ("BOOLEAN", {"default": True}),
            },
        }

    def do_restore(
        self,
        image: torch.Tensor,
        model,
        aligned,
        only_center_face,
        weight,
        save_tmp_steps,
        preserve_alpha: bool = False,
    ) -> torch.Tensor:
        pimage = tensor2np(image)[0]
        width, height = pimage.shape[1], pimage.shape[0]
        source_img = cv2.cvtColor(np.array(pimage), cv2.COLOR_RGB2BGR)

        alpha_channel = None
        if (
            preserve_alpha and image.size(-1) == 4
        ):  # Check if the image has an alpha channel
            alpha_channel = pimage[:, :, 3]
            pimage = pimage[:, :, :3]  # Remove alpha channel for processing

        sys.stdout = NullWriter()
        cropped_faces, restored_faces, restored_img = model.enhance(
            source_img,
            has_aligned=aligned,
            only_center_face=only_center_face,
            paste_back=True,
            # TODO: weight has no effect in 1.3 and 1.4 (only tested these for now...)
            weight=weight,
        )
        sys.stdout = sys.__stdout__
        log.warning(f"Weight value has no effect for now. (value: {weight})")

        if save_tmp_steps:
            self.save_intermediate_images(
                cropped_faces, restored_faces, height, width
            )
        output = None
        if restored_img is not None:
            restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
            output = Image.fromarray(restored_img)

            if alpha_channel is not None:
                alpha_resized = Image.fromarray(alpha_channel).resize(
                    output.size, Image.LANCZOS
                )
                output.putalpha(alpha_resized)
            # imwrite(restored_img, save_restore_path)

        return pil2tensor(output)

    def restore(
        self,
        image: torch.Tensor,
        model,
        aligned=False,
        only_center_face=False,
        weight=0.5,
        save_tmp_steps=True,
        preserve_alpha: bool = False,
    ) -> tuple[torch.Tensor]:
        out = [
            self.do_restore(
                image[i],
                model,
                aligned,
                only_center_face,
                weight,
                save_tmp_steps,
                preserve_alpha,
            )
            for i in range(image.size(0))
        ]

        return (torch.cat(out, dim=0),)

    def get_step_image_path(self, step, idx):
        (
            full_output_folder,
            filename,
            counter,
            _subfolder,
            _filename_prefix,
        ) = folder_paths.get_save_image_path(
            f"{step}_{idx:03}",
            folder_paths.temp_directory,
        )
        file = f"{filename}_{counter:05}_.png"

        return os.path.join(full_output_folder, file)

    def save_intermediate_images(
        self, cropped_faces, restored_faces, height, width
    ):
        for idx, (cropped_face, restored_face) in enumerate(
            zip(cropped_faces, restored_faces, strict=False)
        ):
            face_id = idx + 1
            file = self.get_step_image_path("cropped_faces", face_id)
            cv2.imwrite(file, cropped_face)

            file = self.get_step_image_path("cropped_faces_restored", face_id)
            cv2.imwrite(file, restored_face)

            file = self.get_step_image_path("cropped_faces_compare", face_id)

            # save comparison image
            cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
            cv2.imwrite(file, cmp_img)


__nodes__ = [MTB_RestoreFace, MTB_LoadFaceEnhanceModel]