rbiswasfc's picture
Update app.py
72876f6 verified
import json
import os
from datetime import datetime
import dotenv
import lancedb
import requests
from datasets import load_dataset
from fasthtml.common import * # noqa
from huggingface_hub import login, whoami
# def get_images(query: str):
# url = "http://147.189.194.113:80/get_pages"
# response = requests.get(url, params={"query": query})
# return response.json()
server_ip = "147.189.194.113"
# server_ip = "47.47.180.31"
def get_images(query: str):
url = f"http://{server_ip}:80/get_pages"
response = requests.get(url, params={"query": query})
return response.json()
# def rerank_api(query, docs):
# url = "http://47.47.180.31:80/rerank"
# params = {"query": query, "docs": docs}
# response = requests.get(url, params=params)
# return response.json()
def rerank_api(query, docs):
url = f"http://{server_ip}:80/rerank"
data = {"query": query, "docs": docs}
response = requests.post(url, json=data) # Use POST and send data as JSON
return response.json()
dotenv.load_dotenv()
login(token=os.environ.get("HF_TOKEN"))
hf_user = whoami(os.environ.get("HF_TOKEN"))["name"]
HF_REPO_ID_TXT = f"{hf_user}/zotero-answer-ai-texts"
abstract_ds = load_dataset(HF_REPO_ID_TXT, "abstracts")["train"]
article_ds = load_dataset(HF_REPO_ID_TXT, "articles")["train"]
# ranker = Reranker("answerdotai/answerai-colbert-small-v1", model_type="colbert")
uri = "data/zotero-fts"
db = lancedb.connect(uri)
id2abstract = {example["arxiv_id"]: example["abstract"] for example in abstract_ds}
id2content = {example["arxiv_id"]: example["contents"] for example in article_ds}
id2title = {example["arxiv_id"]: example["title"] for example in article_ds}
arxiv_ids = set(list(id2abstract.keys()))
data = []
for arxiv_id in arxiv_ids:
abstract = id2abstract[arxiv_id]
title = id2title[arxiv_id]
full_text = title
for item in id2content[arxiv_id]:
full_text += f"{item['title']}\n\n{item['content']}"
data.append(
{
"arxiv_id": arxiv_id,
"title": title,
"abstract": abstract,
"full_text": full_text,
}
)
table = db.create_table("articles", data=data, mode="overwrite")
table.create_fts_index("full_text", replace=True)
# format results ----
def _format_results(results):
ret = []
for result in results:
arx_id = result["arxiv_id"]
title = result["title"]
abstract = result["abstract"]
if "Abstract\n\n" in abstract:
abstract = abstract.split("Abstract\n\n")[-1]
this_ex = {
"title": title,
"url": f"https://arxiv.org/abs/{arx_id}",
"abstract": abstract,
}
ret.append(this_ex)
return ret
def retrieve_and_rerank(query, k=3):
# retrieve ---
n_fetch = 25
retrieved = (
table.search(query, vector_column_name="", query_type="fts")
.limit(n_fetch)
.select(["arxiv_id", "title", "abstract"])
.to_list()
)
print(f"Retrieved {len(retrieved)} documents")
# re-rank
docs = [f"{item['title']} {item['abstract']}" for item in retrieved]
# results = ranker.rank(query=query, docs=docs)
ranked_doc_ids = rerank_api(query, docs)["ranked_doc_ids"][:k]
# ranked_doc_ids = []
# for result in results[:k]:
# ranked_doc_ids.append(result.doc_id)
final_results = [retrieved[idx] for idx in ranked_doc_ids]
final_results = _format_results(final_results)
return final_results
###########################################################################
# FastHTML app -----
###########################################################################
style = Style("""
:root {
color-scheme: dark;
}
body {
max-width: 1200px;
margin: 0 auto;
padding: 20px;
line-height: 1.6;
}
#query {
width: 100%;
margin-bottom: 1rem;
}
#search-form button {
width: 100%;
}
#search-results, #log-entries {
margin-top: 2rem;
}
.log-entry {
border: 1px solid #ccc;
padding: 10px;
margin-bottom: 10px;
}
.log-entry pre {
white-space: pre-wrap;
word-wrap: break-word;
}
.htmx-indicator {
display: none;
}
.htmx-request .htmx-indicator {
display: inline-block;
}
.spinner {
display: inline-block;
width: 2.5em;
height: 2.5em;
border: 0.3em solid rgba(255,255,255,.3);
border-radius: 50%;
border-top-color: #fff;
animation: spin 1s ease-in-out infinite;
margin-left: 10px;
vertical-align: middle;
}
@keyframes spin {
to { transform: rotate(360deg); }
}
.searching-text {
font-size: 1.2em;
font-weight: bold;
color: #fff;
margin-right: 10px;
vertical-align: middle;
}
.image-results {
display: flex;
flex-wrap: wrap;
gap: 10px;
margin-top: 20px;
}
.image-result {
width: calc(33% - 10px);
text-align: center;
}
.image-result img {
max-width: 100%;
height: auto;
border-radius: 5px;
}
""")
# get the fast app and route
app, rt = fast_app(hdrs=(style,))
# Initialize a database to store search logs --
db = database("log_data/search_logs.db")
search_logs = db.t.search_logs
if search_logs not in db.t:
search_logs.create(
id=int,
timestamp=str,
query=str,
results=str,
pk="id",
)
SearchLog = search_logs.dataclass()
def insert_log_entry(log_entry):
"Insert a log entry into the database"
return search_logs.insert(
SearchLog(
timestamp=log_entry["timestamp"].isoformat(),
query=log_entry["query"],
results=json.dumps(log_entry["results"]),
)
)
@rt("/")
async def get():
query_form = Form(
Textarea(id="query", name="query", placeholder="Enter your query..."),
Button("Submit", type="submit"),
Div(
Span("Searching...", cls="searching-text htmx-indicator"),
Span(cls="spinner htmx-indicator"),
cls="indicator-container",
),
id="search-form",
hx_post="/search",
hx_target="#search-results",
hx_indicator=".indicator-container",
)
results_div = Div(Div(id="search-results", cls="results-container"))
view_logs_link = A("View Logs", href="/logs", cls="view-logs-link")
return Titled(
"Zotero Search", Div(query_form, results_div, view_logs_link, cls="container")
)
def SearchResult(result):
"Custom component for displaying a search result"
return Card(
H4(A(result["title"], href=result["url"], target="_blank")),
P(result["abstract"]),
footer=A("Read more →", href=result["url"], target="_blank"),
)
# def base64_to_pil(base64_string):
# # Remove the "data:image/png;base64," part if it exists
# if "base64," in base64_string:
# base64_string = base64_string.split("base64,")[1]
# # Decode the base64 string
# img_data = base64.b64decode(base64_string)
# # Open the image using PIL
# img = Image.open(BytesIO(img_data))
# return img
# def process_image(image, max_size=(500, 500), quality=85):
# pil_image = base64_to_pil(image)
# img_byte_arr = io.BytesIO()
# pil_image.thumbnail(max_size)
# pil_image.save(img_byte_arr, format="JPEG", quality=quality, optimize=True)
# return f"data:image/jpeg;base64,{base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')}"
def ImageResult(image):
return Div(
Img(src=f"data:image/jpeg;base64,{image}", alt="arxiv image"),
cls="image-result",
)
# def ImageResult(image):
# return Div(
# Img(src=process_image(image), alt="arxiv image"),
# cls="image-result",
# )
def log_query_and_results(query, results):
log_entry = {
"timestamp": datetime.now(),
"query": query,
"results": [{"title": r["title"], "url": r["url"]} for r in results],
}
insert_log_entry(log_entry)
@rt("/search")
async def post(query: str):
image_results = get_images(query)
# print(image_results)
results = retrieve_and_rerank(query)
log_query_and_results(query, results)
return Div(
Br(),
H3("Byaldi Results"),
Div(*[ImageResult(img) for img in image_results], cls="image-results"),
Br(),
H3("Text Results"),
Div(*[SearchResult(r) for r in results], id="text-results"),
id="search-results",
)
# return Div(*[SearchResult(r) for r in results], id="search-results")
def LogEntry(entry):
return Div(
H4(f"Query: {entry.query}"),
P(f"Timestamp: {entry.timestamp}"),
H5("Results:"),
Pre(entry.results),
cls="log-entry",
)
@rt("/logs")
async def get():
logs = search_logs(order_by="-id", limit=50) # Get the latest 50 logs
log_entries = [LogEntry(log) for log in logs]
return Titled(
"Logs",
Div(
H2("Recent Search Logs"),
Div(*log_entries, id="log-entries"),
A("Back to Search", href="/", cls="back-link"),
cls="container",
),
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))
# run_uv()