amine-01 commited on
Commit
d66126a
Β·
verified Β·
1 Parent(s): b6ba543

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -115
app.py CHANGED
@@ -1,136 +1,79 @@
1
  import streamlit as st
2
- from langchain.prompts import PromptTemplate
3
- from langchain.chains.question_answering import load_qa_chain
4
- from langchain.text_splitter import RecursiveCharacterTextSplitter
5
- from langchain_community.vectorstores import Chroma
6
- from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
7
- from dotenv import load_dotenv
8
- import PyPDF2
9
- import os
10
- import io
11
  from langchain.document_loaders import PyPDFDirectoryLoader
 
 
12
  from langchain.embeddings import SentenceTransformerEmbeddings
 
13
  from langchain_core.output_parsers import StrOutputParser
14
  from langchain_core.runnables import RunnablePassthrough
 
 
15
 
 
 
 
16
 
17
- # Define SPEAKER_TYPES to distinguish between user and bot roles
18
- SPEAKER_TYPES = {
19
- "USER": "user",
20
- "BOT": "bot"
21
- }
22
-
23
- # Define the initial prompt to show when the app starts
24
- initial_prompt = {
25
- 'role': SPEAKER_TYPES["BOT"],
26
- 'content': "Hello! I am your Gemini Pro RAG chatbot. You can ask me questions after uploading a PDF."
27
- }
28
-
29
 
30
- # --- Your RAG chatbot logic ---
31
- source_data_folder = "MyData"
32
- text_splitter = RecursiveCharacterTextSplitter(
33
- separators=["\n\n", "\n", ". ", " ", ""],
34
- chunk_size=2000,
35
- chunk_overlap=200
36
- )
37
  embeddings_model = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
38
- path_db = "/content/VectorDB"
39
- llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro", google_api_key="AIzaSyAnsIVS4x_7lJLe9AYXGLV8FRwUTQkB-1w")
40
 
41
- # --- Streamlit app starts here ---
42
- # Set up the Streamlit app configuration
43
- st.set_page_config(
44
- page_title="Gemini Pro RAG App",
45
- page_icon="πŸ”",
46
- layout="wide",
47
- initial_sidebar_state="expanded",
48
- )
 
 
 
 
49
 
50
- # Initialize session state for chat history and vectorstore (PDF context)
51
- if 'chat_history' not in st.session_state:
52
- st.session_state.chat_history = [initial_prompt]
53
- if 'vectorstore' not in st.session_state:
54
- st.session_state.vectorstore = None
55
 
56
- # Function to clear chat history
57
- def clear_chat_history():
58
- st.session_state.chat_history = [initial_prompt]
59
 
60
- # Extract text from PDF
61
- def extract_text_from_pdf(pdf_bytes):
62
- pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_bytes))
63
- text = ""
64
- for page in pdf_reader.pages:
65
- text += page.extract_text()
66
- return text
67
 
68
- # Initialize vectorstore
69
- def initialize_vector_index(text):
70
- docs = [{'page_content': text}]
71
- splits = text_splitter.split_documents(docs)
72
- vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings_model, persist_directory=path_db)
73
- return vectorstore
74
 
75
- # Sidebar configuration
76
- with st.sidebar:
77
- st.title('πŸ” Gemini RAG Chatbot')
78
- st.write('This chatbot uses the Gemini Pro API with RAG capabilities.')
79
- st.button('Clear Chat History', on_click=clear_chat_history, type='primary')
80
- uploaded_file = st.file_uploader("Upload a PDF file", type=["pdf"], help="Upload your PDF file here to start the analysis.")
81
- if uploaded_file is not None:
82
- st.success("PDF File Uploaded Successfully!")
83
- text = extract_text_from_pdf(uploaded_file.read())
84
- vectorstore = initialize_vector_index(text)
85
- st.session_state.vectorstore = vectorstore
86
 
87
- # Main interface
88
- st.header('Gemini Pro RAG Chatbot')
89
- st.subheader('Upload a PDF and ask questions about its content!')
90
 
91
- # Display the welcome prompt if chat history is only the initial prompt
92
- if len(st.session_state.chat_history) == 1:
93
- with st.chat_message(SPEAKER_TYPES["BOT"], avatar="πŸ”"):
94
- st.write(initial_prompt['content'])
95
 
96
- # Get user input
97
- prompt = st.chat_input("Ask a question about the PDF content:", key="user_input")
 
 
 
 
 
98
 
99
- # Function to get a response from RAG chain
100
- def get_rag_response(prompt):
101
- retriever = st.session_state.vectorstore.as_retriever() # Use the stored vectorstore retriever
102
- rag_chain = (
103
- {"context": retriever | format_docs, "question": RunnablePassthrough()}
104
- | prompt
105
- | llm
106
- | StrOutputParser()
107
- )
108
- response = rag_chain.invoke(prompt)
109
- return response
110
 
111
- # Handle the user prompt and generate response
112
- if prompt:
113
- # Add user prompt to chat history
114
- st.session_state.chat_history.append({'role': SPEAKER_TYPES["USER"], 'content': prompt})
115
-
116
- # Display chat messages from the chat history
117
- for message in st.session_state.chat_history[1:]:
118
- with st.chat_message(message["role"], avatar="πŸ‘€" if message['role'] == SPEAKER_TYPES["USER"] else "πŸ”"):
119
- st.write(message["content"])
120
-
121
- # Get the response using the RAG chain
122
- with st.spinner(text='Generating response...'):
123
- response_text = get_rag_response(prompt)
124
- st.session_state.chat_history.append({'role': SPEAKER_TYPES["BOT"], 'content': response_text})
125
-
126
- # Display the bot response
127
- with st.chat_message(SPEAKER_TYPES["BOT"], avatar="πŸ”"):
128
- st.write(response_text)
129
 
130
- # Add footer for additional information or credits
131
- st.markdown("""
132
- <hr>
133
- <div style="text-align: center;">
134
- <small>Powered by Gemini Pro API | Developed by Christian Thomas BADOLO</small>
135
- </div>
136
- """, unsafe_allow_html=True)
 
1
  import streamlit as st
 
 
 
 
 
 
 
 
 
2
  from langchain.document_loaders import PyPDFDirectoryLoader
3
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
4
+ from langchain.vectorstores import Chroma
5
  from langchain.embeddings import SentenceTransformerEmbeddings
6
+ from langchain import hub
7
  from langchain_core.output_parsers import StrOutputParser
8
  from langchain_core.runnables import RunnablePassthrough
9
+ from langchain_google_genai import ChatGoogleGenerativeAI
10
+ import os
11
 
12
+ # Set up the directories for data and vector DB
13
+ DATA_DIR = "/content/MyData"
14
+ DB_DIR = "/content/VectorDB"
15
 
16
+ # Create directories if they don't exist
17
+ os.makedirs(DATA_DIR, exist_ok=True)
18
+ os.makedirs(DB_DIR, exist_ok=True)
 
 
 
 
 
 
 
 
 
19
 
20
+ # Initialize the embeddings model
 
 
 
 
 
 
21
  embeddings_model = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
 
 
22
 
23
+ # Load and process PDF documents
24
+ def load_data():
25
+ loader = PyPDFDirectoryLoader(DATA_DIR)
26
+ data_on_pdf = loader.load()
27
+ text_splitter = RecursiveCharacterTextSplitter(
28
+ separators=["\n\n", "\n", ". ", " ", ""],
29
+ chunk_size=1000,
30
+ chunk_overlap=200
31
+ )
32
+ splits = text_splitter.split_documents(data_on_pdf)
33
+ vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings_model, persist_directory=DB_DIR)
34
+ return vectorstore
35
 
36
+ # Set up the generative AI model
37
+ llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro", google_api_key="YOUR_GOOGLE_API_KEY")
 
 
 
38
 
39
+ # Load vector store
40
+ vectorstore = load_data()
 
41
 
42
+ # Streamlit interface
43
+ st.title("RAG App: Question-Answering with PDFs")
 
 
 
 
 
44
 
45
+ # File uploader for PDF documents
46
+ uploaded_files = st.file_uploader("Upload PDF files", accept_multiple_files=True, type=["pdf"])
 
 
 
 
47
 
48
+ if uploaded_files:
49
+ for uploaded_file in uploaded_files:
50
+ with open(os.path.join(DATA_DIR, uploaded_file.name), "wb") as f:
51
+ f.write(uploaded_file.getbuffer())
52
+ st.success("PDF files uploaded successfully!")
 
 
 
 
 
 
53
 
54
+ # Reload vector store after uploading new files
55
+ vectorstore = load_data()
 
56
 
57
+ # User input for question
58
+ question = st.text_input("Ask a question about the documents:")
 
 
59
 
60
+ if st.button("Submit"):
61
+ if question:
62
+ retriever = vectorstore.as_retriever()
63
+ prompt = hub.pull("rlm/rag-prompt")
64
+
65
+ def format_docs(docs):
66
+ return "\n\n".join(doc.page_content for doc in docs)
67
 
68
+ rag_chain = (
69
+ {"context": retriever | format_docs, "question": RunnablePassthrough()}
70
+ | prompt
71
+ | llm
72
+ | StrOutputParser()
73
+ )
 
 
 
 
 
74
 
75
+ response = rag_chain.invoke(question)
76
+ st.markdown(response)
77
+ else:
78
+ st.warning("Please enter a question.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79