import json |
import os |
from datetime import datetime |
import dotenv |
import lancedb |
import requests |
from datasets import load_dataset |
from fasthtml.common import * |
from huggingface_hub import login, whoami |
server_ip = "" |
def get_images(query: str): |
url = f"http://{server_ip}:80/get_pages" |
response = requests.get(url, params={"query": query}) |
return response.json() |
def rerank_api(query, docs): |
url = f"http://{server_ip}:80/rerank" |
data = {"query": query, "docs": docs} |
response = requests.post(url, json=data) |
return response.json() |
dotenv.load_dotenv() |
login(token=os.environ.get("HF_TOKEN")) |
hf_user = whoami(os.environ.get("HF_TOKEN"))["name"] |
HF_REPO_ID_TXT = f"{hf_user}/zotero-answer-ai-texts" |
abstract_ds = load_dataset(HF_REPO_ID_TXT, "abstracts")["train"] |
article_ds = load_dataset(HF_REPO_ID_TXT, "articles")["train"] |
uri = "data/zotero-fts" |
db = lancedb.connect(uri) |
id2abstract = {example["arxiv_id"]: example["abstract"] for example in abstract_ds} |
id2content = {example["arxiv_id"]: example["contents"] for example in article_ds} |
id2title = {example["arxiv_id"]: example["title"] for example in article_ds} |
arxiv_ids = set(list(id2abstract.keys())) |
data = [] |
for arxiv_id in arxiv_ids: |
abstract = id2abstract[arxiv_id] |
title = id2title[arxiv_id] |
full_text = title |
for item in id2content[arxiv_id]: |
full_text += f"{item['title']}\n\n{item['content']}" |
data.append( |
{ |
"arxiv_id": arxiv_id, |
"title": title, |
"abstract": abstract, |
"full_text": full_text, |
} |
) |
table = db.create_table("articles", data=data, mode="overwrite") |
table.create_fts_index("full_text", replace=True) |
def _format_results(results): |
ret = [] |
for result in results: |
arx_id = result["arxiv_id"] |
title = result["title"] |
abstract = result["abstract"] |
if "Abstract\n\n" in abstract: |
abstract = abstract.split("Abstract\n\n")[-1] |
this_ex = { |
"title": title, |
"url": f"https://arxiv.org/abs/{arx_id}", |
"abstract": abstract, |
} |
ret.append(this_ex) |
return ret |
def retrieve_and_rerank(query, k=3): |
n_fetch = 25 |
retrieved = ( |
table.search(query, vector_column_name="", query_type="fts") |
.limit(n_fetch) |
.select(["arxiv_id", "title", "abstract"]) |
.to_list() |
) |
print(f"Retrieved {len(retrieved)} documents") |
docs = [f"{item['title']} {item['abstract']}" for item in retrieved] |
ranked_doc_ids = rerank_api(query, docs)["ranked_doc_ids"][:k] |
final_results = [retrieved[idx] for idx in ranked_doc_ids] |
final_results = _format_results(final_results) |
return final_results |
style = Style(""" |
:root { |
color-scheme: dark; |
} |
body { |
max-width: 1200px; |
margin: 0 auto; |
padding: 20px; |
line-height: 1.6; |
} |
#query { |
width: 100%; |
margin-bottom: 1rem; |
} |
#search-form button { |
width: 100%; |
} |
#search-results, #log-entries { |
margin-top: 2rem; |
} |
.log-entry { |
border: 1px solid #ccc; |
padding: 10px; |
margin-bottom: 10px; |
} |
.log-entry pre { |
white-space: pre-wrap; |
word-wrap: break-word; |
} |
.htmx-indicator { |
display: none; |
} |
.htmx-request .htmx-indicator { |
display: inline-block; |
} |
.spinner { |
display: inline-block; |
width: 2.5em; |
height: 2.5em; |
border: 0.3em solid rgba(255,255,255,.3); |
border-radius: 50%; |
border-top-color: #fff; |
animation: spin 1s ease-in-out infinite; |
margin-left: 10px; |
vertical-align: middle; |
} |
@keyframes spin { |
to { transform: rotate(360deg); } |
} |
.searching-text { |
font-size: 1.2em; |
font-weight: bold; |
color: #fff; |
margin-right: 10px; |
vertical-align: middle; |
} |
.image-results { |
display: flex; |
flex-wrap: wrap; |
gap: 10px; |
margin-top: 20px; |
} |
.image-result { |
width: calc(33% - 10px); |
text-align: center; |
} |
.image-result img { |
max-width: 100%; |
height: auto; |
border-radius: 5px; |
} |
""") |
app, rt = fast_app(hdrs=(style,)) |
db = database("log_data/search_logs.db") |
search_logs = db.t.search_logs |
if search_logs not in db.t: |
search_logs.create( |
id=int, |
timestamp=str, |
query=str, |
results=str, |
pk="id", |
) |
SearchLog = search_logs.dataclass() |
def insert_log_entry(log_entry): |
"Insert a log entry into the database" |
return search_logs.insert( |
SearchLog( |
timestamp=log_entry["timestamp"].isoformat(), |
query=log_entry["query"], |
results=json.dumps(log_entry["results"]), |
) |
) |
@rt("/") |
async def get(): |
query_form = Form( |
Textarea(id="query", name="query", placeholder="Enter your query..."), |
Button("Submit", type="submit"), |
Div( |
Span("Searching...", cls="searching-text htmx-indicator"), |
Span(cls="spinner htmx-indicator"), |
cls="indicator-container", |
), |
id="search-form", |
hx_post="/search", |
hx_target="#search-results", |
hx_indicator=".indicator-container", |
) |
results_div = Div(Div(id="search-results", cls="results-container")) |
view_logs_link = A("View Logs", href="/logs", cls="view-logs-link") |
return Titled( |
"Zotero Search", Div(query_form, results_div, view_logs_link, cls="container") |
) |
def SearchResult(result): |
"Custom component for displaying a search result" |
return Card( |
H4(A(result["title"], href=result["url"], target="_blank")), |
P(result["abstract"]), |
footer=A("Read more →", href=result["url"], target="_blank"), |
) |
def ImageResult(image): |
return Div( |
Img(src=f"data:image/jpeg;base64,{image}", alt="arxiv image"), |
cls="image-result", |
) |
def log_query_and_results(query, results): |
log_entry = { |
"timestamp": datetime.now(), |
"query": query, |
"results": [{"title": r["title"], "url": r["url"]} for r in results], |
} |
insert_log_entry(log_entry) |
@rt("/search") |
async def post(query: str): |
image_results = get_images(query) |
results = retrieve_and_rerank(query) |
log_query_and_results(query, results) |
return Div( |
Br(), |
H3("Byaldi Results"), |
Div(*[ImageResult(img) for img in image_results], cls="image-results"), |
Br(), |
H3("Text Results"), |
Div(*[SearchResult(r) for r in results], id="text-results"), |
id="search-results", |
) |
def LogEntry(entry): |
return Div( |
H4(f"Query: {entry.query}"), |
P(f"Timestamp: {entry.timestamp}"), |
H5("Results:"), |
Pre(entry.results), |
cls="log-entry", |
) |
@rt("/logs") |
async def get(): |
logs = search_logs(order_by="-id", limit=50) |
log_entries = [LogEntry(log) for log in logs] |
return Titled( |
"Logs", |
Div( |
H2("Recent Search Logs"), |
Div(*log_entries, id="log-entries"), |
A("Back to Search", href="/", cls="back-link"), |
cls="container", |
), |
) |
if __name__ == "__main__": |
import uvicorn |
uvicorn.run(app, host="", port=int(os.environ.get("PORT", 7860))) |