|
import base64 |
|
import streamlit as st |
|
from streamlit_chat import message |
|
from streamlit_extras.colored_header import colored_header |
|
from backend import QnASystem |
|
from schema import TransformType, EmbeddingTypes, IndexerType, BotType |
|
|
|
kwargs = {} |
|
source_docs = [] |
|
st.set_page_config(page_title="PDFChat - An LLM-powered experimentation app") |
|
|
|
if "qna_system" not in st.session_state: |
|
st.session_state.qna_system = QnASystem() |
|
|
|
def show_pdf(f): |
|
f.seek(0) |
|
base64_pdf = base64.b64encode(f.read()).decode('utf-8') |
|
pdf_display = f'<iframe src="data:application/pdf;base64,{base64_pdf}" width="700" height="800" ' \ |
|
f'type="application/pdf"></iframe>' |
|
st.markdown(pdf_display, unsafe_allow_html=True) |
|
|
|
def model_settings(): |
|
kwargs["temperature"] = st.slider("Temperature", max_value=1.0, min_value=0.0) |
|
kwargs["max_tokens"] = st.number_input("Max Token", min_value=0, value=512) |
|
|
|
st.title("PDF Question and Answering") |
|
|
|
tab1, tab2, tab3 = st.tabs(["Upload and Ingest PDF", "Ask", "Show PDF"]) |
|
|
|
with st.sidebar: |
|
st.header("Advance Setting βοΈ") |
|
require_pdf = st.checkbox("Show PDF", value=1) |
|
st.markdown('---') |
|
kwargs["bot_type"] = st.selectbox("Bot Type", options=BotType) |
|
st.markdown("---") |
|
st.text("Model Parameters") |
|
kwargs["return_documents"] = st.checkbox("Require Source Documents", value=True) |
|
text_transform = st.selectbox("Text Transformer", options=TransformType) |
|
st.markdown("---") |
|
selected_model = st.selectbox("Select Model", options=EmbeddingTypes) |
|
match selected_model: |
|
case EmbeddingTypes.OPENAI: |
|
api_key = st.text_input("OpenAI API Key", placeholder="sk-...", type="password") |
|
if not api_key.startswith('sk-'): |
|
st.warning('Please enter your OpenAI API key!', icon='β ') |
|
model_settings() |
|
case EmbeddingTypes.HUGGING_FACE: |
|
api_key = st.text_input("Hugging Face API Key", placeholder="hg-...", type="password") |
|
if not api_key.startswith('hg-'): |
|
st.warning('Please enter your HuggingFace API key!', icon='β ') |
|
kwargs["model_name"] = st.selectbox("Choose Model", options=["google/flan-t5-xl"]) |
|
model_settings() |
|
case EmbeddingTypes.COHERE: |
|
api_key = st.text_input("Cohere API Key", placeholder="...", type="password") |
|
if not api_key: |
|
st.warning('Please enter your Cohere API key!', icon='β ') |
|
model_settings() |
|
case _: |
|
api_key = None |
|
kwargs["api_key"] = api_key |
|
st.markdown("---") |
|
|
|
vector_indexer = st.selectbox("Vector Indexer", options=IndexerType) |
|
match vector_indexer: |
|
case IndexerType.ELASTICSEARCH: |
|
kwargs["elasticsearch_url"] = st.text_input("Elastic Search URL: ") |
|
if not kwargs.get("elasticsearch_url"): |
|
st.warning("Please enter your elastic search url", icon='β ') |
|
kwargs["elasticsearch_index"] = st.text_input("Elastic Search Index: ") |
|
if not kwargs.get("elasticsearch_index"): |
|
st.warning("Please enter your elastic search index", icon='β ') |
|
|
|
st.markdown("---") |
|
st.text("Chain Settings") |
|
kwargs["chain_type"] = st.selectbox("Chain Type", options=["stuff", "map_reduce"]) |
|
kwargs["search_type"] = st.selectbox("Search Type", options=["similarity"]) |
|
st.markdown("---") |
|
|
|
with tab1: |
|
uploaded_file = st.file_uploader("Upload and Ingest PDF π", type="pdf") |
|
if uploaded_file: |
|
with st.spinner("Uploading and Ingesting"): |
|
documents = st.session_state.qna_system.read_and_load_pdf(uploaded_file) |
|
if selected_model == EmbeddingTypes.NA: |
|
st.warning("Please select the model", icon='β ') |
|
else: |
|
st.session_state.qna_system.build_chain(transform_type=text_transform, embedding_type=selected_model, |
|
indexer_type=vector_indexer, **kwargs) |
|
|
|
def generate_response(prompt): |
|
if prompt and uploaded_file: |
|
response = st.session_state.qna_system.ask_question(prompt) |
|
return response.get("answer", response.get("result", "")), response.get("source_documents") |
|
return "", [] |
|
|
|
with tab2: |
|
if not uploaded_file: |
|
st.warning("Please upload PDF", icon='β ') |
|
else: |
|
match kwargs["bot_type"]: |
|
case BotType.qna: |
|
with st.container(): |
|
with st.form('my_form'): |
|
text = st.text_area("", placeholder='Ask me...') |
|
submitted = st.form_submit_button('Submit') |
|
if text: |
|
st.write(f"Question:\n{text}") |
|
response, source_docs = generate_response(text) |
|
st.write(response) |
|
case BotType.conversational: |
|
|
|
|
|
if 'generated' not in st.session_state: |
|
st.session_state['generated'] = ["Hi! I'm PDF Assistant π€, How may I help you?"] |
|
|
|
if 'past' not in st.session_state: |
|
st.session_state['past'] = ['Hi!'] |
|
|
|
input_container = st.container() |
|
colored_header(label='', description='', color_name='blue-30') |
|
response_container = st.container() |
|
response = "" |
|
|
|
def get_text(): |
|
input_text = st.text_input("You: ", "", key="input") |
|
return input_text |
|
|
|
with input_container: |
|
user_input = get_text() |
|
if st.button("Clear"): |
|
st.session_state.generated.clear() |
|
st.session_state.past.clear() |
|
|
|
with response_container: |
|
if user_input: |
|
response, source_docs = generate_response(user_input) |
|
st.session_state.past.append(user_input) |
|
st.session_state.generated.append(response) |
|
|
|
if st.session_state['generated']: |
|
for i in range(len(st.session_state['generated'])): |
|
message(st.session_state['past'][i], is_user=True, key=str(i) + '_user') |
|
message(st.session_state["generated"][i], key=str(i)) |
|
|
|
require_document = st.container() |
|
if kwargs["return_documents"]: |
|
with require_document: |
|
with st.expander("Related Documents", expanded=False): |
|
for source in source_docs: |
|
metadata = source.metadata |
|
st.write("{source} - {page_no}".format(source=metadata.get("source"), |
|
page_no=metadata.get("page_no"))) |
|
st.write(source.page_content) |
|
st.markdown("---") |
|
|
|
with tab3: |
|
if require_pdf and uploaded_file: |
|
show_pdf(uploaded_file) |
|
elif uploaded_file: |
|
st.warning("Feature not enabled.", icon='β ') |
|
else: |
|
st.warning("Please upload PDF", icon='β ') |
|
|