Spaces:
Runtime error
Runtime error
import gradio as gr | |
from wordllama import WordLlama | |
# Load the default WordLlama model | |
wl = WordLlama.load() | |
def calculate_similarity(sentence1, sentence2): | |
similarity_score = wl.similarity(sentence1, sentence2) | |
return similarity_score | |
def rank_documents(query, candidates): | |
ranked_docs = wl.rank(query, candidates) | |
return ranked_docs | |
def deduplicate_candidates(candidates, threshold): | |
deduplicated = wl.deduplicate(candidates, threshold) | |
return deduplicated | |
def filter_candidates(query, candidates, threshold): | |
filtered = wl.filter(query, candidates, threshold) | |
return filtered | |
def topk_candidates(query, candidates, k): | |
topk = wl.topk(query, candidates, k) | |
return topk | |
def create_gradio_interface(): | |
with gr.Blocks() as demo: | |
gr.Markdown("## WordLlama Gradio Demo") | |
with gr.Tab("Similarity"): | |
with gr.Row(): | |
sentence1 = gr.Textbox(label="Sentence 1", placeholder="Enter the first sentence here...") | |
sentence2 = gr.Textbox(label="Sentence 2", placeholder="Enter the second sentence here...") | |
similarity_output = gr.Number(label="Similarity Score") | |
gr.Button("Calculate Similarity").click( | |
fn=calculate_similarity, | |
inputs=[sentence1, sentence2], | |
outputs=[similarity_output] | |
) | |
with gr.Tab("Rank Documents"): | |
query = gr.Textbox(label="Query", placeholder="Enter the query here...") | |
candidates = gr.Textbox(label="Candidates (comma separated)", placeholder="Enter candidate sentences here...") | |
ranked_docs_output = gr.Dataframe(headers=["Document", "Score"]) | |
gr.Button("Rank Documents").click( | |
fn=lambda q, c: rank_documents(q, c.split(',')), | |
inputs=[query, candidates], | |
outputs=[ranked_docs_output] | |
) | |
with gr.Tab("Deduplicate Candidates"): | |
candidates_dedup = gr.Textbox(label="Candidates (comma separated)", placeholder="Enter candidate sentences here...") | |
threshold_dedup = gr.Slider(label="Threshold", minimum=0.0, maximum=1.0, step=0.01, value=0.8) | |
deduplicated_output = gr.Textbox(label="Deduplicated Candidates") | |
gr.Button("Deduplicate").click( | |
fn=lambda c, t: deduplicate_candidates(c.split(','), t), | |
inputs=[candidates_dedup, threshold_dedup], | |
outputs=[deduplicated_output] | |
) | |
with gr.Tab("Filter Candidates"): | |
filter_query = gr.Textbox(label="Query", placeholder="Enter the query here...") | |
candidates_filter = gr.Textbox(label="Candidates (comma separated)", placeholder="Enter candidate sentences here...") | |
threshold_filter = gr.Slider(label="Threshold", minimum=0.0, maximum=1.0, step=0.01, value=0.3) | |
filtered_output = gr.Textbox(label="Filtered Candidates") | |
gr.Button("Filter Candidates").click( | |
fn=lambda q, c, t: filter_candidates(q, c.split(','), t), | |
inputs=[filter_query, candidates_filter, threshold_filter], | |
outputs=[filtered_output] | |
) | |
with gr.Tab("Top-k Candidates"): | |
topk_query = gr.Textbox(label="Query", placeholder="Enter the query here...") | |
candidates_topk = gr.Textbox(label="Candidates (comma separated)", placeholder="Enter candidate sentences here...") | |
k = gr.Slider(label="Top-k", minimum=1, maximum=10, step=1, value=3) | |
topk_output = gr.Textbox(label="Top-k Candidates") | |
gr.Button("Get Top-k Candidates").click( | |
fn=lambda q, c, k: topk_candidates(q, c.split(','), k), | |
inputs=[topk_query, candidates_topk, k], | |
outputs=[topk_output] | |
) | |
return demo | |
# Create and launch the Gradio interface | |
demo = create_gradio_interface() | |
demo.launch() | |