eli02 commited on
Commit
f0539b9
·
1 Parent(s): dd010bf

update: Create a retreival app.

Browse files
Files changed (2) hide show
  1. app.py +57 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch as t
3
+ import pandas as pd
4
+ from sentence_transformers import SentenceTransformer, util
5
+ from time import perf_counter as timer
6
+
7
+ def load_data(database_file):
8
+ df = pd.read_parquet(database_file)
9
+ chunk_embeddings = t.zeros((df.__len__(), 768))
10
+ for idx in range(len(chunk_embeddings)):
11
+ chunk_embeddings[idx] = t.tensor(df.loc[df.index[idx], "chunk_embeddings"])
12
+ return df, chunk_embeddings
13
+
14
+ def main():
15
+ st.title("Semantic Text Retrieval App")
16
+
17
+ # Select device
18
+ device = "cuda" if t.cuda.is_available() else "cpu"
19
+ st.write(f"Using device: {device}")
20
+
21
+ # Load embedding model
22
+ embedding_model = SentenceTransformer(model_name_or_path="all-mpnet-base-v2", device=device)
23
+
24
+ # File upload for the database
25
+ database_file = st.file_uploader("Upload the Parquet database file", type=["parquet"])
26
+
27
+ if database_file is not None:
28
+ df, chunk_embeddings = load_data(database_file)
29
+ st.success("Database loaded successfully!")
30
+
31
+ query = st.text_area("Enter your query:")
32
+
33
+ if st.button("Search") and query:
34
+ query_embedding = embedding_model.encode(query)
35
+
36
+ # Compute dot product scores
37
+ start_time = timer()
38
+ dot_scores = util.dot_score(query_embedding, chunk_embeddings)[0]
39
+ end_time = timer()
40
+
41
+ st.write(f"Time taken to compute scores: {end_time - start_time:.5f} seconds")
42
+
43
+ # Get top results
44
+ top_k = st.slider("Select number of top results to display", min_value=1, max_value=10, value=5)
45
+ top_results_dot_product = t.topk(dot_scores, k=top_k)
46
+
47
+ st.subheader("Query Results")
48
+ st.write(f"Query: {query}")
49
+
50
+ for score, idx in zip(top_results_dot_product[0], top_results_dot_product[1]):
51
+ st.write(f"### Score: {score:.4f}")
52
+ st.write(f"**Text:** {df.iloc[int(idx)]["ext"]}")
53
+ st.write(f"**Number of tokens:** {df.iloc[int(idx)]['tokens']}")
54
+ st.write("---")
55
+
56
+ if __name__ == "__main__":
57
+ main()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ pandas
3
+ sentence-transformers