|
import streamlit as st |
|
import torch as t |
|
import pandas as pd |
|
from sentence_transformers import SentenceTransformer, util |
|
from time import perf_counter as timer |
|
|
|
def load_data(database_file): |
|
df = pd.read_parquet(database_file) |
|
chunk_embeddings = t.zeros((df.__len__(), 768)) |
|
for idx in range(len(chunk_embeddings)): |
|
chunk_embeddings[idx] = t.tensor(df.loc[df.index[idx], "chunk_embeddings"]) |
|
return df, chunk_embeddings |
|
|
|
def main(): |
|
st.title("Semantic Text Retrieval App") |
|
|
|
|
|
device = "cuda" if t.cuda.is_available() else "cpu" |
|
st.write(f"Using device: {device}") |
|
|
|
|
|
embedding_model = SentenceTransformer(model_name_or_path="all-mpnet-base-v2", device=device) |
|
|
|
|
|
database_file = st.file_uploader("Upload the Parquet database file", type=["parquet"]) |
|
|
|
if database_file is not None: |
|
df, chunk_embeddings = load_data(database_file) |
|
st.success("Database loaded successfully!") |
|
|
|
query = st.text_area("Enter your query:") |
|
|
|
if st.button("Search") and query: |
|
query_embedding = embedding_model.encode(query) |
|
|
|
|
|
start_time = timer() |
|
dot_scores = util.dot_score(query_embedding, chunk_embeddings)[0] |
|
end_time = timer() |
|
|
|
st.write(f"Time taken to compute scores: {end_time - start_time:.5f} seconds") |
|
|
|
|
|
top_k = st.slider("Select number of top results to display", min_value=1, max_value=10, value=5) |
|
top_results_dot_product = t.topk(dot_scores, k=top_k) |
|
|
|
st.subheader("Query Results") |
|
st.write(f"Query: {query}") |
|
|
|
for score, idx in zip(top_results_dot_product[0], top_results_dot_product[1]): |
|
st.write(f"### Score: {score:.4f}") |
|
st.write(f"**Text:** {df.iloc[int(idx)]['ext']}") |
|
st.write(f"**Number of tokens:** {df.iloc[int(idx)]['tokens']}") |
|
st.write("---") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|