import os
import re
import requests
import tempfile

import gradio as gr
from PIL import Image, ImageDraw

from config import theme
from public.data.images.loras.flux1 import loras as flux1_loras

# os.makedirs(os.getenv("HF_HOME"), exist_ok=True)

# UI
with gr.Blocks(
    theme=theme,
    fill_width=True,
    css_paths=[os.path.join("static/css", f) for f in os.listdir("static/css")],
) as demo:

    # States
    data_state = gr.State()
    local_state = gr.BrowserState(
        {
            "selected_loras": [],
        }
    )

    with gr.Row():
        with gr.Column(scale=1):
            gr.Label("AllFlux", show_label=False)

            with gr.Accordion("Settings", open=True):
                with gr.Group():
                    height_slider = gr.Slider(
                        minimum=64,
                        maximum=2048,
                        value=1024,
                        step=64,
                        label="Height",
                        interactive=True,
                    )
                    width_slider = gr.Slider(
                        minimum=64,
                        maximum=2048,
                        value=1024,
                        step=64,
                        label="Width",
                        interactive=True,
                    )

                with gr.Group():
                    num_images_slider = gr.Slider(
                        minimum=1,
                        maximum=4,
                        value=1,
                        step=1,
                        label="Number of Images",
                        interactive=True,
                    )

                toggles = gr.CheckboxGroup(
                    choices=["Realtime", "Randomize Seed"],
                    value=["Randomize Seed"],
                    show_label=False,
                    interactive=True,
                )

            with gr.Accordion("Advanced", open=False):
                num_steps_slider = gr.Slider(
                    minimum=1,
                    maximum=100,
                    value=20,
                    step=1,
                    label="Steps",
                    interactive=True,
                )
                guidance_scale_slider = gr.Slider(
                    minimum=1,
                    maximum=10,
                    value=3.5,
                    step=0.1,
                    label="Guidance Scale",
                    interactive=True,
                )
                seed_slider = gr.Slider(
                    minimum=0,
                    maximum=4294967295,
                    value=42,
                    step=1,
                    label="Seed",
                    interactive=True,
                )
                upscale_slider = gr.Slider(
                    minimum=2,
                    maximum=4,
                    value=2,
                    step=2,
                    label="Upscale",
                    interactive=True,
                )
                scheduler_dropdown = gr.Dropdown(
                    label="Scheduler",
                    choices=[
                        "Euler a",
                        "Euler",
                        "LMS",
                        "Heun",
                        "DPM++ 2",
                        "DPM++ 2 a",
                        "DPM++ SDE",
                        "DPM++ SDE Karras",
                        "DDIM",
                        "PLMS",
                    ],
                    value="Euler a",
                    interactive=True,
                )

            gr.LoginButton()

            gr.Markdown(
                """
            Yurrrrrrrrrrrr, WIP
            """
            )

        with gr.Column(scale=3):
            with gr.Group():
                with gr.Row():
                    prompt = gr.Textbox(
                        show_label=False,
                        placeholder="Enter your prompt here...",
                        lines=3,
                    )

                with gr.Row():
                    with gr.Column(scale=3):
                        submit_btn = gr.Button("Submit")
                    with gr.Column(scale=1):
                        ai_improve_btn = gr.Button("💡", link="#improve-prompt")

            with gr.Group():
                output_gallery = gr.Gallery(
                    label="Outputs", interactive=False, height=500
                )

                with gr.Row():
                    upscale_selected_btn = gr.Button("Upscale Selected", size="sm")
                    upscale_all_btn = gr.Button("Upscale All", size="sm")
                    create_similar_btn = gr.Button("Create Similar", size="sm")

            with gr.Accordion("Output History", open=False):
                with gr.Group():
                    output_history_gallery = gr.Gallery(
                        show_label=False, interactive=False, height=500
                    )

                    with gr.Row():
                        clear_history_btn = gr.Button("Clear All", size="sm")
                        download_history_btn = gr.Button("Download All", size="sm")

            with gr.Accordion("Image Playground", open=True):

                def show_info(content: str | None = None):
                    info_checkbox = gr.Checkbox(
                        value=False, label="Show Info", interactive=True
                    )

                    @gr.render(inputs=info_checkbox)
                    def show_info(info_checkbox):
                        return (
                            gr.Markdown(
                                f"""Sup, need some help here, please check the community tab. {content}"""
                            )
                            if info_checkbox
                            else None
                        )

                with gr.Tabs():
                    with gr.Tab("Img 2 Img"):
                        with gr.Group():
                            img2img_img = gr.Image(show_label=False, interactive=True)
                            img2img_strength_slider = gr.Slider(
                                minimum=0,
                                maximum=1,
                                value=1.0,
                                step=0.1,
                                label="Strength",
                                interactive=True,
                            )

                        show_info()

                    with gr.Tab("Inpaint"):
                        with gr.Group():
                            inpaint_img = gr.ImageMask(
                                show_label=False, interactive=True, type="pil"
                            )
                            generate_mask_btn = gr.Button(
                                "Remove Background", size="sm"
                            )
                        
                        use_fill_pipe_inpaint = gr.Checkbox(
                            value=True,
                            label="Use Fill Pipeline 🧪",
                            interactive=True,
                        )

                        show_info()

                        inpaint_img.upload(
                            fn=lambda x: (
                                gr.update(height=x["layers"][0].height + 96)
                                if x is not None
                                else None
                            ),
                            inputs=inpaint_img,
                            outputs=inpaint_img,
                        )
                    with gr.Tab("Outpaint"):
                        outpaint_img = gr.Image(
                            show_label=False, interactive=True, type="pil"
                        )

                        with gr.Row(equal_height=True):
                            with gr.Column(scale=3):
                                ratio_9_16 = gr.Radio(
                                    label="Image Ratio",
                                    choices=["9:16", "16:9", "1:1", "Height & Width"],
                                    value="9:16",
                                    container=True,
                                    interactive=True,
                                )

                            with gr.Column(scale=1):
                                mask_position = gr.Dropdown(
                                    choices=[
                                        "Middle",
                                        "Left",
                                        "Right",
                                        "Top",
                                        "Bottom",
                                    ],
                                    value="Middle",
                                    label="Alignment",
                                    interactive=True,
                                )

                        with gr.Group():
                            resize_options = gr.Radio(
                                choices=["Full", "75%", "50%", "33%", "25%", "Custom"],
                                value="Full",
                                label="Resize",
                                interactive=True,
                            )

                            resize_option_custom = gr.State()
                            @gr.render(inputs=resize_options)
                            def resize_options_render(resize_option):
                                if resize_option == "Custom":
                                    resize_option_custom = gr.Slider(
                                        minimum=1,
                                        maximum=100,
                                        value=50,
                                        step=1,
                                        label="Custom Size %",
                                        interactive=True,
                                    )

                        with gr.Accordion("Advanced settings", open=False):
                            with gr.Group():
                                mask_overlap_slider = gr.Slider(
                                    label="Mask Overlap %",
                                    minimum=1,
                                    maximum=50,
                                    value=10,
                                    step=1,
                                    interactive=True,
                                )
                                with gr.Row():
                                    overlap_top = gr.Checkbox(
                                        value=True,
                                        label="Overlap Top",
                                        interactive=True,
                                    )
                                    overlap_right = gr.Checkbox(
                                        value=True,
                                        label="Overlap Right",
                                        interactive=True,
                                    )
                                with gr.Row():
                                    overlap_left = gr.Checkbox(
                                        value=True,
                                        label="Overlap Left",
                                        interactive=True,
                                    )
                                    overlap_bottom = gr.Checkbox(
                                        value=True,
                                        label="Overlap Bottom",
                                        interactive=True,
                                    )
                                mask_preview_btn = gr.Button(
                                    "Preview", interactive=True
                                )

                            mask_preview_img = gr.Image(
                                show_label=False, visible=False, interactive=True
                            )

                            def prepare_image_and_mask(
                                image,
                                width,
                                height,
                                overlap_percentage,
                                resize_option,
                                custom_resize_percentage,
                                alignment,
                                overlap_left,
                                overlap_right,
                                overlap_top,
                                overlap_bottom,
                            ):
                                target_size = (width, height)

                                scale_factor = min(
                                    target_size[0] / image.width,
                                    target_size[1] / image.height,
                                )
                                new_width = int(image.width * scale_factor)
                                new_height = int(image.height * scale_factor)

                                source = image.resize(
                                    (new_width, new_height), Image.LANCZOS
                                )

                                if resize_option == "Full":
                                    resize_percentage = 100
                                elif resize_option == "75%":
                                    resize_percentage = 75
                                elif resize_option == "50%":
                                    resize_percentage = 50
                                elif resize_option == "33%":
                                    resize_percentage = 33
                                elif resize_option == "25%":
                                    resize_percentage = 25
                                else:  # Custom
                                    resize_percentage = custom_resize_percentage

                                # Calculate new dimensions based on percentage
                                resize_factor = resize_percentage / 100
                                new_width = int(source.width * resize_factor)
                                new_height = int(source.height * resize_factor)

                                # Ensure minimum size of 64 pixels
                                new_width = max(new_width, 64)
                                new_height = max(new_height, 64)

                                # Resize the image
                                source = source.resize(
                                    (new_width, new_height), Image.LANCZOS
                                )

                                # Calculate the overlap in pixels based on the percentage
                                overlap_x = int(new_width * (overlap_percentage / 100))
                                overlap_y = int(new_height * (overlap_percentage / 100))

                                # Ensure minimum overlap of 1 pixel
                                overlap_x = max(overlap_x, 1)
                                overlap_y = max(overlap_y, 1)

                                # Calculate margins based on alignment
                                if alignment == "Middle":
                                    margin_x = (target_size[0] - new_width) // 2
                                    margin_y = (target_size[1] - new_height) // 2
                                elif alignment == "Left":
                                    margin_x = 0
                                    margin_y = (target_size[1] - new_height) // 2
                                elif alignment == "Right":
                                    margin_x = target_size[0] - new_width
                                    margin_y = (target_size[1] - new_height) // 2
                                elif alignment == "Top":
                                    margin_x = (target_size[0] - new_width) // 2
                                    margin_y = 0
                                elif alignment == "Bottom":
                                    margin_x = (target_size[0] - new_width) // 2
                                    margin_y = target_size[1] - new_height

                                # Adjust margins to eliminate gaps
                                margin_x = max(
                                    0, min(margin_x, target_size[0] - new_width)
                                )
                                margin_y = max(
                                    0, min(margin_y, target_size[1] - new_height)
                                )

                                # Create a new background image and paste the resized source image
                                background = Image.new(
                                    "RGB", target_size, (255, 255, 255)
                                )
                                background.paste(source, (margin_x, margin_y))

                                # Create the mask
                                mask = Image.new("L", target_size, 255)
                                mask_draw = ImageDraw.Draw(mask)

                                # Calculate overlap areas
                                white_gaps_patch = 2

                                left_overlap = (
                                    margin_x + overlap_x
                                    if overlap_left
                                    else margin_x + white_gaps_patch
                                )
                                right_overlap = (
                                    margin_x + new_width - overlap_x
                                    if overlap_right
                                    else margin_x + new_width - white_gaps_patch
                                )
                                top_overlap = (
                                    margin_y + overlap_y
                                    if overlap_top
                                    else margin_y + white_gaps_patch
                                )
                                bottom_overlap = (
                                    margin_y + new_height - overlap_y
                                    if overlap_bottom
                                    else margin_y + new_height - white_gaps_patch
                                )

                                if alignment == "Left":
                                    left_overlap = (
                                        margin_x + overlap_x
                                        if overlap_left
                                        else margin_x
                                    )
                                elif alignment == "Right":
                                    right_overlap = (
                                        margin_x + new_width - overlap_x
                                        if overlap_right
                                        else margin_x + new_width
                                    )
                                elif alignment == "Top":
                                    top_overlap = (
                                        margin_y + overlap_y
                                        if overlap_top
                                        else margin_y
                                    )
                                elif alignment == "Bottom":
                                    bottom_overlap = (
                                        margin_y + new_height - overlap_y
                                        if overlap_bottom
                                        else margin_y + new_height
                                    )

                                # Draw the mask
                                mask_draw.rectangle(
                                    [
                                        (left_overlap, top_overlap),
                                        (right_overlap, bottom_overlap),
                                    ],
                                    fill=0,
                                )

                                return background, mask

                            mask_preview_btn.click(
                                fn=prepare_image_and_mask,
                                inputs=[
                                    outpaint_img,
                                    width_slider,
                                    height_slider,
                                    mask_overlap_slider,
                                    resize_options,
                                    resize_option_custom,
                                    mask_position,
                                    overlap_left,
                                    overlap_right,
                                    overlap_top,
                                    overlap_bottom,
                                ],
                                outputs=[mask_preview_img, outpaint_img],
                            )
                            mask_preview_img.clear(
                                fn=lambda: gr.update(visible=False),
                                outputs=mask_preview_img,
                            )

                        use_fill_pipe_outpaint = gr.Checkbox(
                            value=True,
                            label="Use Fill Pipeline 🧪",
                            interactive=True,
                        )

                        show_info()
                    with gr.Tab("In-Context"):
                        with gr.Group():
                            incontext_img = gr.Image(show_label=False, interactive=True)
                        # https://huggingface.co/spaces/Yuanshi/OminiControl
                        show_info(content="1024 res is in beta")
                    with gr.Tab("IP-Adapter"):
                        with gr.Group():
                            ip_adapter_img = gr.Image(
                                show_label=False, interactive=True
                            )
                            ip_adapter_img_scale = gr.Slider(
                                minimum=0,
                                maximum=1,
                                value=0.7,
                                step=0.1,
                                label="Scale",
                                interactive=True,
                            )
                        # https://huggingface.co/InstantX/FLUX.1-dev-IP-Adapter
                        show_info(content="1024 res is in beta")
                    with gr.Tab("Canny"):
                        with gr.Group():
                            canny_img = gr.Image(show_label=False, interactive=True)
                            with gr.Row(equal_height=True):
                                with gr.Column(scale=3):
                                    canny_controlnet_conditioning_scale = gr.Slider(
                                        minimum=0,
                                        maximum=1,
                                        value=0.65,
                                        step=0.05,
                                        label="ControlNet Conditioning Scale",
                                        interactive=True,
                                    )
                                with gr.Column(scale=1):
                                    canny_img_is_preprocessed = gr.Checkbox(
                                        value=True,
                                        label="Preprocessed",
                                        interactive=True,
                                    )
                    with gr.Tab("Tile"):
                        with gr.Group():
                            tile_img = gr.Image(show_label=False, interactive=True)
                            with gr.Row(equal_height=True):
                                with gr.Column(scale=3):
                                    tile_controlnet_conditioning_scale = gr.Slider(
                                        minimum=0,
                                        maximum=1,
                                        value=0.45,
                                        step=0.05,
                                        label="ControlNet Conditioning Scale",
                                        interactive=True,
                                    )
                                with gr.Column(scale=1):
                                    tile_img_is_preprocessed = gr.Checkbox(
                                        value=True,
                                        label="Preprocessed",
                                        interactive=True,
                                    )
                    with gr.Tab("Depth"):
                        with gr.Group():
                            depth_img = gr.Image(show_label=False, interactive=True)
                            with gr.Row(equal_height=True):
                                with gr.Column(scale=3):
                                    depth_controlnet_conditioning_scale = gr.Slider(
                                        minimum=0,
                                        maximum=1,
                                        value=0.55,
                                        step=0.05,
                                        label="ControlNet Conditioning Scale",
                                        interactive=True,
                                    )
                                with gr.Column(scale=1):
                                    depth_img_is_preprocessed = gr.Checkbox(
                                        value=True,
                                        label="Preprocessed",
                                        interactive=True,
                                    )
                    with gr.Tab("Blur"):
                        with gr.Group():
                            blur_img = gr.Image(show_label=False, interactive=True)
                            with gr.Row(equal_height=True):
                                with gr.Column(scale=3):
                                    blur_controlnet_conditioning_scale = gr.Slider(
                                        minimum=0,
                                        maximum=1,
                                        value=0.45,
                                        step=0.05,
                                        label="ControlNet Conditioning Scale",
                                        interactive=True,
                                    )
                                with gr.Column(scale=1):
                                    blur_img_is_preprocessed = gr.Checkbox(
                                        value=True,
                                        label="Preprocessed",
                                        interactive=True,
                                    )
                    with gr.Tab("Pose"):
                        with gr.Group():
                            pose_img = gr.Image(show_label=False, interactive=True)
                            with gr.Row(equal_height=True):
                                with gr.Column(scale=3):
                                    pose_controlnet_conditioning_scale = gr.Slider(
                                        minimum=0,
                                        maximum=1,
                                        value=0.55,
                                        step=0.05,
                                        label="ControlNet Conditioning Scale",
                                        interactive=True,
                                    )
                                with gr.Column(scale=1):
                                    pose_img_is_preprocessed = gr.Checkbox(
                                        value=True,
                                        label="Preprocessed",
                                        interactive=True,
                                    )
                    with gr.Tab("Gray"):
                        with gr.Group():
                            gray_img = gr.Image(show_label=False, interactive=True)
                            with gr.Row(equal_height=True):
                                with gr.Column(scale=3):
                                    gray_controlnet_conditioning_scale = gr.Slider(
                                        minimum=0,
                                        maximum=1,
                                        value=0.45,
                                        step=0.05,
                                        label="ControlNet Conditioning Scale",
                                        interactive=True,
                                    )
                                with gr.Column(scale=1):
                                    gray_img_is_preprocessed = gr.Checkbox(
                                        value=True,
                                        label="Preprocessed",
                                        interactive=True,
                                    )
                    with gr.Tab("Low Quality"):
                        with gr.Group():
                            low_quality_img = gr.Image(
                                show_label=False, interactive=True
                            )
                            with gr.Row(equal_height=True):
                                with gr.Column(scale=3):
                                    low_quality_controlnet_conditioning_scale = (
                                        gr.Slider(
                                            minimum=0,
                                            maximum=1,
                                            value=0.4,
                                            step=0.05,
                                            label="ControlNet Conditioning Scale",
                                            interactive=True,
                                        )
                                    )
                                with gr.Column(scale=1):
                                    low_quality_img_is_preprocessed = gr.Checkbox(
                                        value=True,
                                        label="Preprocessed",
                                        interactive=True,
                                    )
                    # with gr.Tab("Official Canny"):
                    #     with gr.Group():
                    #         gr.HTML(
                    #             """
                    #             <script
                    #             type="module"
                    #             src="https://gradio.s3-us-west-2.amazonaws.com/5.6.0/gradio.js"
                    #             ></script>

                    #             <gradio-app src="https://black-forest-labs-flux-1-canny-dev.hf.space"></gradio-app>
                    #         """
                    #         )
                    # with gr.Tab("Official Depth"):
                    #     with gr.Group():
                    #         gr.HTML(
                    #             """
                    #             <script
                    #             type="module"
                    #             src="https://gradio.s3-us-west-2.amazonaws.com/5.6.0/gradio.js"
                    #             ></script>

                    #             <gradio-app src="https://black-forest-labs-flux-1-depth-dev.hf.space"></gradio-app>
                    #             """
                    #         )
                    with gr.Tab("Auto Trainer"):
                        gr.HTML(
                            """
                            <script
                            type="module"
                            src="https://gradio.s3-us-west-2.amazonaws.com/4.42.0/gradio.js"
                            ></script>

                            <gradio-app src="https://autotrain-projects-train-flux-lora-ease.hf.space"></gradio-app>
                            """
                        )
                resize_mode_radio = gr.Radio(
                    label="Resize Mode",
                    choices=["Crop & Resize", "Resize Only", "Resize & Fill"],
                    value="Resize & Fill",
                    interactive=True,
                )

            with gr.Accordion("Prompt Generator", open=False):
                gr.HTML(
                    """
                    <gradio-app src="https://gokaygokay-flux-prompt-generator.hf.space"></gradio-app>
                    """
                )

        with gr.Column(scale=1):

            # Loras
            with gr.Accordion("Loras", open=True):
                selected_loras = gr.State([])
                lora_selector = gr.Gallery(
                    show_label=False,
                    value=[(l["image"], l["title"]) for l in flux1_loras],
                    container=False,
                    columns=3,
                    show_download_button=False,
                    show_fullscreen_button=False,
                    allow_preview=False,
                )
                with gr.Group():
                    lora_selected = gr.Textbox(
                        show_label=False,
                        placeholder="Select a Lora to apply...",
                        container=False,
                    )
                    add_lora_btn = gr.Button("Add Lora", size="sm")
                gr.Markdown(
                    "*You can add a Lora by entering a URL or a Hugging Face repo path."
                )

                # update the selected_loras state with the new lora
                @gr.render(
                    inputs=[lora_selected, selected_loras],
                    triggers=[add_lora_btn.click],
                )
                def add_lora(lora_selected):
                    title = None
                    weights = None
                    info = None
                    if isinstance(lora_selected, int):
                        # Add from lora selector
                        title = lora_selector[lora_selected]["title"]
                        weights = lora_selector[lora_selected]["weights"]
                        info = lora_selector[lora_selected]["trigger_word"]
                    elif isinstance(lora_selected, str):
                        # check if url
                        if lora_selected.startswith("http"):
                            # Check if it's a CivitAI URL
                            if "civitai.com/models/" in lora_selected:
                                try:
                                    # Extract model ID and version ID from URL
                                    model_id = re.search(
                                        r"/models/(\d+)", lora_selected
                                    ).group(1)
                                    version_id = re.search(
                                        r"modelVersionId=(\d+)", lora_selected
                                    )
                                    version_id = (
                                        version_id.group(1) if version_id else None
                                    )

                                    # Get API token from env
                                    api_token = os.getenv("CIVITAI_TOKEN")
                                    headers = (
                                        {"Authorization": f"Bearer {api_token}"}
                                        if api_token
                                        else {}
                                    )

                                    # Get model version info
                                    if version_id:
                                        url = f"https://civitai.com/api/v1/model-versions/{version_id}"
                                    else:
                                        # Get latest version if no specific version
                                        url = f"https://civitai.com/api/v1/models/{model_id}"

                                    response = requests.get(url, headers=headers)
                                    data = response.json()

                                    # For models endpoint, get first version
                                    if "modelVersions" in data:
                                        version_data = data["modelVersions"][0]
                                    else:
                                        version_data = data

                                    # Verify it's a LoRA for Flux
                                    if (
                                        "flux" not in version_data["baseModel"].lower()
                                        and "1" not in version_data["baseModel"].lower()
                                    ):
                                        raise ValueError(
                                            "This LoRA is not compatible with Flux base model"
                                        )

                                    # Find .safetensor file
                                    safetensor_file = next(
                                        (
                                            f
                                            for f in version_data["files"]
                                            if f["name"].endswith(".safetensors")
                                        ),
                                        None,
                                    )

                                    if not safetensor_file:
                                        raise ValueError("No .safetensor file found")

                                    # Download file to temp location
                                    temp_dir = tempfile.gettempdir()
                                    file_path = os.path.join(
                                        temp_dir, safetensor_file["name"]
                                    )

                                    download_url = safetensor_file["downloadUrl"]
                                    if api_token:
                                        download_url += f"?token={api_token}"

                                    response = requests.get(
                                        download_url, headers=headers
                                    )
                                    with open(file_path, "wb") as f:
                                        f.write(response.content)

                                    # Set info from model data
                                    title = data["name"]
                                    weights = file_path

                                    # Check usage tips for default weight
                                    if "description" in version_data:
                                        strength_match = re.search(
                                            r"strength[:\s]+(\d*\.?\d+)",
                                            version_data["description"],
                                            re.IGNORECASE,
                                        )
                                        if strength_match:
                                            weight = float(strength_match.group(1))

                                    info = ", ".join(
                                        version_data.get("trainedWords", [])
                                    )

                                except Exception as e:
                                    gr.Error(f"Error processing CivitAI URL: {str(e)}")
                        else:
                            # check if a hugging face repo (user/repo)
                            if re.match(
                                r"^[a-zA-Z0-9_-]+/[a-zA-Z0-9_-]+$", lora_selected
                            ):
                                try:
                                    # Get API token from env
                                    api_token = os.getenv("HF_TOKEN")
                                    headers = (
                                        {"Authorization": f"Bearer {api_token}"}
                                        if api_token
                                        else {}
                                    )

                                    # Get model info
                                    url = f"https://huggingface.co/api/models/{lora_selected}"
                                    response = requests.get(url, headers=headers)
                                    data = response.json()

                                    # Verify it's a LoRA for Flux
                                    if (
                                        "tags" in data
                                        and "flux-lora" not in data["tags"]
                                    ):
                                        raise ValueError(
                                            "This model is not tagged as a Flux LoRA"
                                        )

                                    # Find .safetensor file
                                    files_url = f"https://huggingface.co/api/models/{lora_selected}/tree"
                                    response = requests.get(files_url, headers=headers)
                                    files = response.json()

                                    safetensor_file = next(
                                        (
                                            f
                                            for f in files
                                            if f.get("path", "").endswith(
                                                ".safetensors"
                                            )
                                        ),
                                        None,
                                    )

                                    if not safetensor_file:
                                        raise ValueError("No .safetensor file found")

                                    # Download file to temp location
                                    temp_dir = tempfile.gettempdir()
                                    file_name = os.path.basename(
                                        safetensor_file["path"]
                                    )
                                    file_path = os.path.join(temp_dir, file_name)

                                    download_url = (
                                        f"https://huggingface.co/{lora_selected}"
                                        f"/resolve/main/{safetensor_file['path']}"
                                    )

                                    response = requests.get(
                                        download_url, headers=headers
                                    )
                                    with open(file_path, "wb") as f:
                                        f.write(response.content)

                                    # Set info from model data
                                    title = data.get(
                                        "name", lora_selected.split("/")[-1]
                                    )
                                    weights = file_path

                                    # Check model card for weight recommendations
                                    if (
                                        "cardData" in data
                                        and "weight" in data["cardData"]
                                    ):
                                        try:
                                            weight = float(data["cardData"]["weight"])
                                        except (ValueError, TypeError):
                                            weight = 1.0

                                    # Get trigger words from tags or model card
                                    trigger_words = []
                                    if (
                                        "cardData" in data
                                        and "trigger_words" in data["cardData"]
                                    ):
                                        trigger_words.extend(
                                            data["cardData"]["trigger_words"]
                                        )
                                    if "tags" in data:
                                        trigger_words.extend(
                                            t
                                            for t in data["tags"]
                                            if not t.startswith("flux-")
                                        )

                                    info = (
                                        ", ".join(trigger_words)
                                        if trigger_words
                                        else None
                                    )

                                except Exception as e:
                                    gr.Error(
                                        f"Error processing Hugging Face repo: {str(e)}"
                                    )

                    # add lora to selected_loras
                    selected_loras.append(
                        {
                            "title": title,
                            "weights": weights,  # i.e safetensors file path
                            "info": info,
                        }
                    )

                # render the selected_loras state as sliders
                @gr.render(inputs=[selected_loras])
                def render_selected_loras(selected_loras):
                    def update_lora_weight(lora_slider, selected_loras):
                        for i, lora in enumerate(selected_loras):
                            if lora["title"] == lora_slider.label:
                                lora["weight"] = lora_slider.value

                    for i, lora in enumerate(selected_loras):
                        lora_slider = gr.Slider(
                            label=lora["title"],
                            value=0.8,
                            interactive=True,
                            info=lora["info"],
                        )
                        lora_slider.change(
                            fn=update_lora_weight,
                            inputs=[lora_slider, selected_loras],
                            outputs=selected_loras,
                        )


demo.launch()