from typing import Tuple

import requests
import random
import numpy as np
import gradio as gr
import spaces
import torch
from PIL import Image
from diffusers import FluxInpaintPipeline
from huggingface_hub import login
import os
import time
from gradio_imageslider import ImageSlider


MARKDOWN = """
# FLUX.1 Inpainting with lora
"""

MAX_SEED = np.iinfo(np.int32).max
IMAGE_SIZE = 1024
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
HF_TOKEN = os.environ.get("HF_TOKEN")

login(token=HF_TOKEN)



class calculateDuration:
    def __init__(self, activity_name=""):
        self.activity_name = activity_name

    def __enter__(self):
        self.start_time = time.time()
        self.start_time_formatted = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.start_time))
        print(f"Activity: {self.activity_name}, Start time: {self.start_time_formatted}")
        return self
    
    def __exit__(self, exc_type, exc_value, traceback):
        self.end_time = time.time()
        self.elapsed_time = self.end_time - self.start_time
        self.end_time_formatted = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.end_time))
        
        if self.activity_name:
            print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
        else:
            print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
        
        print(f"Activity: {self.activity_name}, End time: {self.start_time_formatted}")


def remove_background(image: Image.Image, threshold: int = 50) -> Image.Image:
    image = image.convert("RGBA")
    data = image.getdata()
    new_data = []
    for item in data:
        avg = sum(item[:3]) / 3
        if avg < threshold:
            new_data.append((0, 0, 0, 0))
        else:
            new_data.append(item)

    image.putdata(new_data)
    return image

pipe = FluxInpaintPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE)


def resize_image_dimensions(
    original_resolution_wh: Tuple[int, int],
    maximum_dimension: int = IMAGE_SIZE
) -> Tuple[int, int]:
    width, height = original_resolution_wh

    # if width <= maximum_dimension and height <= maximum_dimension:
    #     width = width - (width % 32)
    #     height = height - (height % 32)
    #     return width, height

    if width > height:
        scaling_factor = maximum_dimension / width
    else:
        scaling_factor = maximum_dimension / height

    new_width = int(width * scaling_factor)
    new_height = int(height * scaling_factor)

    new_width = new_width - (new_width % 32)
    new_height = new_height - (new_height % 32)

    return new_width, new_height


@spaces.GPU(duration=100)
def process(
    input_image_editor: dict,
    lora_path: str,
    lora_weights: str,
    lora_scale: float,
    trigger_word: str,
    input_text: str,
    seed_slicer: int,
    randomize_seed_checkbox: bool,
    strength_slider: float,
    num_inference_steps_slider: int,
    progress=gr.Progress(track_tqdm=True)
):
    if not input_text:
        gr.Info("Please enter a text prompt.")
        return None, None

    image = input_image_editor['background']
    mask = input_image_editor['layers'][0]

    if not image:
        gr.Info("Please upload an image.")
        return None, None

    if not mask:
        gr.Info("Please draw a mask on the image.")
        return None, None

    with calculateDuration("resize image"):
        width, height = resize_image_dimensions(original_resolution_wh=image.size)
        resized_image = image.resize((width, height), Image.LANCZOS)
        resized_mask = mask.resize((width, height), Image.LANCZOS)
    
    with calculateDuration("load lora"):
        pipe.load_lora_weights(lora_path, weight_name=lora_weights)
    
    if randomize_seed_checkbox:
        seed_slicer = random.randint(0, MAX_SEED)
    generator = torch.Generator().manual_seed(seed_slicer)

    with calculateDuration("run pipe"):
        result = pipe(
            prompt=f"{input_text} {trigger_word}",
            image=resized_image,
            mask_image=resized_mask,
            width=width,
            height=height,
            strength=strength_slider,
            generator=generator,
            num_inference_steps=num_inference_steps_slider,
            joint_attention_kwargs={"scale": lora_scale},
        ).images[0]
    
    return [resized_image, result], resized_mask


with gr.Blocks() as demo:
    gr.Markdown(MARKDOWN)
    with gr.Row():
        with gr.Column():
            input_image_editor_component = gr.ImageEditor(
                label='Image',
                type='pil',
                sources=["upload", "webcam"],
                image_mode='RGB',
                layers=False,
                brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"))
            
        
            with gr.Accordion("Prompt Settings", open=True):

                input_text_component = gr.Textbox(
                    label="Inpaint prompt",
                    show_label=True,
                    max_lines=1,
                    placeholder="Enter your prompt",
                )
                trigger_word = gr.Textbox(
                    label="Lora trigger word",
                    show_label=True,
                    max_lines=1,
                    placeholder="Enter your lora trigger word here",
                    value="a photo of TOK"
                    
                )

                submit_button_component = gr.Button(
                    value='Submit', variant='primary', scale=0)

            with gr.Accordion("Lora Settings", open=True):
                lora_path = gr.Textbox(
                    label="Lora model path",
                    show_label=True,
                    max_lines=1,
                    placeholder="Enter your model path",
                    info="Currently, only LoRA hosted on Hugging Face'model can be loaded properly.",
                    value="XLabs-AI/flux-lora-collection"
                )
                lora_weights = gr.Textbox(
                    label="Lora weights",
                    show_label=True,
                    max_lines=1,
                    placeholder="Enter your lora weights name",
                    value="anime_lora.safetensors"
                )
                lora_scale = gr.Slider(
                    label="Lora scale",
                    show_label=True,
                    minimum=0,
                    maximum=1,
                    step=0.1,
                    value=0.9,
                )
                
            with gr.Accordion("Advanced Settings", open=True):
                
                
                seed_slicer_component = gr.Slider(
                    label="Seed",
                    minimum=0,
                    maximum=MAX_SEED,
                    step=1,
                    value=42,
                )

                randomize_seed_checkbox_component = gr.Checkbox(
                    label="Randomize seed", value=True)

                with gr.Row():
                    strength_slider_component = gr.Slider(
                        label="Strength",
                        info="Indicates extent to transform the reference `image`. "
                             "Must be between 0 and 1. `image` is used as a starting "
                             "point and more noise is added the higher the `strength`.",
                        minimum=0,
                        maximum=1,
                        step=0.01,
                        value=0.85,
                    )

                    num_inference_steps_slider_component = gr.Slider(
                        label="Number of inference steps",
                        info="The number of denoising steps. More denoising steps "
                             "usually lead to a higher quality image at the",
                        minimum=1,
                        maximum=50,
                        step=1,
                        value=20,
                    )
        with gr.Column():
            output_image_component = ImageSlider(label="Generate image", type="pil", slider_color="pink")
            
            with gr.Accordion("Debug", open=False):
                output_mask_component = gr.Image(
                    type='pil', image_mode='RGB', label='Input mask', format="png")

    submit_button_component.click(
        fn=process,
        inputs=[
            input_image_editor_component,
            lora_path,
            lora_weights,
            lora_scale,
            trigger_word,
            input_text_component,
            seed_slicer_component,
            randomize_seed_checkbox_component,
            strength_slider_component,
            num_inference_steps_slider_component
        ],
        outputs=[
            output_image_component,
            output_mask_component
        ]
    )

demo.launch(debug=False, show_error=True)