|
import json |
|
import os |
|
from datetime import datetime |
|
|
|
import dotenv |
|
import lancedb |
|
import requests |
|
from datasets import load_dataset |
|
from fasthtml.common import * |
|
from huggingface_hub import login, whoami |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
server_ip = "147.189.194.113" |
|
|
|
|
|
def get_images(query: str): |
|
url = f"http://{server_ip}:80/get_pages" |
|
|
|
response = requests.get(url, params={"query": query}) |
|
return response.json() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def rerank_api(query, docs): |
|
url = f"http://{server_ip}:80/rerank" |
|
|
|
data = {"query": query, "docs": docs} |
|
|
|
response = requests.post(url, json=data) |
|
return response.json() |
|
|
|
|
|
dotenv.load_dotenv() |
|
login(token=os.environ.get("HF_TOKEN")) |
|
|
|
hf_user = whoami(os.environ.get("HF_TOKEN"))["name"] |
|
HF_REPO_ID_TXT = f"{hf_user}/zotero-answer-ai-texts" |
|
|
|
abstract_ds = load_dataset(HF_REPO_ID_TXT, "abstracts")["train"] |
|
article_ds = load_dataset(HF_REPO_ID_TXT, "articles")["train"] |
|
|
|
|
|
|
|
|
|
uri = "data/zotero-fts" |
|
db = lancedb.connect(uri) |
|
|
|
id2abstract = {example["arxiv_id"]: example["abstract"] for example in abstract_ds} |
|
id2content = {example["arxiv_id"]: example["contents"] for example in article_ds} |
|
id2title = {example["arxiv_id"]: example["title"] for example in article_ds} |
|
|
|
arxiv_ids = set(list(id2abstract.keys())) |
|
|
|
data = [] |
|
for arxiv_id in arxiv_ids: |
|
abstract = id2abstract[arxiv_id] |
|
title = id2title[arxiv_id] |
|
full_text = title |
|
|
|
for item in id2content[arxiv_id]: |
|
full_text += f"{item['title']}\n\n{item['content']}" |
|
|
|
data.append( |
|
{ |
|
"arxiv_id": arxiv_id, |
|
"title": title, |
|
"abstract": abstract, |
|
"full_text": full_text, |
|
} |
|
) |
|
|
|
|
|
table = db.create_table("articles", data=data, mode="overwrite") |
|
|
|
table.create_fts_index("full_text", replace=True) |
|
|
|
|
|
|
|
def _format_results(results): |
|
ret = [] |
|
|
|
for result in results: |
|
arx_id = result["arxiv_id"] |
|
title = result["title"] |
|
abstract = result["abstract"] |
|
|
|
if "Abstract\n\n" in abstract: |
|
abstract = abstract.split("Abstract\n\n")[-1] |
|
|
|
this_ex = { |
|
"title": title, |
|
"url": f"https://arxiv.org/abs/{arx_id}", |
|
"abstract": abstract, |
|
} |
|
|
|
ret.append(this_ex) |
|
|
|
return ret |
|
|
|
|
|
def retrieve_and_rerank(query, k=3): |
|
|
|
n_fetch = 25 |
|
|
|
retrieved = ( |
|
table.search(query, vector_column_name="", query_type="fts") |
|
.limit(n_fetch) |
|
.select(["arxiv_id", "title", "abstract"]) |
|
.to_list() |
|
) |
|
|
|
print(f"Retrieved {len(retrieved)} documents") |
|
|
|
|
|
docs = [f"{item['title']} {item['abstract']}" for item in retrieved] |
|
|
|
ranked_doc_ids = rerank_api(query, docs)["ranked_doc_ids"][:k] |
|
|
|
|
|
|
|
|
|
|
|
final_results = [retrieved[idx] for idx in ranked_doc_ids] |
|
final_results = _format_results(final_results) |
|
return final_results |
|
|
|
|
|
|
|
|
|
|
|
|
|
style = Style(""" |
|
:root { |
|
color-scheme: dark; |
|
} |
|
body { |
|
max-width: 1200px; |
|
margin: 0 auto; |
|
padding: 20px; |
|
line-height: 1.6; |
|
} |
|
#query { |
|
width: 100%; |
|
margin-bottom: 1rem; |
|
} |
|
#search-form button { |
|
width: 100%; |
|
} |
|
#search-results, #log-entries { |
|
margin-top: 2rem; |
|
} |
|
.log-entry { |
|
border: 1px solid #ccc; |
|
padding: 10px; |
|
margin-bottom: 10px; |
|
} |
|
.log-entry pre { |
|
white-space: pre-wrap; |
|
word-wrap: break-word; |
|
} |
|
.htmx-indicator { |
|
display: none; |
|
} |
|
.htmx-request .htmx-indicator { |
|
display: inline-block; |
|
} |
|
.spinner { |
|
display: inline-block; |
|
width: 2.5em; |
|
height: 2.5em; |
|
border: 0.3em solid rgba(255,255,255,.3); |
|
border-radius: 50%; |
|
border-top-color: #fff; |
|
animation: spin 1s ease-in-out infinite; |
|
margin-left: 10px; |
|
vertical-align: middle; |
|
} |
|
@keyframes spin { |
|
to { transform: rotate(360deg); } |
|
} |
|
.searching-text { |
|
font-size: 1.2em; |
|
font-weight: bold; |
|
color: #fff; |
|
margin-right: 10px; |
|
vertical-align: middle; |
|
} |
|
.image-results { |
|
display: flex; |
|
flex-wrap: wrap; |
|
gap: 10px; |
|
margin-top: 20px; |
|
} |
|
.image-result { |
|
width: calc(33% - 10px); |
|
text-align: center; |
|
} |
|
.image-result img { |
|
max-width: 100%; |
|
height: auto; |
|
border-radius: 5px; |
|
} |
|
""") |
|
|
|
|
|
app, rt = fast_app(hdrs=(style,)) |
|
|
|
|
|
db = database("log_data/search_logs.db") |
|
search_logs = db.t.search_logs |
|
|
|
if search_logs not in db.t: |
|
search_logs.create( |
|
id=int, |
|
timestamp=str, |
|
query=str, |
|
results=str, |
|
pk="id", |
|
) |
|
|
|
SearchLog = search_logs.dataclass() |
|
|
|
|
|
def insert_log_entry(log_entry): |
|
"Insert a log entry into the database" |
|
return search_logs.insert( |
|
SearchLog( |
|
timestamp=log_entry["timestamp"].isoformat(), |
|
query=log_entry["query"], |
|
results=json.dumps(log_entry["results"]), |
|
) |
|
) |
|
|
|
|
|
@rt("/") |
|
async def get(): |
|
query_form = Form( |
|
Textarea(id="query", name="query", placeholder="Enter your query..."), |
|
Button("Submit", type="submit"), |
|
Div( |
|
Span("Searching...", cls="searching-text htmx-indicator"), |
|
Span(cls="spinner htmx-indicator"), |
|
cls="indicator-container", |
|
), |
|
id="search-form", |
|
hx_post="/search", |
|
hx_target="#search-results", |
|
hx_indicator=".indicator-container", |
|
) |
|
|
|
results_div = Div(Div(id="search-results", cls="results-container")) |
|
|
|
view_logs_link = A("View Logs", href="/logs", cls="view-logs-link") |
|
|
|
return Titled( |
|
"Zotero Search", Div(query_form, results_div, view_logs_link, cls="container") |
|
) |
|
|
|
|
|
def SearchResult(result): |
|
"Custom component for displaying a search result" |
|
return Card( |
|
H4(A(result["title"], href=result["url"], target="_blank")), |
|
P(result["abstract"]), |
|
footer=A("Read more →", href=result["url"], target="_blank"), |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ImageResult(image): |
|
return Div( |
|
Img(src=f"data:image/jpeg;base64,{image}", alt="arxiv image"), |
|
cls="image-result", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def log_query_and_results(query, results): |
|
log_entry = { |
|
"timestamp": datetime.now(), |
|
"query": query, |
|
"results": [{"title": r["title"], "url": r["url"]} for r in results], |
|
} |
|
insert_log_entry(log_entry) |
|
|
|
|
|
@rt("/search") |
|
async def post(query: str): |
|
image_results = get_images(query) |
|
|
|
results = retrieve_and_rerank(query) |
|
|
|
log_query_and_results(query, results) |
|
|
|
return Div( |
|
Br(), |
|
H3("Byaldi Results"), |
|
Div(*[ImageResult(img) for img in image_results], cls="image-results"), |
|
Br(), |
|
H3("Text Results"), |
|
Div(*[SearchResult(r) for r in results], id="text-results"), |
|
id="search-results", |
|
) |
|
|
|
|
|
|
|
|
|
def LogEntry(entry): |
|
return Div( |
|
H4(f"Query: {entry.query}"), |
|
P(f"Timestamp: {entry.timestamp}"), |
|
H5("Results:"), |
|
Pre(entry.results), |
|
cls="log-entry", |
|
) |
|
|
|
|
|
@rt("/logs") |
|
async def get(): |
|
logs = search_logs(order_by="-id", limit=50) |
|
log_entries = [LogEntry(log) for log in logs] |
|
return Titled( |
|
"Logs", |
|
Div( |
|
H2("Recent Search Logs"), |
|
Div(*log_entries, id="log-entries"), |
|
A("Back to Search", href="/", cls="back-link"), |
|
cls="container", |
|
), |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
|
|
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860))) |
|
|
|
|
|
|