import gradio as gr
import spaces
import torch
from diffusers import AutoencoderKL, ControlNetModel, TCDScheduler
from gradio_imageslider import ImageSlider
from image_gen_aux import LineArtPreprocessor
from PIL import Image, ImageEnhance

from controlnet_union import ControlNetModel_Union
from pipeline_sdxl_recolor import StableDiffusionXLRecolorPipeline

lineart_preprocessor = LineArtPreprocessor.from_pretrained("OzzyGT/lineart").to("cuda")

controlnet = [
    ControlNetModel.from_pretrained(
        "OzzyGT/ControlNet-recolorXL", torch_dtype=torch.float16, variant="fp16"
    ),
    ControlNetModel_Union.from_pretrained(
        "OzzyGT/controlnet-union-promax-sdxl-1.0",
        torch_dtype=torch.float16,
        variant="fp16",
    ),
]

vae = AutoencoderKL.from_pretrained(
    "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
).to("cuda")

pipe = StableDiffusionXLRecolorPipeline.from_pretrained(
    "recoilme/ColorfulXL-Lightning",
    torch_dtype=torch.float16,
    vae=vae,
    controlnet=controlnet,
    variant="fp16",
).to("cuda")

pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)

pipe.load_ip_adapter(
    "h94/IP-Adapter",
    subfolder="sdxl_models",
    weight_name="ip-adapter_sdxl_vit-h.safetensors",
    image_encoder_folder="models/image_encoder",
)

scale = {
    "up": {"block_0": [1.0, 0.0, 1.0]},
}
pipe.set_ip_adapter_scale(scale)

prompt = "high quality color photo, sharp, detailed, 4k, colorized, remastered"
negative_prompt = "blurry, low resolution, bad quality, pixelated, black and white, b&w, grayscale, monochrome, sepia"

(
    prompt_embeds,
    negative_prompt_embeds,
    pooled_prompt_embeds,
    negative_pooled_prompt_embeds,
) = pipe.encode_prompt(prompt, negative_prompt, "cuda", True)


@spaces.GPU(duration=16)
def recolor_image(image):
    source_image = image["background"]

    lineart_image = lineart_preprocessor(source_image, resolution_scale=0.7)[0]

    for image in pipe(
        prompt_embeds=prompt_embeds,
        negative_prompt_embeds=negative_prompt_embeds,
        pooled_prompt_embeds=pooled_prompt_embeds,
        negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
        image=[source_image, lineart_image],
        ip_adapter_image=source_image,
        num_inference_steps=8,
        guidance_scale=2.0,
        controlnet_conditioning_scale=[1.0, 0.5],
        control_guidance_end=[1.0, 0.9],
    ):
        yield source_image, image

    image = image.convert("RGBA")
    source_image = source_image.convert("RGBA")

    enhancer = ImageEnhance.Color(image)
    image = enhancer.enhance(4.0)

    alpha = image.split()[3]
    alpha = alpha.point(lambda p: p * 0.20)
    image.putalpha(alpha)

    merged_image = Image.alpha_composite(source_image, image)

    yield source_image, merged_image


def clear_result():
    return gr.update(value=None)


css = """
.gradio-container {
    width: 1024px !important;
}
"""


title = """<h1 align="center">Diffusers Image Recolor</h1>
<div align="center">Upload a grayscale image to colorize it.</div>
<div align="center">This space is a PoC made for the guide <a href='https://huggingface.co/blog/OzzyGT/diffusers-recolor'>Recoloring photos with diffusers</a>.</div>
"""

with gr.Blocks(css=css) as demo:
    gr.HTML(title)

    run_button = gr.Button("Generate")

    with gr.Row():
        input_image = gr.ImageEditor(
            type="pil",
            label="Input Image",
            crop_size=(1024, 1024),
            canvas_size=(1024, 1024),
            layers=False,
            eraser=False,
            brush=False,
            sources=["upload"],
            image_mode="RGB",
        )

        result = ImageSlider(interactive=False, label="Generated Image", type="pil")

    run_button.click(
        fn=clear_result,
        inputs=None,
        outputs=result,
    ).then(
        fn=recolor_image,
        inputs=[input_image],
        outputs=result,
    )


demo.launch(share=False)