rag-retrieval / app.py
eli02's picture
update: Create a retreival app.
f0539b9
raw
history blame
2.09 kB
import streamlit as st
import torch as t
import pandas as pd
from sentence_transformers import SentenceTransformer, util
from time import perf_counter as timer
def load_data(database_file):
df = pd.read_parquet(database_file)
chunk_embeddings = t.zeros((df.__len__(), 768))
for idx in range(len(chunk_embeddings)):
chunk_embeddings[idx] = t.tensor(df.loc[df.index[idx], "chunk_embeddings"])
return df, chunk_embeddings
def main():
st.title("Semantic Text Retrieval App")
# Select device
device = "cuda" if t.cuda.is_available() else "cpu"
st.write(f"Using device: {device}")
# Load embedding model
embedding_model = SentenceTransformer(model_name_or_path="all-mpnet-base-v2", device=device)
# File upload for the database
database_file = st.file_uploader("Upload the Parquet database file", type=["parquet"])
if database_file is not None:
df, chunk_embeddings = load_data(database_file)
st.success("Database loaded successfully!")
query = st.text_area("Enter your query:")
if st.button("Search") and query:
query_embedding = embedding_model.encode(query)
# Compute dot product scores
start_time = timer()
dot_scores = util.dot_score(query_embedding, chunk_embeddings)[0]
end_time = timer()
st.write(f"Time taken to compute scores: {end_time - start_time:.5f} seconds")
# Get top results
top_k = st.slider("Select number of top results to display", min_value=1, max_value=10, value=5)
top_results_dot_product = t.topk(dot_scores, k=top_k)
st.subheader("Query Results")
st.write(f"Query: {query}")
for score, idx in zip(top_results_dot_product[0], top_results_dot_product[1]):
st.write(f"### Score: {score:.4f}")
st.write(f"**Text:** {df.iloc[int(idx)]["ext"]}")
st.write(f"**Number of tokens:** {df.iloc[int(idx)]['tokens']}")
st.write("---")
if __name__ == "__main__":
main()