import gradio as gr
import torch
import spaces
from huggingface_hub import hf_hub_download
from diffusers import FluxControlPipeline, FluxTransformer2DModel
import numpy as np

####################################
#   Load the model(s) on GPU       #
####################################
path = "sayakpaul/FLUX.1-dev-edit-v0"
edit_transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=torch.bfloat16)
pipeline = FluxControlPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev", transformer=edit_transformer, torch_dtype=torch.bfloat16
).to("cuda")
pipeline.load_lora_weights(
    hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd"
)
pipeline.set_adapters(["hyper-sd"], adapter_weights=[0.125])

MAX_SEED = np.iinfo(np.int32).max

def get_seed(randomize_seed: bool, seed: int) -> int:
    """
    Get the random seed.
    """
    return 

#####################################
#  The function for our Gradio app  #
#####################################
@spaces.GPU(duration=120)
def generate(prompt, input_image, seed):
    """
    Runs the Flux Control pipeline for editing the given `input_image`
    with the specified `prompt`. The pipeline is on GPU by default.
    """
    generator = torch.Generator(device="cuda").manual_seed(seed)  # Maintain reproducibility

    output_image = pipeline(
        control_image=input_image,
        prompt=prompt,
        guidance_scale=30.,
        num_inference_steps=8,
        max_sequence_length=512,
        height=input_image.height,
        width=input_image.width,
        generator=generator  # Pass the seeded generator
    ).images[0]

    return output_image

def launch_app():
    css = '''
    .gradio-container{max-width: 1100px !important}
    '''
    with gr.Blocks(css=css) as demo:
        gr.Markdown(
            """
            # Flux Control Editing 🖌️
            Edit any image with the [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) 
            [Flux Control edit framework](https://github.com/sayakpaul/flux-image-editing) by [Sayak Paul](https://huggingface.co/sayakpaul).
            """
        )
        with gr.Row():
            with gr.Column():
                with gr.Group():
                    input_image = gr.Image(
                        label="Image you would like to edit",
                        type="pil",
                    )
                    prompt = gr.Textbox(
                        label="Your edit prompt",
                        placeholder="e.g. 'Turn the color of the mushroom to blue'"
                    )
                    seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
                    generate_button = gr.Button("Generate")
        
            output_image = gr.Image(label="Edited Image")

        # Connect button to function
        generate_button.click(
            fn=generate,
            inputs=[prompt, input_image, seed],
            outputs=[output_image],
        )

        gr.Examples(
            examples=[
                ["Turn the color of the mushroom to gray", "mushroom.jpg", 42],
                ["Make the mushroom polka-dotted", "mushroom.jpg", 100],
            ],
            inputs=[prompt, input_image, seed],
            outputs=[output_image],
            fn=generate,
            cache_examples="lazy"
        )
        gr.Markdown(
            """
            **Acknowledgements**: 
            - [Sayak Paul](https://huggingface.co/sayakpaul) for open-sourcing FLUX.1-dev-edit-v0 
            - [black-forest-labs](https://huggingface.co/black-forest-labs) for FLUX.1-dev
            - [ByteDance/Hyper-SD](https://huggingface.co/ByteDance/Hyper-SD) for the Turbo LoRA which we use to speed up inference
            """
        )
    return demo

if __name__ == "__main__":
    demo = launch_app()
    demo.launch()