import io
import json
import urllib.parse
import urllib.request
from math import pi

import comfy.model_management as model_management
import comfy.utils
import numpy as np
import torch
from PIL import Image

from ..log import log
from ..utils import (
    EASINGS,
    apply_easing,
    get_server_info,
    numpy_NFOV,
    pil2tensor,
    tensor2np,
)


def get_image(filename, subfolder, folder_type):
    log.debug(
        f"Getting image {filename} from foldertype {folder_type} {f'in subfolder: {subfolder}' if subfolder else ''}"
    )
    data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
    base_url, port = get_server_info()

    url_values = urllib.parse.urlencode(data)
    url = f"http://{base_url}:{port}/view?{url_values}"
    log.debug(f"Fetching image from {url}")
    with urllib.request.urlopen(url) as response:
        return io.BytesIO(response.read())


class MTB_ToDevice:
    """Send a image or mask tensor to the given device."""

    @classmethod
    def INPUT_TYPES(cls):
        devices = ["cpu"]
        if torch.backends.mps.is_available():
            devices.append("mps")
        if torch.cuda.is_available():
            devices.append("cuda")
            for i in range(torch.cuda.device_count()):
                devices.append(f"cuda{i}")

        return {
            "required": {
                "ignore_errors": ("BOOLEAN", {"default": False}),
                "device": (devices, {"default": "cpu"}),
            },
            "optional": {
                "image": ("IMAGE",),
                "mask": ("MASK",),
            },
        }

    RETURN_TYPES = ("IMAGE", "MASK")
    RETURN_NAMES = ("images", "masks")
    CATEGORY = "mtb/utils"
    FUNCTION = "to_device"

    def to_device(
        self,
        *,
        ignore_errors=False,
        device="cuda",
        image: torch.Tensor | None = None,
        mask: torch.Tensor | None = None,
    ):
        if not ignore_errors and image is None and mask is None:
            raise ValueError(
                "You must either provide an image or a mask,"
                " use ignore_error to passthrough"
            )
        if image is not None:
            image = image.to(device)
        if mask is not None:
            mask = mask.to(device)
        return (image, mask)


# class MTB_ApplyTextTemplate:
class MTB_ApplyTextTemplate:
    """
    Experimental node to interpolate strings from inputs.

    Interpolation just requires {}, for instance:

    Some string {var_1} and {var_2}
    """

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "template": ("STRING", {"default": "", "multiline": True}),
            },
        }

    RETURN_TYPES = ("STRING",)
    RETURN_NAMES = ("string",)
    CATEGORY = "mtb/utils"
    FUNCTION = "execute"

    def execute(self, *, template: str, **kwargs):
        res = f"{template}"
        for k, v in kwargs.items():
            res = res.replace(f"{{{k}}}", f"{v}")

        return (res,)


class MTB_MatchDimensions:
    """Match images dimensions along the given dimension, preserving aspect ratio."""

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "source": ("IMAGE",),
                "reference": ("IMAGE",),
                "match": (["height", "width"], {"default": "height"}),
            },
        }

    RETURN_TYPES = ("IMAGE", "INT", "INT")
    RETURN_NAMES = ("image", "new_width", "new_height")
    CATEGORY = "mtb/utils"
    FUNCTION = "execute"

    def execute(
        self, source: torch.Tensor, reference: torch.Tensor, match: str
    ):
        import torchvision.transforms.functional as VF

        _batch_size, height, width, _channels = source.shape
        _rbatch_size, rheight, rwidth, _rchannels = reference.shape

        source_aspect_ratio = width / height
        # reference_aspect_ratio = rwidth / rheight

        source = source.permute(0, 3, 1, 2)
        reference = reference.permute(0, 3, 1, 2)

        if match == "height":
            new_height = rheight
            new_width = int(rheight * source_aspect_ratio)
        else:
            new_width = rwidth
            new_height = int(rwidth / source_aspect_ratio)

        resized_images = [
            VF.resize(
                source[i],
                (new_height, new_width),
                antialias=True,
                interpolation=Image.BICUBIC,
            )
            for i in range(_batch_size)
        ]
        resized_source = torch.stack(resized_images, dim=0)
        resized_source = resized_source.permute(0, 2, 3, 1)

        return (resized_source, new_width, new_height)


class MTB_FloatToFloats:
    """Conversion utility for compatibility with other extensions (AD, IPA, Fitz are using FLOAT to represent list of floats.)"""

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "float": ("FLOAT", {"default": 0.0, "forceInput": True}),
            }
        }

    RETURN_TYPES = ("FLOATS",)
    RETURN_NAMES = ("floats",)
    CATEGORY = "mtb/utils"
    FUNCTION = "convert"

    def convert(self, float: float):
        return (float,)


class MTB_FloatsToInts:
    """Conversion utility for compatibility with frame interpolation."""

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "floats": ("FLOATS", {"forceInput": True}),
            }
        }

    RETURN_TYPES = ("INTS", "INT")
    CATEGORY = "mtb/utils"
    FUNCTION = "convert"

    def convert(self, floats: list[float]):
        vals = [int(x) for x in floats]
        return (vals, vals)


class MTB_FloatsToFloat:
    """Conversion utility for compatibility with other extensions (AD, IPA, Fitz are using FLOAT to represent list of floats.)"""

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "floats": ("FLOATS",),
            }
        }

    RETURN_TYPES = ("FLOAT",)
    RETURN_NAMES = ("float",)
    CATEGORY = "mtb/utils"
    FUNCTION = "convert"

    def convert(self, floats):
        return (floats,)


class MTB_AutoPanEquilateral:
    """Generate a 360 panning video from an equilateral image."""

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "equilateral_image": ("IMAGE",),
                "fovX": ("FLOAT", {"default": 45.0}),
                "fovY": ("FLOAT", {"default": 45.0}),
                "elevation": ("FLOAT", {"default": 0.5}),
                "frame_count": ("INT", {"default": 100}),
                "width": ("INT", {"default": 768}),
                "height": ("INT", {"default": 512}),
            },
            "optional": {
                "floats_fovX": ("FLOATS",),
                "floats_fovY": ("FLOATS",),
                "floats_elevation": ("FLOATS",),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES = ("image",)
    CATEGORY = "mtb/utils"
    FUNCTION = "generate_frames"

    def check_floats(self, f: list[float] | None, expected_count: int):
        if f:
            if len(f) == expected_count:
                return True
            return False
        return True

    def generate_frames(
        self,
        equilateral_image: torch.Tensor,
        fovX: float,
        fovY: float,
        elevation: float,
        frame_count: int,
        width: int,
        height: int,
        floats_fovX: list[float] | None = None,
        floats_fovY: list[float] | None = None,
        floats_elevation: list[float] | None = None,
    ):
        source = tensor2np(equilateral_image)

        if len(source) > 1:
            log.warn(
                "You provided more than one image in the equilateral_image input, only the first will be used."
            )
        if not all(
            [
                self.check_floats(x, frame_count)
                for x in [floats_fovX, floats_fovY, floats_elevation]
            ]
        ):
            raise ValueError(
                "You provided less than the expected number of fovX, fovY, or elevation values."
            )

        source = source[0]
        frames = []

        pbar = comfy.utils.ProgressBar(frame_count)
        for i in range(frame_count):
            rotation_angle = (i / frame_count) * 2 * pi

            if floats_elevation:
                elevation = floats_elevation[i]

            if floats_fovX:
                fovX = floats_fovX[i]

            if floats_fovY:
                fovY = floats_fovY[i]

            fov = [fovX / 100, fovY / 100]
            center_point = [rotation_angle / (2 * pi), elevation]

            nfov = numpy_NFOV(fov, height, width)
            frame = nfov.to_nfov(source, center_point=center_point)

            frames.append(frame)

            model_management.throw_exception_if_processing_interrupted()
            pbar.update(1)

        return (pil2tensor(frames),)


class MTB_GetBatchFromHistory:
    """Very experimental node to load images from the history of the server.

    Queue items without output are ignored in the count.
    """

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "enable": ("BOOLEAN", {"default": True}),
                "count": ("INT", {"default": 1, "min": 0}),
                "offset": ("INT", {"default": 0, "min": -1e9, "max": 1e9}),
                "internal_count": ("INT", {"default": 0}),
            },
            "optional": {
                "passthrough_image": ("IMAGE",),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES = ("images",)
    CATEGORY = "mtb/animation"
    FUNCTION = "load_from_history"

    def load_from_history(
        self,
        *,
        enable=True,
        count=0,
        offset=0,
        internal_count=0,  # hacky way to invalidate the node
        passthrough_image=None,
    ):
        if not enable or count == 0:
            if passthrough_image is not None:
                log.debug("Using passthrough image")
                return (passthrough_image,)
            log.debug("Load from history is disabled for this iteration")
            return (torch.zeros(0),)
        frames = []

        base_url, port = get_server_info()

        history_url = f"http://{base_url}:{port}/history"
        log.debug(f"Fetching history from {history_url}")
        output = torch.zeros(0)
        with urllib.request.urlopen(history_url) as response:
            output = self.load_batch_frames(response, offset, count, frames)

        if output.size(0) == 0:
            log.warn("No output found in history")

        return (output,)

    def load_batch_frames(self, response, offset, count, frames):
        history = json.loads(response.read())

        output_images = []

        for run in history.values():
            for node_output in run["outputs"].values():
                if "images" in node_output:
                    for image in node_output["images"]:
                        image_data = get_image(
                            image["filename"],
                            image["subfolder"],
                            image["type"],
                        )
                        output_images.append(image_data)

        if not output_images:
            return torch.zeros(0)

        # Directly get desired range of images
        start_index = max(len(output_images) - offset - count, 0)
        end_index = len(output_images) - offset
        selected_images = output_images[start_index:end_index]

        frames = [Image.open(image) for image in selected_images]

        if not frames:
            return torch.zeros(0)
        elif len(frames) != count:
            log.warning(f"Expected {count} images, got {len(frames)} instead")

        return pil2tensor(frames)


class MTB_AnyToString:
    """Tries to take any input and convert it to a string."""

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {"input": ("*",)},
        }

    RETURN_TYPES = ("STRING",)
    FUNCTION = "do_str"
    CATEGORY = "mtb/converters"

    def do_str(self, input):
        if isinstance(input, str):
            return (input,)
        elif isinstance(input, torch.Tensor):
            return (f"Tensor of shape {input.shape} and dtype {input.dtype}",)
        elif isinstance(input, Image.Image):
            return (f"PIL Image of size {input.size} and mode {input.mode}",)
        elif isinstance(input, np.ndarray):
            return (
                f"Numpy array of shape {input.shape} and dtype {input.dtype}",
            )

        elif isinstance(input, dict):
            return (
                f"Dictionary of {len(input)} items, with keys {input.keys()}",
            )

        else:
            log.debug(f"Falling back to string conversion of {input}")
            return (str(input),)


class MTB_StringReplace:
    """Basic string replacement."""

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "string": ("STRING", {"forceInput": True}),
                "old": ("STRING", {"default": ""}),
                "new": ("STRING", {"default": ""}),
            }
        }

    FUNCTION = "replace_str"
    RETURN_TYPES = ("STRING",)
    CATEGORY = "mtb/string"

    def replace_str(self, string: str, old: str, new: str):
        log.debug(f"Current string: {string}")
        log.debug(f"Find string: {old}")
        log.debug(f"Replace string: {new}")

        string = string.replace(old, new)

        log.debug(f"New string: {string}")

        return (string,)


class MTB_MathExpression:
    """Node to evaluate a simple math expression string"""

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "expression": ("STRING", {"default": "", "multiline": True}),
            }
        }

    FUNCTION = "eval_expression"
    RETURN_TYPES = ("FLOAT", "INT")
    RETURN_NAMES = ("result (float)", "result (int)")
    CATEGORY = "mtb/math"
    DESCRIPTION = (
        "evaluate a simple math expression string, only supports literal_eval"
    )

    def eval_expression(self, expression: str, **kwargs):
        from ast import literal_eval

        for key, value in kwargs.items():
            log.debug(f"Replacing placeholder <{key}> with value {value}")
            expression = expression.replace(f"<{key}>", str(value))

        result = -1
        try:
            result = literal_eval(expression)
        except SyntaxError as e:
            raise ValueError(
                f"The expression syntax is wrong '{expression}': {e}"
            ) from e

        except Exception as e:
            raise ValueError(
                f"Math expression only support literal_eval now: {e}"
            )

        return (result, int(result))


class MTB_FitNumber:
    """Fit the input float using a source and target range"""

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "value": ("FLOAT", {"default": 0, "forceInput": True}),
                "clamp": ("BOOLEAN", {"default": False}),
                "source_min": (
                    "FLOAT",
                    {"default": 0.0, "step": 0.01, "min": -1e5},
                ),
                "source_max": (
                    "FLOAT",
                    {"default": 1.0, "step": 0.01, "min": -1e5},
                ),
                "target_min": (
                    "FLOAT",
                    {"default": 0.0, "step": 0.01, "min": -1e5},
                ),
                "target_max": (
                    "FLOAT",
                    {"default": 1.0, "step": 0.01, "min": -1e5},
                ),
                "easing": (
                    EASINGS,
                    {"default": "Linear"},
                ),
            }
        }

    FUNCTION = "set_range"
    RETURN_TYPES = ("FLOAT",)
    CATEGORY = "mtb/math"
    DESCRIPTION = "Fit the input float using a source and target range"

    def set_range(
        self,
        value: float,
        clamp: bool,
        source_min: float,
        source_max: float,
        target_min: float,
        target_max: float,
        easing: str,
    ):
        if source_min == source_max:
            normalized_value = 0
        else:
            normalized_value = (value - source_min) / (source_max - source_min)
        if clamp:
            normalized_value = max(min(normalized_value, 1), 0)

        eased_value = apply_easing(normalized_value, easing)

        # - Convert the eased value to the target range
        res = target_min + (target_max - target_min) * eased_value

        return (res,)


class MTB_ConcatImages:
    """Add images to batch."""

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "concatenate_tensors"
    CATEGORY = "mtb/image"

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {"reverse": ("BOOLEAN", {"default": False})},
            "optional": {
                "on_mismatch": (
                    ["Error", "Smallest", "Largest"],
                    {"default": "Smallest"},
                )
            },
        }

    def concatenate_tensors(
        self,
        reverse: bool,
        on_mismatch: str = "Smallest",
        **kwargs: torch.Tensor,
    ) -> tuple[torch.Tensor]:
        tensors = list(kwargs.values())

        if on_mismatch == "Error":
            shapes = [tensor.shape for tensor in tensors]
            if not all(shape == shapes[0] for shape in shapes):
                raise ValueError(
                    "All input tensors must have the same shape when on_mismatch is 'Error'."
                )

        else:
            import torch.nn.functional as F

            if on_mismatch == "Smallest":
                target_shape = min(
                    (tensor.shape for tensor in tensors),
                    key=lambda s: (s[1], s[2]),
                )
            else:  # on_mismatch == "Largest"
                target_shape = max(
                    (tensor.shape for tensor in tensors),
                    key=lambda s: (s[1], s[2]),
                )

            target_height, target_width = target_shape[1], target_shape[2]

            resized_tensors = []
            for tensor in tensors:
                if (
                    tensor.shape[1] != target_height
                    or tensor.shape[2] != target_width
                ):
                    resized_tensor = F.interpolate(
                        tensor.permute(0, 3, 1, 2),
                        size=(target_height, target_width),
                        mode="bilinear",
                        align_corners=False,
                    )
                    resized_tensor = resized_tensor.permute(0, 2, 3, 1)
                    resized_tensors.append(resized_tensor)
                else:
                    resized_tensors.append(tensor)

            tensors = resized_tensors

        concatenated = torch.cat(tensors, dim=0)

        return (concatenated,)


__nodes__ = [
    MTB_StringReplace,
    MTB_FitNumber,
    MTB_GetBatchFromHistory,
    MTB_AnyToString,
    MTB_ConcatImages,
    MTB_MathExpression,
    MTB_ToDevice,
    MTB_ApplyTextTemplate,
    MTB_MatchDimensions,
    MTB_AutoPanEquilateral,
    MTB_FloatsToFloat,
    MTB_FloatToFloats,
    MTB_FloatsToInts,
]