Spaces:
Sleeping
Sleeping
from dotenv import load_dotenv | |
import pandas as pd | |
import streamlit as st | |
import streamlit_authenticator as stauth | |
from streamlit_modal import Modal | |
from utils import new_file, clear_memory, append_documentation_to_sidebar, load_authenticator_config, init_qa, \ | |
append_header | |
from haystack.document_stores.in_memory import InMemoryDocumentStore | |
from haystack import Document | |
load_dotenv() | |
OPENAI_MODELS = ['gpt-3.5-turbo', | |
"gpt-4", | |
"gpt-4-1106-preview"] | |
OPEN_MODELS = [ | |
'mistralai/Mistral-7B-Instruct-v0.1', | |
'HuggingFaceH4/zephyr-7b-beta' | |
] | |
def reset_chat_memory(): | |
st.button( | |
'Reset chat memory', | |
key="reset-memory-button", | |
on_click=clear_memory, | |
help="Clear the conversational memory. Currently implemented to retain the 4 most recent messages.", | |
disabled=False) | |
def manage_files(modal, document_store): | |
open_modal = st.sidebar.button("Manage Files", use_container_width=True) | |
if open_modal: | |
modal.open() | |
if modal.is_open(): | |
with modal.container(): | |
uploaded_file = st.file_uploader( | |
"Upload a CV in PDF format", | |
type=("pdf",), | |
on_change=new_file(), | |
disabled=st.session_state['document_qa_model'] is None, | |
label_visibility="collapsed", | |
help="The document is used to answer your questions. The system will process the document and store it in a RAG to answer your questions.", | |
) | |
edited_df = st.data_editor(use_container_width=True, data=st.session_state['files'], | |
num_rows='dynamic', | |
column_order=['name', 'size', 'is_active'], | |
column_config={'name': {'editable': False}, 'size': {'editable': False}, | |
'is_active': {'editable': True, 'type': 'checkbox', | |
'width': 100}} | |
) | |
st.session_state['files'] = pd.DataFrame(columns=['name', 'content', 'size', 'is_active']) | |
if uploaded_file: | |
st.session_state['file_uploaded'] = True | |
st.session_state['files'] = pd.concat([st.session_state['files'], edited_df]) | |
with st.spinner('Processing the CV content...'): | |
store_file_in_table(document_store, uploaded_file) | |
ingest_document(uploaded_file) | |
def ingest_document(uploaded_file): | |
if not st.session_state['document_qa_model']: | |
st.warning('Please select a model to start asking questions') | |
else: | |
try: | |
st.session_state['document_qa_model'].ingest_pdf(uploaded_file) | |
st.success('Document processed successfully') | |
except Exception as e: | |
st.error(f"Error processing the document: {e}") | |
st.session_state['file_uploaded'] = False | |
def store_file_in_table(document_store, uploaded_file): | |
pdf_content = uploaded_file.getvalue() | |
st.session_state['pdf_content'] = pdf_content | |
st.session_state.messages = [] | |
document = Document(content=pdf_content, meta={"name": uploaded_file.name}) | |
df = pd.DataFrame(st.session_state['files']) | |
df['is_active'] = False | |
st.session_state['files'] = pd.concat([df, pd.DataFrame( | |
[{"name": uploaded_file.name, "content": pdf_content, "size": len(pdf_content), | |
"is_active": True}])]) | |
document_store.write_documents([document]) | |
def init_session_state(): | |
st.session_state.setdefault('files', pd.DataFrame(columns=['name', 'content', 'size', 'is_active'])) | |
st.session_state.setdefault('models', []) | |
st.session_state.setdefault('api_keys', {}) | |
st.session_state.setdefault('current_selected_model', 'gpt-3.5-turbo') | |
st.session_state.setdefault('current_api_key', '') | |
st.session_state.setdefault('messages', []) | |
st.session_state.setdefault('pdf_content', None) | |
st.session_state.setdefault('memory', None) | |
st.session_state.setdefault('pdf', None) | |
st.session_state.setdefault('document_qa_model', None) | |
st.session_state.setdefault('file_uploaded', False) | |
def set_page_config(): | |
st.set_page_config( | |
page_title="CV Insights AI Assistant", | |
page_icon=":shark:", | |
initial_sidebar_state="expanded", | |
layout="wide", | |
menu_items={ | |
'Get Help': 'https://www.extremelycoolapp.com/help', | |
'Report a bug': "https://www.extremelycoolapp.com/bug", | |
'About': "# This is a header. This is an *extremely* cool app!" | |
} | |
) | |
def update_running_model(api_key, model): | |
st.session_state['api_keys'][model] = api_key | |
st.session_state['document_qa_model'] = init_qa(model, api_key) | |
def init_api_key_dict(): | |
st.session_state['models'] = OPENAI_MODELS + list(OPEN_MODELS) + ['local LLM'] | |
for model_name in OPENAI_MODELS: | |
st.session_state['api_keys'][model_name] = None | |
def display_chat_messages(chat_box, chat_input): | |
with chat_box: | |
if chat_input: | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"], unsafe_allow_html=True) | |
st.chat_message("user").markdown(chat_input) | |
with st.chat_message("assistant"): | |
# process user input and generate response | |
response = st.session_state['document_qa_model'].inference(chat_input, st.session_state.messages) | |
st.markdown(response) | |
st.session_state.messages.append({"role": "user", "content": chat_input}) | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |
def setup_model_selection(): | |
model = st.selectbox( | |
"Model:", | |
options=st.session_state['models'], | |
index=0, # default to the first model in the list gpt-3.5-turbo | |
placeholder="Select model", | |
help="Select an LLM:" | |
) | |
if model: | |
if model != st.session_state['current_selected_model']: | |
st.session_state['current_selected_model'] = model | |
if model == 'local LLM': | |
st.session_state['document_qa_model'] = init_qa(model) | |
api_key = st.sidebar.text_input("Enter LLM-authorization Key:", type="password", | |
disabled=st.session_state['current_selected_model'] == 'local LLM') | |
if api_key and api_key != st.session_state['current_api_key']: | |
update_running_model(api_key, model) | |
st.session_state['current_api_key'] = api_key | |
return model | |
def setup_task_selection(model): | |
# enable extractive and generative tasks if we're using a local LLM or an OpenAI model with an API key | |
if model == 'local LLM' or st.session_state['api_keys'].get(model): | |
task_options = ['Extractive', 'Generative'] | |
else: | |
task_options = ['Extractive'] | |
task_selection = st.sidebar.radio('Select the task:', task_options) | |
# TODO: Add the task selection logic here (initializing the model based on the task) | |
def setup_page_body(): | |
chat_box = st.container(height=350, border=False) | |
chat_input = st.chat_input( | |
placeholder="Upload a document to start asking questions...", | |
disabled=not st.session_state['file_uploaded'], | |
) | |
if st.session_state['file_uploaded']: | |
display_chat_messages(chat_box, chat_input) | |
class StreamlitApp: | |
def __init__(self): | |
self.authenticator_config = load_authenticator_config() | |
self.document_store = InMemoryDocumentStore() | |
set_page_config() | |
self.authenticator = self.init_authenticator() | |
init_session_state() | |
init_api_key_dict() | |
def init_authenticator(self): | |
return stauth.Authenticate( | |
self.authenticator_config['credentials'], | |
self.authenticator_config['cookie']['name'], | |
self.authenticator_config['cookie']['key'], | |
self.authenticator_config['cookie']['expiry_days'] | |
) | |
def setup_sidebar(self): | |
with st.sidebar: | |
st.sidebar.image("resources/ml_logo.png", use_column_width=True) | |
# Sidebar for Task Selection | |
st.sidebar.header('Options:') | |
model = setup_model_selection() | |
setup_task_selection(model) | |
st.divider() | |
self.authenticator.logout() | |
reset_chat_memory() | |
modal = Modal("Manage Files", key="demo-modal") | |
manage_files(modal, self.document_store) | |
st.divider() | |
append_documentation_to_sidebar() | |
def run(self): | |
name, authentication_status, username = self.authenticator.login() | |
if authentication_status: | |
self.run_authenticated_app() | |
elif st.session_state["authentication_status"] is False: | |
st.error('Username/password is incorrect') | |
elif st.session_state["authentication_status"] is None: | |
st.warning('Please enter your username and password') | |
def run_authenticated_app(self): | |
self.setup_sidebar() | |
append_header() | |
setup_page_body() | |
app = StreamlitApp() | |
app.run() | |