Spaces:
Sleeping
Sleeping
File size: 3,436 Bytes
be785b0 3238b05 be785b0 52b042a dd10fbb 52b042a dd10fbb 52b042a dd10fbb 52b042a dd10fbb 52b042a dd10fbb da8206b dd10fbb da8206b dd10fbb 52b042a dd10fbb da8206b 52b042a da8206b dd10fbb da8206b dd10fbb da8206b 52b042a dc4d5eb 52b042a dd10fbb 52b042a da8206b 52b042a dd10fbb 52b042a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/default-java"
import gradio as gr
from pyserini.search.lucene import LuceneSearcher
import os
import json
def initialize_searcher(index_name):
if not os.path.exists(index_name):
os.system(f'python -c "from pyserini.search import LuceneSearcher; LuceneSearcher.from_prebuilt_index(\'{index_name}\')"')
searcher = LuceneSearcher.from_prebuilt_index(index_name)
searcher.set_bm25(k1=0.9, b=0.4)
return searcher
def search_pyserini(query, top_k, index_name):
try:
searcher = initialize_searcher(index_name)
hits = searcher.search(query, k=top_k)
results = []
for i, hit in enumerate(hits):
doc = searcher.doc(hit.docid)
doc_dict = json.loads(doc.raw())
results.append({
"rank": i + 1,
"doc_id": hit.docid,
"score": hit.score,
"content": doc_dict['contents']
})
return format_results(results)
except Exception as e:
return f"<div class='error'>An error occurred: {str(e)}</div>"
def format_results(results):
html = "<div class='results-container'>"
for result in results:
html += f"""
<div class='result-item'>
<h3>Rank {result['rank']} (Score: {result['score']:.4f})</h3>
<p class='doc-id'>Doc ID: {result['doc_id']}</p>
<p class='content'>{result['content']}</p>
</div>
"""
html += "</div>"
return html
css = """
.gradio-container {
font-family: 'Arial', sans-serif;
}
.results-container {
display: flex;
flex-direction: column;
gap: 20px;
}
.result-item {
border: 1px solid #ddd;
border-radius: 8px;
padding: 15px;
width: 100%;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
.result-item h3 {
margin-top: 0;
color: #333;
}
.doc-id {
font-size: 0.9em;
color: #666;
margin-bottom: 10px;
}
.content {
font-size: 0.95em;
line-height: 1.4;
}
.error {
color: red;
font-weight: bold;
}
"""
with gr.Blocks(css=css) as iface:
gr.Markdown("# Pyserini Search Interface")
gr.Markdown("Enter a query to search using Pyserini with BM25 scoring (k1=0.9, b=0.4). See all possible prebuild index names at [https://github.com/castorini/pyserini/blob/master/docs/prebuilt-indexes.md#standard-lucene-indexes](https://github.com/castorini/pyserini/blob/master/docs/prebuilt-indexes.md#standard-lucene-indexes)")
with gr.Row():
index_input = gr.Textbox(
value="msmarco-passage",
lines=1,
label="Prebuilt Index Name",
placeholder="Enter the name of the prebuilt index"
)
with gr.Row():
top_k_slider = gr.Slider(
minimum=1,
maximum=100,
value=10,
step=1,
label="Number of top results to return"
)
with gr.Row():
query_input = gr.Textbox(
lines=1,
placeholder="Enter your search query here...",
label="Search Query"
)
with gr.Row():
search_button = gr.Button("Search", variant="primary")
with gr.Row():
output = gr.HTML(label="Search Results")
search_button.click(
fn=search_pyserini,
inputs=[query_input, top_k_slider, index_input],
outputs=output
)
iface.launch() |