import gradio as gr
import spaces
import os
import sys
import subprocess
import numpy as np
from PIL import Image
import cv2

import torch

from diffusers import StableDiffusion3ControlNetPipeline
from diffusers.models import SD3ControlNetModel
from diffusers.utils import load_image

# Load pipeline
controlnet_canny = SD3ControlNetModel.from_pretrained("InstantX/SD3-Controlnet-Canny")
pipe = StableDiffusion3ControlNetPipeline.from_pretrained(
    "stabilityai/stable-diffusion-3-medium-diffusers",
    controlnet=controlnet_canny
).to("cuda", torch.float16)

def resize_image(input_path, output_path, target_height):
    # Open the input image
    img = Image.open(input_path)

    # Calculate the aspect ratio of the original image
    original_width, original_height = img.size
    original_aspect_ratio = original_width / original_height

    # Calculate the new width while maintaining the aspect ratio and the target height
    new_width = int(target_height * original_aspect_ratio)

    # Resize the image while maintaining the aspect ratio and fixing the height
    img = img.resize((new_width, target_height), Image.LANCZOS)

    # Save the resized image
    img.save(output_path)

    return output_path, new_width, target_height


@spaces.GPU(duration=90)
def infer(
    image_in, 
    prompt, 
    negative_prompt="",
    inference_steps=25,
    guidance_scale=7.0,
    control_weight=0.7,
    progress=gr.Progress(track_tqdm=True)
):
    # Canny preprocessing
    control_image = load_image(image_in)
    control_image = control_image.convert('L')
    control_image = np.array(control_image)
    control_image = np.stack([control_image] * 3, axis=-1)
    control_image = Image.fromarray(control_image)
 
    # Infer
    image = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        control_image=control_image, 
        controlnet_conditioning_scale=control_weight,
        num_inference_steps=inference_steps,
        guidance_scale=guidance_scale,
    ).images[0]

    _, w, h = resize_image(image_in, "resized_input.jpg", 1024)
    image = image.resize((w, h), Image.LANCZOS)
    
    return image, gr.update(value=control_image , visible=True)



css = """
#col-container{
    margin: 0 auto;
    max-width: 1080px;
}
"""
with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown("""
        # SD3 ControlNet
        Experiment with Stable Diffusion 3 ControlNet models proposed and maintained by the InstantX team.<br />
        Model card: [InstantX/SD3-Controlnet-Canny](https://huggingface.co/InstantX/SD3-Controlnet-Canny)
        """)
        
        with gr.Column():
            
            with gr.Row():
                with gr.Column():
                    image_in = gr.Image(label="Image reference", sources=["upload"], type="filepath")
                    prompt = gr.Textbox(label="Prompt")
                    negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="Enter negative prompts here")
                    
                    with gr.Accordion("Advanced settings", open=False):
                        with gr.Column():
                            with gr.Row():
                                inference_steps = gr.Slider(label="Inference steps", minimum=1, maximum=100, step=1, value=50)
                                guidance_scale = gr.Slider(label="Guidance scale", minimum=1.0, maximum=10.0, step=0.1, value=7.0)
                            control_weight = gr.Slider(label="Control Weight", minimum=0.0, maximum=1.0, step=0.01, value=0.7)
                    
                    submit_canny_btn = gr.Button("Submit")
                    
                with gr.Column():
                    result = gr.Image(label="Result")
                    canny_used = gr.Image(label="Preprocessed Canny", visible=False)


    submit_canny_btn.click(
        fn=infer,
        inputs=[image_in, prompt, negative_prompt, inference_steps, guidance_scale, control_weight],
        outputs=[result, canny_used],
        api_name="predict",
        show_api=True
    )

demo.queue().launch(show_api=True)