Spaces:
Runtime error
Runtime error
import copy | |
import streamlit as st | |
import pandas as pd | |
from sentence_transformers import SentenceTransformer, util | |
from sentence_transformers.cross_encoder import CrossEncoder | |
from st_aggrid import GridOptionsBuilder, AgGrid | |
import pickle | |
import torch | |
from transformers import DPRQuestionEncoderTokenizer, AutoModel | |
from pathlib import Path | |
import base64 | |
import regex | |
import tokenizers | |
st.set_page_config(layout="wide") | |
DATAFRAME_FILE_ORIGINAL = 'policyQA_original.csv' | |
DATAFRAME_FILE_BSBS = 'policyQA_bsbs_sentence.csv' | |
selectbox_selections = { | |
'Retrieve - Rerank (with fine-tuned cross-encoder)': 1, | |
'Dense Passage Retrieval':2, | |
'Retrieve - Reranking with DPR':3, | |
'Retrieve - Rerank':4 | |
} | |
imagebox_selections = { | |
'Retrieve - Rerank (with fine-tuned cross-encoder)': 'Retrieve-rerank-trained-cross-encoder.png', | |
'Dense Passage Retrieval': 'DPR_pipeline.png', | |
'Retrieve - Reranking with DPR': 'Retrieve-rerank-DPR.png', | |
'Retrieve - Rerank': 'retrieve-rerank.png' | |
} | |
def retrieve_rerank(question): | |
# Semantic Search (Retrieve) | |
question_embedding = bi_encoder.encode(question, convert_to_tensor=True) | |
hits = util.semantic_search(question_embedding, context_embeddings, top_k=100) | |
if len(hits) == 0: | |
return [] | |
hits = hits[0] | |
# Rerank - score all retrieved passages with cross-encoder | |
cross_inp = [[question, contexes[hit['corpus_id']]] for hit in hits] | |
cross_scores = cross_encoder.predict(cross_inp) | |
# Sort results by the cross-encoder scores | |
for idx in range(len(cross_scores)): | |
hits[idx]['cross-score'] = cross_scores[idx] | |
# Output of top-5 hits from re-ranker | |
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True) | |
top_5_contexes = [] | |
top_5_scores = [] | |
for hit in hits[0:20]: | |
top_5_contexes.append(contexes[hit['corpus_id']]) | |
top_5_scores.append(hit['cross-score']) | |
return top_5_contexes, top_5_scores | |
def load_paragraphs(path): | |
with open(path, "rb") as fIn: | |
cache_data = pickle.load(fIn) | |
corpus_sentences = cache_data['contexes'] | |
corpus_embeddings = cache_data['embeddings'] | |
return corpus_embeddings, corpus_sentences | |
def load_dataframes(): | |
data_original = pd.read_csv(DATAFRAME_FILE_ORIGINAL, index_col=0, sep='|') | |
data_bsbs = pd.read_csv(DATAFRAME_FILE_BSBS, index_col=0, sep='|') | |
data_original = data_original.sample(frac=1).reset_index(drop=True) | |
data_bsbs = data_bsbs.sample(frac=1).reset_index(drop=True) | |
return data_original, data_bsbs | |
def dot_product(question_output, context_output): | |
mat1 = torch.unsqueeze(question_output, dim=1) | |
mat2 = torch.unsqueeze(context_output, dim=2) | |
result = torch.bmm(mat1, mat2) | |
result = torch.squeeze(result, dim=1) | |
result = torch.squeeze(result, dim=1) | |
return result | |
def retrieve_rerank_DPR(question): | |
hits = retrieve_with_dpr_embeddings(question) | |
return rerank_with_DPR(hits, question) | |
def DPR_reranking(question, selected_contexes, selected_embeddings): | |
scores = [] | |
tokenized_question = question_tokenizer(question, padding=True, truncation=True, return_tensors="pt", | |
add_special_tokens=True) | |
question_output = dpr_trained.model.question_model(**tokenized_question) | |
question_output = question_output['pooler_output'] | |
for context_embedding in selected_embeddings: | |
score = dot_product(question_output, context_embedding) | |
scores.append(score.detach().cpu()) | |
scores_index = sorted(range(len(scores)), key=lambda x: scores[x], reverse=True) | |
contexes_list = [] | |
scores_final = [] | |
for i, idx in enumerate(scores_index[:5]): | |
scores_final.append(scores[idx]) | |
contexes_list.append(selected_contexes[idx]) | |
return scores_final, contexes_list | |
def search_pipeline(question, search_method): | |
if search_method == 1: #Retrieve - rerank with fine-tuned cross encoder | |
return retrieve_rerank_with_trained_cross_encoder(question) | |
if search_method == 2: | |
return custom_dpr_pipeline(question) # DPR only | |
if search_method == 3: | |
return retrieve_rerank_DPR(question) | |
if search_method == 4: | |
return retrieve_rerank(question) | |
def custom_dpr_pipeline(question): | |
#paragraphs | |
tokenized_question = question_tokenizer(question, padding=True, truncation=True, return_tensors="pt", | |
add_special_tokens=True) | |
question_embedding = dpr_trained.model.question_model(**tokenized_question) | |
question_embedding = question_embedding['pooler_output'] | |
results_list = [] | |
for i,context_embedding in enumerate(dpr_context_embeddings): | |
score = dot_product(question_embedding, context_embedding) | |
results_list.append(score.detach().cpu().numpy()[0]) | |
hits = sorted(range(len(results_list)), key=lambda i: results_list[i], reverse=True) | |
top_5_contexes = [] | |
top_5_scores = [] | |
for j in hits[0:5]: | |
top_5_contexes.append(dpr_contexes[j]) | |
top_5_scores.append(results_list[j]) | |
return top_5_contexes, top_5_scores | |
def retrieve(question, corpus_embeddings): | |
# Semantic Search (Retrieve) | |
question_embedding = bi_encoder.encode(question, convert_to_tensor=True) | |
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=100) | |
if len(hits) == 0: | |
return [] | |
hits = hits[0] | |
return hits | |
def retrieve_with_dpr_embeddings(question): | |
# Semantic Search (Retrieve) | |
question_tokens = question_tokenizer(question, padding=True, truncation=True, return_tensors="pt", | |
add_special_tokens=True) | |
question_embedding = dpr_trained.model.question_model(**question_tokens)['pooler_output'] | |
question_embedding = torch.squeeze(question_embedding, dim=0) | |
corpus_embeddings = torch.stack(dpr_context_embeddings) | |
corpus_embeddings = torch.squeeze(corpus_embeddings, dim=1) | |
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=100) | |
if len(hits) == 0: | |
return [] | |
hits = hits[0] | |
return hits | |
def rerank_with_DPR(hits, question): | |
# Rerank - score all retrieved passages with cross-encoder | |
selected_contexes = [dpr_contexes[hit['corpus_id']] for hit in hits] | |
selected_embeddings = [dpr_context_embeddings[hit['corpus_id']] for hit in hits] | |
top_5_scores, top_5_contexes = DPR_reranking(question, selected_contexes, selected_embeddings) | |
return top_5_contexes, top_5_scores | |
def DPR_reranking(question, selected_contexes, selected_embeddings): | |
scores = [] | |
tokenized_question = question_tokenizer(question, padding=True, truncation=True, return_tensors="pt", | |
add_special_tokens=True) | |
question_output = dpr_trained.model.question_model(**tokenized_question) | |
question_output = question_output['pooler_output'] | |
for context_embedding in selected_embeddings: | |
score = dot_product(question_output, context_embedding) | |
scores.append(score.detach().cpu().numpy()[0]) | |
scores_index = sorted(range(len(scores)), key=lambda x: scores[x], reverse=True) | |
contexes_list = [] | |
scores_final = [] | |
for i, idx in enumerate(scores_index[:5]): | |
scores_final.append(scores[idx]) | |
contexes_list.append(selected_contexes[idx]) | |
return scores_final, contexes_list | |
def retrieve_rerank_with_trained_cross_encoder(question): | |
hits = retrieve(question, context_embeddings) | |
cross_inp = [(question, contexes[hit['corpus_id']]) for hit in hits] | |
cross_scores = trained_cross_encoder.predict(cross_inp) | |
# Sort results by the cross-encoder scores | |
for idx in range(len(cross_scores)): | |
hits[idx]['cross-score'] = cross_scores[idx][0] | |
# Output of top-5 hits from re-ranker | |
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True) | |
top_5_contexes = [] | |
top_5_scores = [] | |
for hit in hits[0:5]: | |
top_5_contexes.append(contexes[hit['corpus_id']]) | |
top_5_scores.append(hit['cross-score']) | |
return top_5_contexes, top_5_scores | |
def interactive_table(dataframe): | |
gb = GridOptionsBuilder.from_dataframe(dataframe) | |
gb.configure_pagination(paginationAutoPageSize=True) | |
gb.configure_side_bar() | |
gb.configure_selection('single', rowMultiSelectWithClick=True, | |
groupSelectsChildren="Group checkbox select children") # Enable multi-row selection | |
gridOptions = gb.build() | |
grid_response = AgGrid( | |
dataframe, | |
gridOptions=gridOptions, | |
data_return_mode='AS_INPUT', | |
update_mode='SELECTION_CHANGED', | |
enable_enterprise_modules=False, | |
fit_columns_on_grid_load=False, | |
theme='streamlit', # Add theme color to the table | |
height=350, | |
width='100%', | |
reload_data=False | |
) | |
return grid_response | |
def img_to_bytes(img_path): | |
img_bytes = Path(img_path).read_bytes() | |
encoded = base64.b64encode(img_bytes).decode() | |
return encoded | |
def qa_main_widgetsv2(): | |
st.title("Question Answering Demo") | |
st.markdown("""---""") | |
option = st.selectbox("Select a search method:", list(selectbox_selections.keys())) | |
header_html = "<center> <img src='data:image/png;base64,{}' class='img-fluid' width='60%', height='40%'> </center>".format( | |
img_to_bytes(imagebox_selections[option]) | |
) | |
st.markdown( | |
header_html, unsafe_allow_html=True, | |
) | |
st.markdown("""---""") | |
col1, col2, col3 = st.columns([2, 1, 1]) | |
with col1: | |
form = st.form(key='first_form') | |
question = form.text_area("What is your question?:", height=200) | |
submit = form.form_submit_button('Submit') | |
if "form_submit" not in st.session_state: | |
st.session_state.form_submit = False | |
if submit: | |
st.session_state.form_submit = True | |
if st.session_state.form_submit and question != '': | |
with st.spinner(text='Related context search in progress..'): | |
top_5_contexes, top_5_scores = search_pipeline(question.strip(), selectbox_selections[option]) | |
if len(top_5_contexes) == 0: | |
st.error("Related context not found!") | |
st.session_state.form_submit = False | |
else: | |
for i, context in enumerate(top_5_contexes): | |
st.markdown(f"## Related Context - {i + 1} (score: {top_5_scores[i]:.2f})") | |
st.markdown(context) | |
st.markdown("""---""") | |
with col2: | |
st.markdown("## Original Questions") | |
grid_response = interactive_table(dataframe_original) | |
data1 = grid_response['selected_rows'] | |
if "grid_click_1" not in st.session_state: | |
st.session_state.grid_click_1 = False | |
if len(data1) > 0: | |
st.session_state.grid_click_1 = True | |
if st.session_state.grid_click_1: | |
selection = data1[0] | |
# st.markdown("## Context & Answer:") | |
st.markdown("### Context:") | |
st.write(selection['context']) | |
st.markdown("### Question:") | |
st.write(selection['question']) | |
st.markdown("### Answer:") | |
st.write(selection['answer']) | |
st.session_state.grid_click_1 = False | |
with col3: | |
st.markdown("## Our Questions") | |
grid_response = interactive_table(dataframe_bsbs) | |
data2 = grid_response['selected_rows'] | |
if "grid_click_2" not in st.session_state: | |
st.session_state.grid_click_2 = False | |
if len(data2) > 0: | |
st.session_state.grid_click_2 = True | |
if st.session_state.grid_click_2: | |
selection = data2[0] | |
# st.markdown("## Context & Answer:") | |
st.markdown("### Context:") | |
st.write(selection['context']) | |
st.markdown("### Question:") | |
st.write(selection['question']) | |
st.markdown("### Answer:") | |
st.write(selection['answer']) | |
st.session_state.grid_click_2 = False | |
def load_models(dpr_model_path, auth_token, cross_encoder_model_path): | |
dpr_trained = AutoModel.from_pretrained(dpr_model_path, use_auth_token=auth_token, | |
trust_remote_code=True) | |
bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1') | |
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') | |
bi_encoder.max_seq_length = 500 | |
trained_cross_encoder = CrossEncoder(cross_encoder_model_path) | |
question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained('facebook/dpr-question_encoder-single-nq-base') | |
return dpr_trained, bi_encoder, cross_encoder, trained_cross_encoder, question_tokenizer | |
context_embeddings, contexes = load_paragraphs('context-embeddings.pkl') | |
dpr_context_embeddings, dpr_contexes = load_paragraphs('custom-dpr-context-embeddings.pkl') | |
dataframe_original, dataframe_bsbs = load_dataframes() | |
dpr_trained, bi_encoder, cross_encoder, trained_cross_encoder, question_tokenizer = copy.deepcopy(load_models(st.secrets["DPR_MODEL_PATH"], st.secrets["AUTH_TOKEN"], st.secrets["CROSS_ENCODER_MODEL_PATH"])) | |
qa_main_widgetsv2() | |