import os
import io
import re
import time
import random
import torch
from typing import Dict, Final, List, Optional, Tuple, cast

from PIL import Image, ImageDraw, ImageEnhance
from PIL.Image import Image as PILImage
from diffusers import StableDiffusionPipeline

model_id: Final = "Onodofthenorth/SD_PixelArt_SpriteSheet_Generator"
pipe = StableDiffusionPipeline.from_pretrained(
    model_id, torch_dtype=torch.float16, cache_dir="cache"
)
pipe = pipe.to("cuda")

sprite_sides: Final = {
    "front": "PixelArtFSS",
    "right": "PixelArtRSS",
    "back": "PixelArtBSS",
    "left": "PixelArtLSS",
}


def torchGenerator(seed: Optional[int], max: int = 1024) -> Tuple[torch.Generator, int]:
    seed = seed or random.randrange(0, max)

    return torch.Generator("cuda").manual_seed(seed), seed


def generate(
    prompt: str,
    sfw_retries: int = 1,
    seed: Optional[int] = None,
) -> PILImage:
    """
    Generate a sprite image from a text description.

    Return a blank image if the model fails to generate a safe image.
    """

    generator = torchGenerator(seed)[0]
    image: PILImage | None = None

    for _ in range(sfw_retries):
        pipe_output = pipe(prompt, generator=generator, width=512, height=512)
        image = pipe_output.images[0]

        if not pipe_output.nsfw_content_detected[0]:
            break

        rand_seed = seed

        while rand_seed == seed:
            print(f"Regenerating `{prompt}` with different seed.")

            rand_seed = random.randrange(0, 1024)
            generator = torchGenerator(rand_seed)[0]

    return cast(PILImage, image)


def generate_sides(
    prompt: str, sfw_retries: int = 1, sides: Dict[str, str] = sprite_sides
) -> Tuple[Dict[str, PILImage], str]:
    """
    Generate sprite images from a text description of different sides.

    If both left and right side specified, duplicate and flip left side as the right side
    """

    print(f"Generating sprites for `{prompt}`")

    seed = random.randrange(0, 1024)
    sprites = {}

    # If both left and right side specified, duplicate and flip left side as the right side
    for side, label in sides.items():
        if side == "right" and "left" in sides and "right" in sides:
            continue

        sprites[side] = generate(f"({prompt}) [nsfw] [photograph] {label}", sfw_retries, seed)

    if "left" in sides and "right" in sides:
        sprites["right"] = sprites["left"].transpose(Image.Transpose.FLIP_LEFT_RIGHT)

    return sprites, prompt


def clean_sprite(
    image: PILImage,
    size: Tuple[int, int] = (192, 192),
    sharpness: float = 1.5,
    thresh: int = 128,
    rescaling: Optional[int] = None,
) -> PILImage:
    """
    Process image to be more sprite-like.

    `rescale` will first scale down by value, then up to specified size.
    """

    width, height = image.size
    sharpener = ImageEnhance.Sharpness(image)

    image = sharpener.enhance(sharpness)
    image = image.convert("RGBA")
    ImageDraw.floodfill(image, (0, 0), (255, 255, 255, 0), thresh=thresh)

    if type(rescaling) is int:
        image = image.resize(
            (int(width / rescaling), int(height / rescaling)),
            resample=Image.Resampling.NEAREST,
        )

    image = image.resize(size, resample=Image.Resampling.NEAREST)

    return image


def split_sprites(image: PILImage, size: Tuple[int, int] = (96, 96)) -> List[PILImage]:
    """Split sprite image into individual sides."""

    width, height = image.size
    w, h = size

    # fmt: off
    frames = [
        image.crop((
            0,
            int(h / 2),
            int(width / 4),
            int(height * 0.75),
        )),
        image.crop((
            int(width / 4),
            int(h / 2),
            int(width / 4) * 2,
            int(height * 0.75),
        )),
        image.crop((
            int(width / 4) * 2,
            int(h / 2),
            int(width / 4) * 3,
            int(height * 0.75),
        )),
        image.crop((
            int(width / 4) * 3,
            int(h / 2),
            width,
            int(height * 0.75),
        )),
    ]
    # fmt: on

    new_canvas = Image.new("RGBA", size, (255, 255, 255, 0))

    for i in range(len(frames)):
        canvas = new_canvas.copy()
        canvas.paste(frames[i], (int(w / 4), 0, int(w * 0.75), h))
        frames[i] = canvas

    return frames


def build_spritesheet(
    images: Dict[str, PILImage],
    text: str = "sd_pixelart",
    sprite_size: Tuple[int, int] = (96, 96),
    dir: str = "output",
    save: bool = False,
    timestamp: Optional[int] = None,
    thresh: int = 128,
) -> Tuple[PILImage, str | None]:
    """
    Build sprite sheet from sides.

    1. Clean and scale each image
    2. Split each image into individual frames
    3. Create a new spritesheet canvas for all sides[frames]
    4. Paste each individial frame onto canvas
    """

    frames = {}
    width, height = sprite_size
    text = re.sub(r"[^\w()[\]_-]", "", text)
    filepath = None

    for side, image in images.items():
        image = clean_sprite(image, (width * 2, height * 2), thresh=thresh)
        frames[side] = split_sprites(image, sprite_size)

    canvas = Image.new(
        "RGBA",
        (width * len(frames["front"]), height * len(frames)),
        (255, 255, 255, 0),
    )

    for j in range(len(frames["front"])):
        for k, side in enumerate(frames):
            canvas.paste(
                frames[side][j],
                (
                    j * width,
                    k * height,
                    j * width + width,
                    k * height + height,
                ),
            )

    spritesheet = io.BytesIO()
    canvas.save(spritesheet, "PNG")

    if save:
        timestamp = timestamp or int(time.time())
        filepath = os.path.join(dir, f"{timestamp}_{text}.png")
        canvas.save(filepath)

    return Image.open(spritesheet), filepath


def build_gifs(
    images: Dict[str, PILImage],
    text: str = "sd_spritesheet",
    dir: str = "output",
    duration: int | List[int] | Tuple[int, ...] = (300, 450, 300, 450),
    save: bool = False,
    timestamp: Optional[int] = None,
    thresh: int = 128,
) -> Tuple[Dict[str, List[PILImage]], List[str] | None]:
    """Build animated GIFs from side frames."""

    gifs = {}
    text = re.sub(r"[^\w()[\]_-]", "", text)
    filepaths = [] if save else None

    for side, image in images.items():
        image = clean_sprite(image, thresh=thresh)
        frames = split_sprites(image)

        gif = io.BytesIO()

        options = {
            "fp": gif,
            "format": "GIF",
            "save_all": True,
            "append_images": frames[1:],
            "disposal": 3,
            "duration": duration,
            "loop": 0,
        }

        frames[0].save(**options)
        gifs[side] = Image.open(gif)

        if save:
            timestamp = timestamp or int(time.time())
            filepath = os.path.join(dir, f"{timestamp}_{text}_{side}.gif")
            filepaths.append(filepath)

            options.update({"fp": filepath})
            frames[0].save(**options)

    return gifs, filepaths