Update functions.py
Browse files- functions.py +4 -95
functions.py
CHANGED
@@ -119,15 +119,6 @@ def load_asr_model(asr_model_name):
|
|
119 |
asr_model = whisper.load_model(asr_model_name)
|
120 |
|
121 |
return asr_model
|
122 |
-
|
123 |
-
# @st.experimental_singleton(suppress_st_warning=True)
|
124 |
-
# def load_sbert(model_name):
|
125 |
-
# if 'hkunlp' in model_name:
|
126 |
-
# sbert = INSTRUCTOR(model_name)
|
127 |
-
# else:
|
128 |
-
# sbert = SentenceTransformer(model_name)
|
129 |
-
|
130 |
-
# return sbert
|
131 |
|
132 |
@st.experimental_singleton(suppress_st_warning=True)
|
133 |
def process_corpus(corpus, tok, title, embeddings, chunk_size=200, overlap=50):
|
@@ -185,7 +176,7 @@ def embed_text(query,corpus,title,embedding_model,emb_tok,chain_type='stuff'):
|
|
185 |
|
186 |
docs = [d[0] for d in docs]
|
187 |
|
188 |
-
if chain_type == '
|
189 |
|
190 |
PROMPT = PromptTemplate(template=template,
|
191 |
input_variables=["summaries", "question"],
|
@@ -200,7 +191,7 @@ def embed_text(query,corpus,title,embedding_model,emb_tok,chain_type='stuff'):
|
|
200 |
|
201 |
return answer['output_text']
|
202 |
|
203 |
-
elif chain_type == '
|
204 |
|
205 |
initial_qa_prompt = PromptTemplate(
|
206 |
input_variables=["context_str", "question"], template=initial_qa_template
|
@@ -211,62 +202,6 @@ def embed_text(query,corpus,title,embedding_model,emb_tok,chain_type='stuff'):
|
|
211 |
|
212 |
return answer['output_text']
|
213 |
|
214 |
-
# @st.experimental_memo(suppress_st_warning=True)
|
215 |
-
# def embed_text(query,corpus,embedding_model):
|
216 |
-
|
217 |
-
# '''Embed text and generate semantic search scores'''
|
218 |
-
|
219 |
-
# #If model is e5 then apply prefixes to query and passage
|
220 |
-
# if embedding_model == 'intfloat/e5-base':
|
221 |
-
# search_input = 'query: '+ query
|
222 |
-
# passages_emb = ['passage: ' + sentence for sentence in corpus]
|
223 |
-
|
224 |
-
# elif embedding_model == 'hkunlp/instructor-base':
|
225 |
-
# search_input = [['Represent the Financial question for retrieving supporting paragraphs: ', query]]
|
226 |
-
# passages_emb = [['Represent the Financial paragraph for retrieval: ',sentence] for sentence in corpus]
|
227 |
-
|
228 |
-
# else:
|
229 |
-
# search_input = query
|
230 |
-
# passages_emb = corpus
|
231 |
-
|
232 |
-
|
233 |
-
# #Embed corpus and question
|
234 |
-
# corpus_embedding = sbert.encode(passages_emb, convert_to_tensor=True)
|
235 |
-
# question_embedding = sbert.encode(search_input, convert_to_tensor=True)
|
236 |
-
# question_embedding = question_embedding.cpu()
|
237 |
-
# corpus_embedding = corpus_embedding.cpu()
|
238 |
-
|
239 |
-
# # #Calculate similarity scores and rank
|
240 |
-
# hits = util.semantic_search(question_embedding, corpus_embedding, top_k=2)
|
241 |
-
# hits = hits[0] # Get the hits for the first query
|
242 |
-
|
243 |
-
# # ##### Re-Ranking #####
|
244 |
-
# # Now, score all retrieved passages with the cross_encoder
|
245 |
-
# cross_inp = [[search_input, corpus[hit['corpus_id']]] for hit in hits]
|
246 |
-
|
247 |
-
# if embedding_model == 'hkunlp/instructor-base':
|
248 |
-
# result = []
|
249 |
-
|
250 |
-
# for sublist in cross_inp:
|
251 |
-
# question = sublist[0][0][1]
|
252 |
-
# document = sublist[1][1]
|
253 |
-
# result.append([question, document])
|
254 |
-
|
255 |
-
# cross_inp = result
|
256 |
-
|
257 |
-
# cross_scores = cross_encoder.predict(cross_inp)
|
258 |
-
|
259 |
-
# # Sort results by the cross-encoder scores
|
260 |
-
# for idx in range(len(cross_scores)):
|
261 |
-
# hits[idx]['cross-score'] = cross_scores[idx]
|
262 |
-
|
263 |
-
# # Output of top-3 hits from re-ranker
|
264 |
-
# # st.markdown("\n-------------------------\n")
|
265 |
-
# # st.subheader(f"Top-{top_k} Cross-Encoder Re-ranker hits")
|
266 |
-
# hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
|
267 |
-
|
268 |
-
# return hits
|
269 |
-
|
270 |
@st.experimental_singleton(suppress_st_warning=True)
|
271 |
def get_spacy():
|
272 |
nlp = en_core_web_lg.load()
|
@@ -350,32 +285,7 @@ def chunk_long_text(text,threshold,window_size=3,stride=2):
|
|
350 |
end_idx = min(start_idx+window_size, len(paragraph))
|
351 |
passages.append(" ".join(paragraph[start_idx:end_idx]))
|
352 |
|
353 |
-
return passages
|
354 |
-
|
355 |
-
@st.experimental_memo(suppress_st_warning=True)
|
356 |
-
def chunk_and_preprocess_text(text,thresh=500):
|
357 |
-
|
358 |
-
"""Chunk text longer than n tokens for summarization"""
|
359 |
-
|
360 |
-
sentences = sent_tokenize(text)
|
361 |
-
|
362 |
-
current_chunk = 0
|
363 |
-
chunks = []
|
364 |
-
|
365 |
-
for sentence in sentences:
|
366 |
-
if len(chunks) == current_chunk + 1:
|
367 |
-
if len(chunks[current_chunk]) + len(sentence.split(" ")) <= thresh:
|
368 |
-
chunks[current_chunk].extend(sentence.split(" "))
|
369 |
-
else:
|
370 |
-
current_chunk += 1
|
371 |
-
chunks.append(sentence.split(" "))
|
372 |
-
else:
|
373 |
-
chunks.append(sentence.split(" "))
|
374 |
-
|
375 |
-
for chunk_id in range(len(chunks)):
|
376 |
-
chunks[chunk_id] = " ".join(chunks[chunk_id])
|
377 |
-
|
378 |
-
return chunks
|
379 |
|
380 |
|
381 |
def summary_downloader(raw_text):
|
@@ -830,5 +740,4 @@ def save_network_html(kb, filename="network.html"):
|
|
830 |
|
831 |
|
832 |
nlp = get_spacy()
|
833 |
-
sent_pipe, sum_pipe, ner_pipe, cross_encoder, kg_model, kg_tokenizer, emb_tokenizer = load_models()
|
834 |
-
sbert = load_sbert('all-MiniLM-L12-v2')
|
|
|
119 |
asr_model = whisper.load_model(asr_model_name)
|
120 |
|
121 |
return asr_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
@st.experimental_singleton(suppress_st_warning=True)
|
124 |
def process_corpus(corpus, tok, title, embeddings, chunk_size=200, overlap=50):
|
|
|
176 |
|
177 |
docs = [d[0] for d in docs]
|
178 |
|
179 |
+
if chain_type == 'Normal':
|
180 |
|
181 |
PROMPT = PromptTemplate(template=template,
|
182 |
input_variables=["summaries", "question"],
|
|
|
191 |
|
192 |
return answer['output_text']
|
193 |
|
194 |
+
elif chain_type == 'Refined':
|
195 |
|
196 |
initial_qa_prompt = PromptTemplate(
|
197 |
input_variables=["context_str", "question"], template=initial_qa_template
|
|
|
202 |
|
203 |
return answer['output_text']
|
204 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
@st.experimental_singleton(suppress_st_warning=True)
|
206 |
def get_spacy():
|
207 |
nlp = en_core_web_lg.load()
|
|
|
285 |
end_idx = min(start_idx+window_size, len(paragraph))
|
286 |
passages.append(" ".join(paragraph[start_idx:end_idx]))
|
287 |
|
288 |
+
return passages
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
289 |
|
290 |
|
291 |
def summary_downloader(raw_text):
|
|
|
740 |
|
741 |
|
742 |
nlp = get_spacy()
|
743 |
+
sent_pipe, sum_pipe, ner_pipe, cross_encoder, kg_model, kg_tokenizer, emb_tokenizer = load_models()
|
|