q-and-a-tool / app.py
karshreya98's picture
resolving merge conflicts
8b2d8aa
raw
history blame
9.37 kB
from operator import index
import streamlit as st
import logging
import os
from annotated_text import annotation
from json import JSONDecodeError
from markdown import markdown
from utils.config import parser
from utils.haystack import start_document_store, query, initialize_pipeline, start_preprocessor_node, start_retriever, start_reader
from utils.ui import reset_results, set_initial_state
import pandas as pd
import haystack
# Whether the file upload should be enabled or not
DISABLE_FILE_UPLOAD = bool(os.getenv("DISABLE_FILE_UPLOAD"))
# Define a function to handle file uploads
def upload_files():
uploaded_files = st.sidebar.file_uploader(
"upload", type=["pdf", "txt", "docx"], accept_multiple_files=True, label_visibility="hidden"
)
return uploaded_files
# Define a function to process a single file
def process_file(data_file, preprocesor, document_store):
# read file and add content
file_contents = data_file.read().decode("utf-8")
docs = [{
'content': str(file_contents),
'meta': {'name': str(data_file.name)}
}]
try:
names = [item.meta.get('name') for item in document_store.get_all_documents()]
#if args.store == 'inmemory':
# doc = converter.convert(file_path=files, meta=None)
if data_file.name in names:
print(f"{data_file.name} already processed")
else:
print(f'preprocessing uploaded doc {data_file.name}.......')
#print(data_file.read().decode("utf-8"))
preprocessed_docs = preprocesor.process(docs)
print('writing to document store.......')
document_store.write_documents(preprocessed_docs)
print('updating emebdding.......')
document_store.update_embeddings(retriever)
except Exception as e:
print(e)
try:
args = parser.parse_args()
preprocesor = start_preprocessor_node()
document_store = start_document_store(type=args.store)
retriever = start_retriever(document_store)
reader = start_reader()
st.set_page_config(
page_title="MLReplySearch",
layout="centered",
page_icon=":shark:",
menu_items={
'Get Help': 'https://www.extremelycoolapp.com/help',
'Report a bug': "https://www.extremelycoolapp.com/bug",
'About': "# This is a header. This is an *extremely* cool app!"
}
)
st.sidebar.image("ml_logo.png", use_column_width=True)
# Sidebar for Task Selection
st.sidebar.header('Options:')
# OpenAI Key Input
openai_key = st.sidebar.text_input("Enter OpenAI Key:", type="password")
if openai_key:
task_options = ['Extractive', 'Generative']
else:
task_options = ['Extractive']
task_selection = st.sidebar.radio('Select the task:', task_options)
# Check the task and initialize pipeline accordingly
if task_selection == 'Extractive':
pipeline_extractive = initialize_pipeline("extractive", document_store, retriever, reader)
elif task_selection == 'Generative' and openai_key: # Check for openai_key to ensure user has entered it
pipeline_rag = initialize_pipeline("rag", document_store, retriever, reader, openai_key=openai_key)
set_initial_state()
st.write('# ' + args.name)
# File upload block
if not DISABLE_FILE_UPLOAD:
st.sidebar.write("## File Upload:")
#data_files = st.sidebar.file_uploader(
# "upload", type=["pdf", "txt", "docx"], accept_multiple_files=True, label_visibility="hidden"
#)
data_files = upload_files()
if data_files is not None:
for data_file in data_files:
# Upload file
if data_file:
try:
#raw_json = upload_doc(data_file)
# Call the process_file function for each uploaded file
if args.store == 'inmemory':
processed_data = process_file(data_file, preprocesor, document_store)
st.sidebar.write(str(data_file.name) + "    βœ… ")
except Exception as e:
st.sidebar.write(str(data_file.name) + "    ❌ ")
st.sidebar.write("_This file could not be parsed, see the logs for more information._")
if "question" not in st.session_state:
st.session_state.question = ""
# Search bar
question = st.text_input("", value=st.session_state.question, max_chars=100, on_change=reset_results)
run_pressed = st.button("Run")
run_query = (
run_pressed or question != st.session_state.question #or task_selection != st.session_state.task
)
# Get results for query
if run_query and question:
if task_selection == 'Extractive':
reset_results()
st.session_state.question = question
with st.spinner("πŸ”Ž    Running your pipeline"):
try:
st.session_state.results_extractive = query(pipeline_extractive, question)
st.session_state.task = task_selection
except JSONDecodeError as je:
st.error(
"πŸ‘“    An error occurred reading the results. Is the document store working?"
)
except Exception as e:
logging.exception(e)
st.error("🐞    An error occurred during the request.")
elif task_selection == 'Generative':
reset_results()
st.session_state.question = question
with st.spinner("πŸ”Ž    Running your pipeline"):
try:
st.session_state.results_generative = query(pipeline_rag, question)
st.session_state.task = task_selection
except JSONDecodeError as je:
st.error(
"πŸ‘“    An error occurred reading the results. Is the document store working?"
)
except Exception as e:
if "API key is invalid" in str(e):
logging.exception(e)
st.error("🐞    incorrect API key provided. You can find your API key at https://platform.openai.com/account/api-keys.")
else:
logging.exception(e)
st.error("🐞    An error occurred during the request.")
# Display results
if (st.session_state.results_extractive or st.session_state.results_generative) and run_query:
# Handle Extractive Answers
if task_selection == 'Extractive':
results = st.session_state.results_extractive
st.subheader("Extracted Answers:")
if 'answers' in results:
answers = results['answers']
treshold = 0.2
higher_then_treshold = any(ans.score > treshold for ans in answers)
if not higher_then_treshold:
st.markdown(f"<span style='color:red'>Please note none of the answers achieved a score higher then {int(treshold) * 100}%. Which probably means that the desired answer is not in the searched documents.</span>", unsafe_allow_html=True)
for count, answer in enumerate(answers):
if answer.answer:
text, context = answer.answer, answer.context
start_idx = context.find(text)
end_idx = start_idx + len(text)
score = round(answer.score, 3)
st.markdown(f"**Answer {count + 1}:**")
st.markdown(
context[:start_idx] + str(annotation(body=text, label=f'SCORE {score}', background='#964448', color='#ffffff')) + context[end_idx:],
unsafe_allow_html=True,
)
else:
st.info(
"πŸ€” &nbsp;&nbsp; Haystack is unsure whether any of the documents contain an answer to your question. Try to reformulate it!"
)
# Handle Generative Answers
elif task_selection == 'Generative':
results = st.session_state.results_generative
st.subheader("Generated Answer:")
if 'results' in results:
st.markdown("**Answer:**")
st.write(results['results'][0])
# Handle Retrieved Documents
if 'documents' in results:
retrieved_documents = results['documents']
st.subheader("Retriever Results:")
data = []
for i, document in enumerate(retrieved_documents):
# Truncate the content
truncated_content = (document.content[:150] + '...') if len(document.content) > 150 else document.content
data.append([i + 1, document.meta['name'], truncated_content])
# Convert data to DataFrame and display using Streamlit
df = pd.DataFrame(data, columns=['Ranked Context', 'Document Name', 'Content'])
st.table(df)
except SystemExit as e:
os._exit(e.code)