Arthur-75 commited on
Commit
5bcbfd6
1 Parent(s): d665c2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -90
app.py CHANGED
@@ -1,83 +1,10 @@
1
- "https://sites.google.com/airliquide.com/sis-processsafety/knowledge/lessons-learned-sources"
2
  from sentence_transformers import SentenceTransformer
3
-
4
- #model=SentenceTransformer("all-mpnet-base-v2")
5
  import streamlit as st
6
- import pickle
7
- import faiss
8
- from llama_index.core import VectorStoreIndex,StorageContext
9
- from llama_index.embeddings.huggingface import HuggingFaceEmbedding
10
- from llama_index.vector_stores.faiss import FaissVectorStore
11
- from llama_index.core import VectorStoreIndex
12
- from llama_index.retrievers.bm25 import BM25Retriever
13
- from llama_index.core.postprocessor import SentenceTransformerRerank
14
- from llama_index.core import QueryBundle
15
- from llama_index.core.schema import NodeWithScore
16
- from llama_index.core.retrievers import BaseRetriever
17
- from transformers import AutoTokenizer, AutoModel
18
-
19
-
20
- @st.cache_resource(show_spinner=False)
21
- def load_data():
22
- with open('nodes_clean.pkl', 'rb') as file:
23
- nodes=pickle.load( file)
24
- d = 768
25
- faiss_index = faiss.IndexFlatL2(d)
26
- vector_store = FaissVectorStore(faiss_index=faiss_index )
27
- storage_context = StorageContext.from_defaults(vector_store=vector_store)
28
- # use later nodes_clean
29
- index = VectorStoreIndex(nodes,embed_model=embed_model,storage_context=storage_context)
30
- retriever_dense = index.as_retriever(similarity_top_k=25,embedding=True)
31
- retrieverBM25 = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=25)
32
- hybrid_retriever = HybridRetriever(retriever_dense, retrieverBM25)
33
- return hybrid_retriever
34
-
35
-
36
-
37
- @st.cache_resource(show_spinner=False)
38
- class models():
39
- def __init__(self):
40
- EMBEDDING_MODEL="BAAI/llm-embedder"
41
- self.embed_model = HuggingFaceEmbedding(EMBEDDING_MODEL,device='cpu',)
42
- self.reranker = SentenceTransformerRerank(top_n=25, model="BAAI/bge-reranker-base",device='cpu',)
43
-
44
- mod=models()
45
- embed_model=mod.embed_model
46
- reranker= mod.reranker
47
-
48
- class HybridRetriever(BaseRetriever):
49
- def __init__(self, vector_retriever, bm25_retriever):
50
- self.vector_retriever = vector_retriever
51
- self.bm25_retriever = bm25_retriever
52
- super().__init__()
53
- def _retrieve(self, query, **kwargs):
54
- bm25_nodes = self.bm25_retriever.retrieve(query, **kwargs)
55
- vector_nodes = self.vector_retriever.retrieve(query, **kwargs)
56
- # combine the two lists of nodes
57
- all_nodes = []
58
- node_ids = set()
59
- for n in bm25_nodes + vector_nodes:
60
- if n.node.node_id not in node_ids:
61
- all_nodes.append(n)
62
- node_ids.add(n.node.node_id)
63
- return all_nodes
64
- hybrid_retriever = load_data()
65
-
66
-
67
-
68
-
69
-
70
- import re
71
- def clean_whitespace(text,k=5):
72
- text = text.strip()
73
- text=" ".join([i for i in text.split("\n")[:k] if len(i.strip())>25]+text.split("\n")[k:])
74
- text = re.sub(r"\.EU", "", text)
75
- #text = re.sub(r"\n+", "\n", text)
76
- text = re.sub(r"\s+", " ", text)
77
- return text.lower()
78
-
79
 
80
 
 
 
81
 
82
 
83
  def stream(reranked_nodes,text_size=700):
@@ -92,16 +19,11 @@ def stream(reranked_nodes,text_size=700):
92
  file_name = i_di[0].metadata['file_name']
93
  summary = i_di[0].metadata['text']
94
  url = i_di[0].metadata['doc_url']
95
-
96
  st.write(f"**Rank {rank+1}:** {file_name} ")
97
  st.write(f"- Tittle: [{title}](%s)"% url)
98
  #st.write("check out this [link](%s)" % url)
99
  with st.expander(f"Summary"):
100
  st.write(f"{summary}")
101
- #st.write(f"- Summary: {summery}")
102
- #st.link_button("Link", url)
103
-
104
- #st.write(f"- URL: {url}")
105
  with st.expander(f"Extra Text(s) "):
106
  for n_extra,t in enumerate(i_di[:5]):
107
  page_n=t.metadata['page_label'] if "page_label" in t.metadata else 'Unknown'
@@ -110,17 +32,16 @@ def stream(reranked_nodes,text_size=700):
110
  st.markdown("""---""")
111
  st.markdown("""---""")
112
 
113
- #stream(reranked_nodes,150)
114
 
115
  # Function to perform search and return sorted documents
116
  def perform_search(query):
117
  if query:
118
  retrieved_nodes = hybrid_retriever.retrieve(query)
119
- reranked_nodes = reranker.postprocess_nodes(
120
  retrieved_nodes,
121
- query_bundle=QueryBundle(
122
- query
123
- ),)
124
  return reranked_nodes
125
  else:
126
  return []
@@ -133,11 +54,9 @@ def main():
133
  st.title("Information Retrieval System")
134
  query = st.text_input("Enter your question:")
135
 
136
-
137
-
138
  if st.button("Search") or query:
139
  sorted_docs = perform_search(query)
140
- #st.session_state.sorted_docs = sorted_docs
141
 
142
  else:
143
  sorted_docs = st.session_state.get("sorted_docs", [])
@@ -145,7 +64,7 @@ def main():
145
 
146
 
147
  if sorted_docs:
148
- stream(sorted_docs,500)
149
  #st.write(f"Current Page Number: {page_number}")
150
 
151
 
 
 
1
  from sentence_transformers import SentenceTransformer
2
+ from utils_st import load_models,load_data,clean_whitespace
 
3
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
 
6
+ embed_model, reranker = load_models()
7
+ hybrid_retriever = load_data(embed_model)
8
 
9
 
10
  def stream(reranked_nodes,text_size=700):
 
19
  file_name = i_di[0].metadata['file_name']
20
  summary = i_di[0].metadata['text']
21
  url = i_di[0].metadata['doc_url']
 
22
  st.write(f"**Rank {rank+1}:** {file_name} ")
23
  st.write(f"- Tittle: [{title}](%s)"% url)
24
  #st.write("check out this [link](%s)" % url)
25
  with st.expander(f"Summary"):
26
  st.write(f"{summary}")
 
 
 
 
27
  with st.expander(f"Extra Text(s) "):
28
  for n_extra,t in enumerate(i_di[:5]):
29
  page_n=t.metadata['page_label'] if "page_label" in t.metadata else 'Unknown'
 
32
  st.markdown("""---""")
33
  st.markdown("""---""")
34
 
35
+
36
 
37
  # Function to perform search and return sorted documents
38
  def perform_search(query):
39
  if query:
40
  retrieved_nodes = hybrid_retriever.retrieve(query)
41
+ reranked_nodes = reranker.predict(
42
  retrieved_nodes,
43
+ query_bundle=query
44
+ )
 
45
  return reranked_nodes
46
  else:
47
  return []
 
54
  st.title("Information Retrieval System")
55
  query = st.text_input("Enter your question:")
56
 
 
 
57
  if st.button("Search") or query:
58
  sorted_docs = perform_search(query)
59
+ st.session_state.sorted_docs = sorted_docs
60
 
61
  else:
62
  sorted_docs = st.session_state.get("sorted_docs", [])
 
64
 
65
 
66
  if sorted_docs:
67
+ stream(sorted_docs,700)
68
  #st.write(f"Current Page Number: {page_number}")
69
 
70