from functools import lru_cache

import duckdb
import gradio as gr
import polars as pl
from datasets import load_dataset
from gradio_huggingfacehub_search import HuggingfaceHubSearch
from model2vec import StaticModel

global df

# Load a model from the HuggingFace hub (in this case the potion-base-8M model)
model_name = "minishlab/potion-base-8M"
model = StaticModel.from_pretrained(model_name)


def get_iframe(hub_repo_id):
    if not hub_repo_id:
        raise ValueError("Hub repo id is required")
    url = f"https://huggingface.co/datasets/{hub_repo_id}/embed/viewer"
    iframe = f"""
    <iframe
  src="{url}"
  frameborder="0"
  width="100%"
  height="600px"
></iframe>
"""
    return iframe


def load_dataset_from_hub(hub_repo_id: str):
    gr.Info(message="Loading dataset...")
    ds = load_dataset(hub_repo_id)


def get_columns(hub_repo_id: str, split: str):
    ds = load_dataset(hub_repo_id)
    ds_split = ds[split]
    return gr.Dropdown(
        choices=ds_split.column_names,
        value=ds_split.column_names[0],
        label="Select a column",
        visible=True,
    )


def get_splits(hub_repo_id: str):
    ds = load_dataset(hub_repo_id)
    splits = list(ds.keys())
    return gr.Dropdown(
        choices=splits, value=splits[0], label="Select a split", visible=True
    )


@lru_cache
def vectorize_dataset(hub_repo_id: str, split: str, column: str):
    gr.Info("Vectorizing dataset...")
    ds = load_dataset(hub_repo_id)
    df = ds[split].to_polars()
    embeddings = model.encode(df[column].cast(str), max_length=512)
    return embeddings


def run_query(hub_repo_id: str, query: str, split: str, column: str):
    embeddings = vectorize_dataset(hub_repo_id, split, column)
    ds = load_dataset(hub_repo_id)
    df = ds[split].to_polars()
    df = df.with_columns(pl.Series(embeddings).alias("embeddings"))
    try:
        vector = model.encode(query)
        df_results = duckdb.sql(
            query=f"""
            SELECT *
            FROM df
            ORDER BY array_cosine_distance(embeddings, {vector.tolist()}::FLOAT[256])
            LIMIT 5
            """
        ).to_df()
        return gr.Dataframe(df_results, visible=True)
    except Exception as e:
        raise gr.Error(f"Error running query: {e}")


def hide_components():
    return [
        gr.Dropdown(visible=False),
        gr.Dropdown(visible=False),
        gr.Textbox(visible=False),
        gr.Button(visible=False),
        gr.Dataframe(visible=False),
    ]


def partial_hide_components():
    return [
        gr.Textbox(visible=False),
        gr.Button(visible=False),
        gr.Dataframe(visible=False),
    ]


def show_components():
    return [
        gr.Textbox(visible=True, label="Query"),
        gr.Button(visible=True, value="Search"),
    ]


with gr.Blocks() as demo:
    gr.HTML(
        """
        <h1>Vector Search any Hugging Face Dataset</h1>
        <p>
            This app allows you to vector search any Hugging Face dataset.
            You can search for the nearest neighbors of a query vector, or
            perform a similarity search on a dataframe.
        </p>
        """
    )
    with gr.Row():
        with gr.Column():
            search_in = HuggingfaceHubSearch(
                label="Search Huggingface Hub",
                placeholder="Search for models on Huggingface",
                search_type="dataset",
                sumbit_on_select=True,
            )
    with gr.Row():
        search_out = gr.HTML(label="Search Results")

    with gr.Row():
        split_dropdown = gr.Dropdown(label="Select a split", visible=False)
        column_dropdown = gr.Dropdown(label="Select a column", visible=False)
    with gr.Row():
        query_input = gr.Textbox(label="Query", visible=False)

    btn_run = gr.Button("Search", visible=False)

    results_output = gr.Dataframe(label="Results", visible=False)

    search_in.submit(get_iframe, inputs=search_in, outputs=search_out).then(
        fn=load_dataset_from_hub,
        inputs=search_in,
        show_progress=True,
    ).then(
        fn=hide_components,
        outputs=[split_dropdown, column_dropdown, query_input, btn_run, results_output],
    ).then(fn=get_splits, inputs=search_in, outputs=split_dropdown).then(
        fn=get_columns, inputs=[search_in, split_dropdown], outputs=column_dropdown
    )

    split_dropdown.change(
        fn=get_columns, inputs=[search_in, split_dropdown], outputs=column_dropdown
    )

    column_dropdown.change(
        fn=partial_hide_components,
        outputs=[query_input, btn_run, results_output],
    ).then(fn=show_components, outputs=[query_input, btn_run])

    btn_run.click(
        fn=run_query,
        inputs=[search_in, query_input, split_dropdown, column_dropdown],
        outputs=results_output,
    )

demo.launch()