Update functions.py
Browse files- functions.py +179 -171
functions.py
CHANGED
@@ -139,6 +139,15 @@ def load_models():
|
|
139 |
|
140 |
return sent_pipe, sum_pipe, ner_pipe, cross_encoder, kg_model, kg_tokenizer, emb_tokenizer, sbert
|
141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
@st.cache_data
|
143 |
def get_yt_audio(url):
|
144 |
|
@@ -161,6 +170,14 @@ def load_whisper_api(audio):
|
|
161 |
|
162 |
return transcript
|
163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
def inference(link, upload, _asr_model):
|
165 |
'''Convert Youtube video or Audio upload to text'''
|
166 |
|
@@ -257,19 +274,53 @@ def inference(link, upload, _asr_model):
|
|
257 |
return results['text'], title
|
258 |
|
259 |
@st.cache_data
|
260 |
-
def
|
|
|
261 |
|
262 |
-
|
|
|
|
|
|
|
|
|
263 |
|
264 |
-
|
265 |
-
|
266 |
-
texts = text_splitter.split_text(corpus)
|
267 |
|
268 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
269 |
|
270 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
271 |
|
272 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
273 |
|
274 |
@st.cache_data
|
275 |
def chunk_and_preprocess_text(text,thresh=500):
|
@@ -296,114 +347,6 @@ def chunk_and_preprocess_text(text,thresh=500):
|
|
296 |
chunks[chunk_id] = " ".join(chunks[chunk_id])
|
297 |
|
298 |
return chunks
|
299 |
-
|
300 |
-
@st.cache_resource
|
301 |
-
def gen_embeddings(embedding_model):
|
302 |
-
|
303 |
-
'''Generate embeddings for given model'''
|
304 |
-
|
305 |
-
if 'hkunlp' in embedding_model:
|
306 |
-
|
307 |
-
embeddings = HuggingFaceInstructEmbeddings(model_name=embedding_model,
|
308 |
-
query_instruction='Represent the Financial question for retrieving supporting paragraphs: ',
|
309 |
-
embed_instruction='Represent the Financial paragraph for retrieval: ')
|
310 |
-
|
311 |
-
else:
|
312 |
-
|
313 |
-
embeddings = HuggingFaceEmbeddings(model_name=embedding_model)
|
314 |
-
|
315 |
-
return embeddings
|
316 |
-
|
317 |
-
def embed_text(query,embedding_model,_docsearch):
|
318 |
-
|
319 |
-
'''Embed text and generate semantic search scores'''
|
320 |
-
|
321 |
-
# llm = OpenAI(temperature=0)
|
322 |
-
chat_llm = ChatOpenAI(streaming=True,
|
323 |
-
model_name = 'gpt-4',
|
324 |
-
callbacks=[StdOutCallbackHandler()],
|
325 |
-
verbose=True,
|
326 |
-
temperature=0
|
327 |
-
)
|
328 |
-
|
329 |
-
# chain = RetrievalQA.from_chain_type(llm=chat_llm, chain_type="stuff",
|
330 |
-
# retriever=_docsearch.as_retriever(),
|
331 |
-
# return_source_documents=True)
|
332 |
-
|
333 |
-
chain = ConversationalRetrievalChain.from_llm(chat_llm,
|
334 |
-
retriever= _docsearch.as_retriever(search_kwargs={"k": 3}),
|
335 |
-
get_chat_history=lambda h : h,
|
336 |
-
memory = memory,
|
337 |
-
return_source_documents=True)
|
338 |
-
|
339 |
-
chain.combine_docs_chain.llm_chain.prompt.messages[0] = load_prompt()
|
340 |
-
|
341 |
-
answer = chain({"question": query})
|
342 |
-
|
343 |
-
return answer
|
344 |
-
|
345 |
-
@st.cache_data
|
346 |
-
def gen_sentiment(text):
|
347 |
-
'''Generate sentiment of given text'''
|
348 |
-
return sent_pipe(text)[0]['label']
|
349 |
-
|
350 |
-
@st.cache_data
|
351 |
-
def gen_annotated_text(df):
|
352 |
-
'''Generate annotated text'''
|
353 |
-
|
354 |
-
tag_list=[]
|
355 |
-
for row in df.itertuples():
|
356 |
-
label = row[2]
|
357 |
-
text = row[1]
|
358 |
-
if label == 'Positive':
|
359 |
-
tag_list.append((text,label,'#8fce00'))
|
360 |
-
elif label == 'Negative':
|
361 |
-
tag_list.append((text,label,'#f44336'))
|
362 |
-
else:
|
363 |
-
tag_list.append((text,label,'#000000'))
|
364 |
-
|
365 |
-
return tag_list
|
366 |
-
|
367 |
-
@st.cache_data
|
368 |
-
def generate_eval(raw_text, N, chunk):
|
369 |
-
|
370 |
-
# Generate N questions from context of chunk chars
|
371 |
-
# IN: text, N questions, chunk size to draw question from in the doc
|
372 |
-
# OUT: eval set as JSON list
|
373 |
-
|
374 |
-
# raw_text = ','.join(raw_text)
|
375 |
-
|
376 |
-
st.info("`Generating sample questions ...`")
|
377 |
-
n = len(raw_text)
|
378 |
-
starting_indices = [random.randint(0, n-chunk) for _ in range(N)]
|
379 |
-
sub_sequences = [raw_text[i:i+chunk] for i in starting_indices]
|
380 |
-
chain = QAGenerationChain.from_llm(ChatOpenAI(temperature=0))
|
381 |
-
eval_set = []
|
382 |
-
for i, b in enumerate(sub_sequences):
|
383 |
-
try:
|
384 |
-
qa = chain.run(b)
|
385 |
-
eval_set.append(qa)
|
386 |
-
st.write("Creating Question:",i+1)
|
387 |
-
except Exception as e:
|
388 |
-
st.warning('Error generating question %s.' % str(i+1), icon="⚠️")
|
389 |
-
#st.write(e)
|
390 |
-
eval_set_full = list(itertools.chain.from_iterable(eval_set))
|
391 |
-
return eval_set_full
|
392 |
-
|
393 |
-
@st.cache_resource
|
394 |
-
def get_spacy():
|
395 |
-
nlp = en_core_web_lg.load()
|
396 |
-
return nlp
|
397 |
-
|
398 |
-
|
399 |
-
@st.cache_data
|
400 |
-
def sentiment_pipe(earnings_text):
|
401 |
-
'''Determine the sentiment of the text'''
|
402 |
-
|
403 |
-
earnings_sentences = chunk_long_text(earnings_text,150,1,1)
|
404 |
-
earnings_sentiment = sent_pipe(earnings_sentences)
|
405 |
-
|
406 |
-
return earnings_sentiment, earnings_sentences
|
407 |
|
408 |
@st.cache_data
|
409 |
def summarize_text(text_to_summarize,max_len,min_len):
|
@@ -416,56 +359,7 @@ def summarize_text(text_to_summarize,max_len,min_len):
|
|
416 |
early_stopping=True)
|
417 |
summarized_text = ' '.join([summ['summary_text'] for summ in summarized_text])
|
418 |
|
419 |
-
return summarized_text
|
420 |
-
|
421 |
-
@st.cache_data
|
422 |
-
def clean_text(text):
|
423 |
-
'''Clean all text'''
|
424 |
-
|
425 |
-
text = text.encode("ascii", "ignore").decode() # unicode
|
426 |
-
text = re.sub(r"https*\S+", " ", text) # url
|
427 |
-
text = re.sub(r"@\S+", " ", text) # mentions
|
428 |
-
text = re.sub(r"#\S+", " ", text) # hastags
|
429 |
-
text = re.sub(r"\s{2,}", " ", text) # over spaces
|
430 |
-
|
431 |
-
return text
|
432 |
-
|
433 |
-
@st.cache_data
|
434 |
-
def chunk_long_text(text,threshold,window_size=3,stride=2):
|
435 |
-
'''Preprocess text and chunk for sentiment analysis'''
|
436 |
-
|
437 |
-
#Convert cleaned text into sentences
|
438 |
-
sentences = sent_tokenize(text)
|
439 |
-
out = []
|
440 |
-
|
441 |
-
#Limit the length of each sentence to a threshold
|
442 |
-
for chunk in sentences:
|
443 |
-
if len(chunk.split()) < threshold:
|
444 |
-
out.append(chunk)
|
445 |
-
else:
|
446 |
-
words = chunk.split()
|
447 |
-
num = int(len(words)/threshold)
|
448 |
-
for i in range(0,num*threshold+1,threshold):
|
449 |
-
out.append(' '.join(words[i:threshold+i]))
|
450 |
-
|
451 |
-
passages = []
|
452 |
-
|
453 |
-
#Combine sentences into a window of size window_size
|
454 |
-
for paragraph in [out]:
|
455 |
-
for start_idx in range(0, len(paragraph), stride):
|
456 |
-
end_idx = min(start_idx+window_size, len(paragraph))
|
457 |
-
passages.append(" ".join(paragraph[start_idx:end_idx]))
|
458 |
-
|
459 |
-
return passages
|
460 |
-
|
461 |
-
|
462 |
-
def summary_downloader(raw_text):
|
463 |
-
|
464 |
-
b64 = base64.b64encode(raw_text.encode()).decode()
|
465 |
-
new_filename = "new_text_file_{}_.txt".format(time_str)
|
466 |
-
st.markdown("#### Download Summary as a File ###")
|
467 |
-
href = f'<a href="data:file/txt;base64,{b64}" download="{new_filename}">Click to Download!!</a>'
|
468 |
-
st.markdown(href,unsafe_allow_html=True)
|
469 |
|
470 |
@st.cache_data
|
471 |
def get_all_entities_per_sentence(text):
|
@@ -489,7 +383,7 @@ def get_all_entities_per_sentence(text):
|
|
489 |
entities_all_sentences.append(entities_this_sentence)
|
490 |
|
491 |
return entities_all_sentences
|
492 |
-
|
493 |
@st.cache_data
|
494 |
def get_all_entities(text):
|
495 |
all_entities_per_sentence = get_all_entities_per_sentence(text)
|
@@ -569,6 +463,124 @@ def highlight_entities(article_content,summary_output):
|
|
569 |
soup = BeautifulSoup(summary_output, features="html.parser")
|
570 |
|
571 |
return HTML_WRAPPER.format(soup)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
572 |
|
573 |
|
574 |
def display_df_as_table(model,top_k,score='score'):
|
@@ -909,7 +921,3 @@ def save_network_html(kb, filename="network.html"):
|
|
909 |
)
|
910 |
net.set_edge_smooth('dynamic')
|
911 |
net.show(filename)
|
912 |
-
|
913 |
-
nlp = get_spacy()
|
914 |
-
|
915 |
-
sent_pipe, sum_pipe, ner_pipe, cross_encoder, kg_model, kg_tokenizer, emb_tokenizer, sbert = load_models()
|
|
|
139 |
|
140 |
return sent_pipe, sum_pipe, ner_pipe, cross_encoder, kg_model, kg_tokenizer, emb_tokenizer, sbert
|
141 |
|
142 |
+
@st.cache_resource
|
143 |
+
def get_spacy():
|
144 |
+
nlp = en_core_web_lg.load()
|
145 |
+
return nlp
|
146 |
+
|
147 |
+
nlp = get_spacy()
|
148 |
+
|
149 |
+
sent_pipe, sum_pipe, ner_pipe, cross_encoder, kg_model, kg_tokenizer, emb_tokenizer, sbert = load_models()
|
150 |
+
|
151 |
@st.cache_data
|
152 |
def get_yt_audio(url):
|
153 |
|
|
|
170 |
|
171 |
return transcript
|
172 |
|
173 |
+
@st.cache_data
|
174 |
+
def load_asr_model(model_name):
|
175 |
+
|
176 |
+
'''Load the open source whisper model in cases where the API is not working'''
|
177 |
+
model = whisper.load_model(model_name)
|
178 |
+
|
179 |
+
return model
|
180 |
+
|
181 |
def inference(link, upload, _asr_model):
|
182 |
'''Convert Youtube video or Audio upload to text'''
|
183 |
|
|
|
274 |
return results['text'], title
|
275 |
|
276 |
@st.cache_data
|
277 |
+
def clean_text(text):
|
278 |
+
'''Clean all text after inference'''
|
279 |
|
280 |
+
text = text.encode("ascii", "ignore").decode() # unicode
|
281 |
+
text = re.sub(r"https*\S+", " ", text) # url
|
282 |
+
text = re.sub(r"@\S+", " ", text) # mentions
|
283 |
+
text = re.sub(r"#\S+", " ", text) # hastags
|
284 |
+
text = re.sub(r"\s{2,}", " ", text) # over spaces
|
285 |
|
286 |
+
return text
|
|
|
|
|
287 |
|
288 |
+
@st.cache_data
|
289 |
+
def chunk_long_text(text,threshold,window_size=3,stride=2):
|
290 |
+
'''Preprocess text and chunk for sentiment analysis'''
|
291 |
+
|
292 |
+
#Convert cleaned text into sentences
|
293 |
+
sentences = sent_tokenize(text)
|
294 |
+
out = []
|
295 |
|
296 |
+
#Limit the length of each sentence to a threshold
|
297 |
+
for chunk in sentences:
|
298 |
+
if len(chunk.split()) < threshold:
|
299 |
+
out.append(chunk)
|
300 |
+
else:
|
301 |
+
words = chunk.split()
|
302 |
+
num = int(len(words)/threshold)
|
303 |
+
for i in range(0,num*threshold+1,threshold):
|
304 |
+
out.append(' '.join(words[i:threshold+i]))
|
305 |
+
|
306 |
+
passages = []
|
307 |
+
|
308 |
+
#Combine sentences into a window of size window_size
|
309 |
+
for paragraph in [out]:
|
310 |
+
for start_idx in range(0, len(paragraph), stride):
|
311 |
+
end_idx = min(start_idx+window_size, len(paragraph))
|
312 |
+
passages.append(" ".join(paragraph[start_idx:end_idx]))
|
313 |
+
|
314 |
+
return passages
|
315 |
|
316 |
+
@st.cache_data
|
317 |
+
def sentiment_pipe(earnings_text):
|
318 |
+
'''Determine the sentiment of the text'''
|
319 |
+
|
320 |
+
earnings_sentences = chunk_long_text(earnings_text,150,1,1)
|
321 |
+
earnings_sentiment = sent_pipe(earnings_sentences)
|
322 |
+
|
323 |
+
return earnings_sentiment, earnings_sentences
|
324 |
|
325 |
@st.cache_data
|
326 |
def chunk_and_preprocess_text(text,thresh=500):
|
|
|
347 |
chunks[chunk_id] = " ".join(chunks[chunk_id])
|
348 |
|
349 |
return chunks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
350 |
|
351 |
@st.cache_data
|
352 |
def summarize_text(text_to_summarize,max_len,min_len):
|
|
|
359 |
early_stopping=True)
|
360 |
summarized_text = ' '.join([summ['summary_text'] for summ in summarized_text])
|
361 |
|
362 |
+
return summarized_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
363 |
|
364 |
@st.cache_data
|
365 |
def get_all_entities_per_sentence(text):
|
|
|
383 |
entities_all_sentences.append(entities_this_sentence)
|
384 |
|
385 |
return entities_all_sentences
|
386 |
+
|
387 |
@st.cache_data
|
388 |
def get_all_entities(text):
|
389 |
all_entities_per_sentence = get_all_entities_per_sentence(text)
|
|
|
463 |
soup = BeautifulSoup(summary_output, features="html.parser")
|
464 |
|
465 |
return HTML_WRAPPER.format(soup)
|
466 |
+
|
467 |
+
def summary_downloader(raw_text):
|
468 |
+
|
469 |
+
'''Download the summary generated'''
|
470 |
+
|
471 |
+
b64 = base64.b64encode(raw_text.encode()).decode()
|
472 |
+
new_filename = "new_text_file_{}_.txt".format(time_str)
|
473 |
+
st.markdown("#### Download Summary as a File ###")
|
474 |
+
href = f'<a href="data:file/txt;base64,{b64}" download="{new_filename}">Click to Download!!</a>'
|
475 |
+
st.markdown(href,unsafe_allow_html=True)
|
476 |
+
|
477 |
+
@st.cache_data
|
478 |
+
def generate_eval(raw_text, N, chunk):
|
479 |
+
|
480 |
+
# Generate N questions from context of chunk chars
|
481 |
+
# IN: text, N questions, chunk size to draw question from in the doc
|
482 |
+
# OUT: eval set as JSON list
|
483 |
+
|
484 |
+
# raw_text = ','.join(raw_text)
|
485 |
+
|
486 |
+
st.info("`Generating sample questions ...`")
|
487 |
+
n = len(raw_text)
|
488 |
+
starting_indices = [random.randint(0, n-chunk) for _ in range(N)]
|
489 |
+
sub_sequences = [raw_text[i:i+chunk] for i in starting_indices]
|
490 |
+
chain = QAGenerationChain.from_llm(ChatOpenAI(temperature=0))
|
491 |
+
eval_set = []
|
492 |
+
for i, b in enumerate(sub_sequences):
|
493 |
+
try:
|
494 |
+
qa = chain.run(b)
|
495 |
+
eval_set.append(qa)
|
496 |
+
st.write("Creating Question:",i+1)
|
497 |
+
except Exception as e:
|
498 |
+
st.warning('Error generating question %s.' % str(i+1), icon="⚠️")
|
499 |
+
#st.write(e)
|
500 |
+
eval_set_full = list(itertools.chain.from_iterable(eval_set))
|
501 |
+
return eval_set_full
|
502 |
+
|
503 |
+
@st.cache_resource
|
504 |
+
def gen_embeddings(embedding_model):
|
505 |
+
|
506 |
+
'''Generate embeddings for given model'''
|
507 |
+
|
508 |
+
if 'hkunlp' in embedding_model:
|
509 |
+
|
510 |
+
embeddings = HuggingFaceInstructEmbeddings(model_name=embedding_model,
|
511 |
+
query_instruction='Represent the Financial question for retrieving supporting paragraphs: ',
|
512 |
+
embed_instruction='Represent the Financial paragraph for retrieval: ')
|
513 |
+
|
514 |
+
else:
|
515 |
+
|
516 |
+
embeddings = HuggingFaceEmbeddings(model_name=embedding_model)
|
517 |
+
|
518 |
+
return embeddings
|
519 |
+
|
520 |
+
@st.cache_data
|
521 |
+
def process_corpus(corpus, title, embedding_model, chunk_size=1000, overlap=50):
|
522 |
+
|
523 |
+
'''Process text for Semantic Search'''
|
524 |
+
|
525 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size,chunk_overlap=overlap)
|
526 |
+
|
527 |
+
texts = text_splitter.split_text(corpus)
|
528 |
+
|
529 |
+
embeddings = gen_embeddings(embedding_model)
|
530 |
+
|
531 |
+
vectorstore = FAISS.from_texts(texts, embeddings, metadatas=[{"source": i} for i in range(len(texts))])
|
532 |
+
|
533 |
+
return vectorstore
|
534 |
+
|
535 |
+
def embed_text(query,_docsearch):
|
536 |
+
|
537 |
+
'''Embed text and generate semantic search scores'''
|
538 |
+
|
539 |
+
# llm = OpenAI(temperature=0)
|
540 |
+
chat_llm = ChatOpenAI(streaming=True,
|
541 |
+
model_name = 'gpt-4',
|
542 |
+
callbacks=[StdOutCallbackHandler()],
|
543 |
+
verbose=True,
|
544 |
+
temperature=0
|
545 |
+
)
|
546 |
+
|
547 |
+
# chain = RetrievalQA.from_chain_type(llm=chat_llm, chain_type="stuff",
|
548 |
+
# retriever=_docsearch.as_retriever(),
|
549 |
+
# return_source_documents=True)
|
550 |
+
|
551 |
+
chain = ConversationalRetrievalChain.from_llm(chat_llm,
|
552 |
+
retriever= _docsearch.as_retriever(search_kwargs={"k": 3}),
|
553 |
+
get_chat_history=lambda h : h,
|
554 |
+
memory = memory,
|
555 |
+
return_source_documents=True)
|
556 |
+
|
557 |
+
chain.combine_docs_chain.llm_chain.prompt.messages[0] = load_prompt()
|
558 |
+
|
559 |
+
answer = chain({"question": query})
|
560 |
+
|
561 |
+
return answer
|
562 |
+
|
563 |
+
@st.cache_data
|
564 |
+
def gen_sentiment(text):
|
565 |
+
'''Generate sentiment of given text'''
|
566 |
+
return sent_pipe(text)[0]['label']
|
567 |
+
|
568 |
+
@st.cache_data
|
569 |
+
def gen_annotated_text(df):
|
570 |
+
'''Generate annotated text'''
|
571 |
+
|
572 |
+
tag_list=[]
|
573 |
+
for row in df.itertuples():
|
574 |
+
label = row[2]
|
575 |
+
text = row[1]
|
576 |
+
if label == 'Positive':
|
577 |
+
tag_list.append((text,label,'#8fce00'))
|
578 |
+
elif label == 'Negative':
|
579 |
+
tag_list.append((text,label,'#f44336'))
|
580 |
+
else:
|
581 |
+
tag_list.append((text,label,'#000000'))
|
582 |
+
|
583 |
+
return tag_list
|
584 |
|
585 |
|
586 |
def display_df_as_table(model,top_k,score='score'):
|
|
|
921 |
)
|
922 |
net.set_edge_smooth('dynamic')
|
923 |
net.show(filename)
|
|
|
|
|
|
|
|