import copy import streamlit as st import pandas as pd from sentence_transformers import SentenceTransformer, util from sentence_transformers.cross_encoder import CrossEncoder from st_aggrid import GridOptionsBuilder, AgGrid import pickle import torch from transformers import DPRQuestionEncoderTokenizer, AutoModel from pathlib import Path import base64 import regex import tokenizers st.set_page_config(layout="wide") DATAFRAME_FILE_ORIGINAL = 'policyQA_original.csv' DATAFRAME_FILE_BSBS = 'policyQA_bsbs_sentence.csv' selectbox_selections = { 'Retrieve - Rerank (with fine-tuned cross-encoder)': 1, 'Dense Passage Retrieval':2, 'Retrieve - Reranking with DPR':3, 'Retrieve - Rerank':4 } imagebox_selections = { 'Retrieve - Rerank (with fine-tuned cross-encoder)': 'Retrieve-rerank-trained-cross-encoder.png', 'Dense Passage Retrieval': 'DPR_pipeline.png', 'Retrieve - Reranking with DPR': 'Retrieve-rerank-DPR.png', 'Retrieve - Rerank': 'retrieve-rerank.png' } def retrieve_rerank(question): # Semantic Search (Retrieve) question_embedding = bi_encoder.encode(question, convert_to_tensor=True) hits = util.semantic_search(question_embedding, context_embeddings, top_k=100) if len(hits) == 0: return [] hits = hits[0] # Rerank - score all retrieved passages with cross-encoder cross_inp = [[question, contexes[hit['corpus_id']]] for hit in hits] cross_scores = cross_encoder.predict(cross_inp) # Sort results by the cross-encoder scores for idx in range(len(cross_scores)): hits[idx]['cross-score'] = cross_scores[idx] # Output of top-5 hits from re-ranker hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True) top_5_contexes = [] top_5_scores = [] for hit in hits[0:20]: top_5_contexes.append(contexes[hit['corpus_id']]) top_5_scores.append(hit['cross-score']) return top_5_contexes, top_5_scores @st.cache(show_spinner=False, allow_output_mutation=True) def load_paragraphs(path): with open(path, "rb") as fIn: cache_data = pickle.load(fIn) corpus_sentences = cache_data['contexes'] corpus_embeddings = cache_data['embeddings'] return corpus_embeddings, corpus_sentences @st.cache(show_spinner=False) def load_dataframes(): data_original = pd.read_csv(DATAFRAME_FILE_ORIGINAL, index_col=0, sep='|') data_bsbs = pd.read_csv(DATAFRAME_FILE_BSBS, index_col=0, sep='|') data_original = data_original.sample(frac=1).reset_index(drop=True) data_bsbs = data_bsbs.sample(frac=1).reset_index(drop=True) return data_original, data_bsbs def dot_product(question_output, context_output): mat1 = torch.unsqueeze(question_output, dim=1) mat2 = torch.unsqueeze(context_output, dim=2) result = torch.bmm(mat1, mat2) result = torch.squeeze(result, dim=1) result = torch.squeeze(result, dim=1) return result def retrieve_rerank_DPR(question): hits = retrieve_with_dpr_embeddings(question) return rerank_with_DPR(hits, question) def DPR_reranking(question, selected_contexes, selected_embeddings): scores = [] tokenized_question = question_tokenizer(question, padding=True, truncation=True, return_tensors="pt", add_special_tokens=True) question_output = dpr_trained.model.question_model(**tokenized_question) question_output = question_output['pooler_output'] for context_embedding in selected_embeddings: score = dot_product(question_output, context_embedding) scores.append(score.detach().cpu()) scores_index = sorted(range(len(scores)), key=lambda x: scores[x], reverse=True) contexes_list = [] scores_final = [] for i, idx in enumerate(scores_index[:5]): scores_final.append(scores[idx]) contexes_list.append(selected_contexes[idx]) return scores_final, contexes_list def search_pipeline(question, search_method): if search_method == 1: #Retrieve - rerank with fine-tuned cross encoder return retrieve_rerank_with_trained_cross_encoder(question) if search_method == 2: return custom_dpr_pipeline(question) # DPR only if search_method == 3: return retrieve_rerank_DPR(question) if search_method == 4: return retrieve_rerank(question) def custom_dpr_pipeline(question): #paragraphs tokenized_question = question_tokenizer(question, padding=True, truncation=True, return_tensors="pt", add_special_tokens=True) question_embedding = dpr_trained.model.question_model(**tokenized_question) question_embedding = question_embedding['pooler_output'] results_list = [] for i,context_embedding in enumerate(dpr_context_embeddings): score = dot_product(question_embedding, context_embedding) results_list.append(score.detach().cpu().numpy()[0]) hits = sorted(range(len(results_list)), key=lambda i: results_list[i], reverse=True) top_5_contexes = [] top_5_scores = [] for j in hits[0:5]: top_5_contexes.append(dpr_contexes[j]) top_5_scores.append(results_list[j]) return top_5_contexes, top_5_scores def retrieve(question, corpus_embeddings): # Semantic Search (Retrieve) question_embedding = bi_encoder.encode(question, convert_to_tensor=True) hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=100) if len(hits) == 0: return [] hits = hits[0] return hits def retrieve_with_dpr_embeddings(question): # Semantic Search (Retrieve) question_tokens = question_tokenizer(question, padding=True, truncation=True, return_tensors="pt", add_special_tokens=True) question_embedding = dpr_trained.model.question_model(**question_tokens)['pooler_output'] question_embedding = torch.squeeze(question_embedding, dim=0) corpus_embeddings = torch.stack(dpr_context_embeddings) corpus_embeddings = torch.squeeze(corpus_embeddings, dim=1) hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=100) if len(hits) == 0: return [] hits = hits[0] return hits def rerank_with_DPR(hits, question): # Rerank - score all retrieved passages with cross-encoder selected_contexes = [dpr_contexes[hit['corpus_id']] for hit in hits] selected_embeddings = [dpr_context_embeddings[hit['corpus_id']] for hit in hits] top_5_scores, top_5_contexes = DPR_reranking(question, selected_contexes, selected_embeddings) return top_5_contexes, top_5_scores def DPR_reranking(question, selected_contexes, selected_embeddings): scores = [] tokenized_question = question_tokenizer(question, padding=True, truncation=True, return_tensors="pt", add_special_tokens=True) question_output = dpr_trained.model.question_model(**tokenized_question) question_output = question_output['pooler_output'] for context_embedding in selected_embeddings: score = dot_product(question_output, context_embedding) scores.append(score.detach().cpu().numpy()[0]) scores_index = sorted(range(len(scores)), key=lambda x: scores[x], reverse=True) contexes_list = [] scores_final = [] for i, idx in enumerate(scores_index[:5]): scores_final.append(scores[idx]) contexes_list.append(selected_contexes[idx]) return scores_final, contexes_list def retrieve_rerank_with_trained_cross_encoder(question): hits = retrieve(question, context_embeddings) cross_inp = [(question, contexes[hit['corpus_id']]) for hit in hits] cross_scores = trained_cross_encoder.predict(cross_inp) # Sort results by the cross-encoder scores for idx in range(len(cross_scores)): hits[idx]['cross-score'] = cross_scores[idx][0] # Output of top-5 hits from re-ranker hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True) top_5_contexes = [] top_5_scores = [] for hit in hits[0:5]: top_5_contexes.append(contexes[hit['corpus_id']]) top_5_scores.append(hit['cross-score']) return top_5_contexes, top_5_scores def interactive_table(dataframe): gb = GridOptionsBuilder.from_dataframe(dataframe) gb.configure_pagination(paginationAutoPageSize=True) gb.configure_side_bar() gb.configure_selection('single', rowMultiSelectWithClick=True, groupSelectsChildren="Group checkbox select children") # Enable multi-row selection gridOptions = gb.build() grid_response = AgGrid( dataframe, gridOptions=gridOptions, data_return_mode='AS_INPUT', update_mode='SELECTION_CHANGED', enable_enterprise_modules=False, fit_columns_on_grid_load=False, theme='streamlit', # Add theme color to the table height=350, width='100%', reload_data=False ) return grid_response def img_to_bytes(img_path): img_bytes = Path(img_path).read_bytes() encoded = base64.b64encode(img_bytes).decode() return encoded def qa_main_widgetsv2(): st.title("Question Answering Demo") st.markdown("""---""") option = st.selectbox("Select a search method:", list(selectbox_selections.keys())) header_html = "