import duckdb import gradio as gr import polars as pl from datasets import load_dataset from gradio_huggingfacehub_search import HuggingfaceHubSearch from model2vec import StaticModel global ds global df # Load a model from the HuggingFace hub (in this case the potion-base-8M model) model_name = "minishlab/potion-base-8M" model = StaticModel.from_pretrained(model_name) def get_iframe(hub_repo_id): if not hub_repo_id: raise ValueError("Hub repo id is required") url = f"https://huggingface.co/datasets/{hub_repo_id}/embed/viewer" iframe = f""" """ return iframe def load_dataset_from_hub(hub_repo_id): gr.Info("Loading dataset...") global ds ds = load_dataset(hub_repo_id) def get_columns(split: str): global ds ds_split = ds[split] return gr.Dropdown( choices=ds_split.column_names, value=ds_split.column_names[0], label="Select a column", visible=True, ) def get_splits(): global ds splits = list(ds.keys()) return gr.Dropdown( choices=splits, value=splits[0], label="Select a split", visible=True ) def vectorize_dataset(split: str, column: str): gr.Info("Vectorizing dataset...") global df global ds df = ds[split].to_polars() embeddings = model.encode(df[column].cast(str), max_length=512) df = df.with_columns(pl.Series(embeddings).alias(f"{column}_embeddings")) def run_query(query: str, column: str): try: global df vector = model.encode(query) df_results = duckdb.sql( query=f""" SELECT * FROM df ORDER BY array_cosine_distance({column}_embeddings, {vector.tolist()}::FLOAT[256]) LIMIT 5 """ ).to_df() return gr.Dataframe(df_results, visible=True) except Exception as e: raise gr.Error(f"Error running query: {e}") def hide_components(): return [ gr.Dropdown(visible=False), gr.Dropdown(visible=False), gr.Textbox(visible=False), gr.Button(visible=False), gr.Dataframe(visible=False), ] def partial_hide_components(): return [ gr.Textbox(visible=False), gr.Button(visible=False), gr.Dataframe(visible=False), ] def show_components(): return [ gr.Textbox(visible=True, label="Query"), gr.Button(visible=True, value="Search"), ] with gr.Blocks() as demo: gr.HTML( """

Vector Search any Hugging Face Dataset

This app allows you to vector search any Hugging Face dataset. You can search for the nearest neighbors of a query vector, or perform a similarity search on a dataframe.

""" ) with gr.Row(): with gr.Column(): search_in = HuggingfaceHubSearch( label="Search Huggingface Hub", placeholder="Search for models on Huggingface", search_type="dataset", sumbit_on_select=True, ) with gr.Row(): search_out = gr.HTML(label="Search Results") with gr.Row(): split_dropdown = gr.Dropdown(label="Select a split", visible=False) column_dropdown = gr.Dropdown(label="Select a column", visible=False) with gr.Row(): query_input = gr.Textbox(label="Query", visible=False) btn_run = gr.Button("Search", visible=False) results_output = gr.Dataframe(label="Results", visible=False) search_in.submit(get_iframe, inputs=search_in, outputs=search_out).then( fn=load_dataset_from_hub, inputs=search_in, show_progress=True, ).then( fn=hide_components, outputs=[split_dropdown, column_dropdown, query_input, btn_run, results_output], ).then(fn=get_splits, outputs=split_dropdown).then( fn=get_columns, inputs=split_dropdown, outputs=column_dropdown ) split_dropdown.change( fn=get_columns, inputs=split_dropdown, outputs=column_dropdown ) column_dropdown.change( fn=partial_hide_components, outputs=[query_input, btn_run, results_output], ).then(fn=vectorize_dataset, inputs=[split_dropdown, column_dropdown]).then( fn=show_components, outputs=[query_input, btn_run] ) btn_run.click( fn=run_query, inputs=[query_input, column_dropdown], outputs=results_output ) demo.launch()