import torch

torch.jit.script = lambda f: f

import spaces
import gradio as gr
from diffusers import FluxInpaintPipeline
from PIL import Image, ImageFile

# ImageFile.LOAD_TRUNCATED_IMAGES = True

# Initialize the pipeline
pipe = FluxInpaintPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
)
pipe.to("cuda")
pipe.load_lora_weights(
    "ysmao/multiview-incontext",
    weight_name="twoview-incontext-b03.safetensors",
)


def fractional_resize_image(img, target_size=864):
    if img.mode in ("RGBA", "P"):
        img = img.convert("RGB")

    width, height = img.size
    scale_factor = target_size / max(width, height)
    return img.resize(
        (int(width * scale_factor), int(height * scale_factor)),
        Image.Resampling.LANCZOS,
    )


def duplicate_horizontally(img):
    width, height = img.size

    new_image = Image.new("RGB", (width * 2, height))
    new_image.paste(img, (0, 0))
    new_image.paste(img, (width, 0))

    mask_image = Image.new("RGB", (width * 2, height), (255, 255, 255))
    left_mask = Image.new(
        "RGB",
        (width, height),
        (0, 0, 0),
    )
    mask_image.paste(left_mask, (0, 0))

    return new_image, mask_image


@spaces.GPU(duration=55)
def generate(
    image, prompt_description, prompt_user, progress=gr.Progress(track_tqdm=True)
):
    prompt_structure = (
        "[TWO-VIEWS] This set of two images presents a scene from two different viewpoints. [IMAGE1] The first image shows "
        + prompt_description
        + " [IMAGE2] The second image shows the same room but in another viewpoint "
    )
    prompt = prompt_structure + prompt_user + "."

    resized_image = fractional_resize_image(image)
    image_twoview, mask_image = duplicate_horizontally(resized_image)

    image_width, image_height = image_twoview.size

    out = pipe(
        prompt=prompt,
        image=image_twoview,
        mask_image=mask_image,
        guidance_scale=3.5,
        height=image_height,
        width=image_width,
        num_inference_steps=28,
        max_sequence_length=256,
        strength=1,
    ).images[0]

    width, height = out.size
    half_width = width // 2
    image_2 = out.crop((half_width, 0, width, height))
    return image_2, out


with gr.Blocks() as demo:
    gr.Markdown("# MultiView in Context")
    gr.Markdown(
        "### [In-Context LoRA](https://huggingface.co/ali-vilab/In-Context-LoRA) + Image-to-Image + Inpainting. Diffusers implementation based on the [workflow by WizardWhitebeard/klinter](https://civitai.com/articles/8779)"
    )
    gr.Markdown(
        "### Using [MultiView In-Context LoRA](https://huggingface.co/ysmao/multiview-incontext)"
    )
    gr.Markdown(
        "> **_NOTE:_** This is a beta release of the model. The consistency between views may not be perfect. I am working on improving the consistency and spatial relationships between generated views."
    )

    with gr.Tab("Demo"):
        with gr.Row():
            with gr.Column():
                input_image = gr.Image(
                    label="Upload Source Image", type="pil", height=384
                )
                prompt_description = gr.Textbox(
                    label="Describe the source image",
                    placeholder="a living room with a sofa set with cushions, side tables with table lamps, a flat screen television on a table, houseplants, wall hangings, electric lights, and a carpet on the floor",
                )
                prompt_input = gr.Textbox(
                    label="Any additional description to the new viewpoint?",
                    placeholder="",
                )
                generate_btn = gr.Button("Generate Application", variant="primary")

            with gr.Column():
                output_image = gr.Image(label="Generated Application")
                output_side = gr.Image(label="Side by side")

        gr.Examples(
            examples=[
                [
                    "livingroom_fluxdev.jpg",
                    "a living room with a sofa set with cushions, side tables with table lamps, a flat screen television on a table, houseplants, wall hangings, electric lights, and a carpet on the floor",
                    "",
                ],
                [
                    "bedroom_fluxdev.jpg",
                    "a bedroom with a bed, dresser, and window. The bed is covered with a blanket and pillows, and there is a carpet on the floor. The walls are adorned with photo frames, and the windows have curtains. Through the window, we can see trees outside.",
                    "",
                ],
            ],
            inputs=[input_image, prompt_description, prompt_input],
            outputs=[output_image, output_side],
            fn=generate,
            cache_examples="lazy",
        )

        with gr.Row():
            gr.Markdown(
                """
            ### Instructions:
            1. Upload a source image
            2. Describe the source image
            3. Click 'Generate Application' and wait for the result
    
            Note: The generation process might take a few moments.
            """
            )
    # Set up the click event
    generate_btn.click(
        fn=generate,
        inputs=[input_image, prompt_description, prompt_input],
        outputs=[output_image, output_side],
    )

demo.launch()