srinidhidevaraj's picture
Update app.py
d378475 verified
raw
history blame
5.13 kB
import streamlit as st
import os
from langchain_groq import ChatGroq
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.embeddings import OllamaEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains import create_retrieval_chain
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.document_loaders import PyPDFDirectoryLoader
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
# from langchain.vectorstores.cassandra import Cassandra
from langchain_community.vectorstores import Cassandra
from langchain_community.llms import Ollama
from cassandra.auth import PlainTextAuthProvider
import tempfile
import cassio
from PyPDF2 import PdfReader
from cassandra.cluster import Cluster
import warnings
warnings.filterwarnings("ignore")
from dotenv import load_dotenv
import time
load_dotenv()
ASTRA_DB_SECURE_BUNDLE_PATH ='secure-connect-pdf-query-db.zip'
os.environ["LANGCHAIN_TRACING_V2"]="true"
LANGCHAIN_API_KEY=os.getenv("LANGCHAIN_API_KEY")
LANGCHAIN_PROJECT=os.getenv("LANGCHAIN_PROJECT")
LANGCHAIN_ENDPOINT=os.getenv("LANGCHAIN_ENDPOINT")
ASTRA_DB_APPLICATION_TOKEN=os.getenv("ASTRA_DB_APPLICATION_TOKEN")
ASTRA_DB_ID=os.getenv("ASTRA_DB_ID")
ASTRA_DB_KEYSPACE=os.getenv("ASTRA_DB_KEYSPACE")
ASTRA_DB_API_ENDPOINT=os.getenv("ASTRA_DB_API_ENDPOINT")
ASTRA_DB_CLIENT_ID=os.getenv("ASTRA_DB_CLIENT_ID")
ASTRA_DB_CLIENT_SECRET=os.getenv("ASTRA_DB_CLIENT_SECRET")
ASTRA_DB_TABLE=os.getenv("ASTRA_DB_TABLE")
groq_api_key=os.getenv('groq_api_key')
cassio.init(token=ASTRA_DB_APPLICATION_TOKEN,database_id=ASTRA_DB_ID,secure_connect_bundle=ASTRA_DB_SECURE_BUNDLE_PATH)
cloud_config = {
'secure_connect_bundle': ASTRA_DB_SECURE_BUNDLE_PATH
}
def doc_loader(pdf_reader):
encode_kwargs = {'normalize_embeddings': True}
huggigface_embeddings=HuggingFaceBgeEmbeddings(
model_name='BAAI/bge-small-en-v1.5',
# model_name='sentence-transformers/all-MiniLM-16-v2',
model_kwargs={'device':'cpu'},
encode_kwargs=encode_kwargs)
loader=PyPDFLoader(pdf_reader)
documents=loader.load_and_split()
text_splitter=RecursiveCharacterTextSplitter(chunk_size=1000,chunk_overlap=200)
final_documents=text_splitter.split_documents(documents)
astrasession = Cluster(
cloud={"secure_connect_bundle": ASTRA_DB_SECURE_BUNDLE_PATH},
auth_provider=PlainTextAuthProvider("token", ASTRA_DB_APPLICATION_TOKEN),
).connect()
# Truncate the existing table
astrasession.execute(f'TRUNCATE {ASTRA_DB_KEYSPACE}.{ASTRA_DB_TABLE}')
astra_vector_store=Cassandra(
embedding=huggigface_embeddings,
table_name="qa_mini_demo",
session=astrasession,
keyspace=ASTRA_DB_KEYSPACE
)
astra_vector_store.add_documents(final_documents)
return astra_vector_store
def prompt_temp():
prompt=ChatPromptTemplate.from_template(
"""
Answer the question based on provided context only.
Your context retrieval mechanism works correclty but your are not providing answer from context.
Please provide the most accurate response based on question.
{context},
Questions:{input}
"""
)
return prompt
def generate_response(llm,prompt,user_input,vectorstore):
document_chain=create_stuff_documents_chain(llm,prompt)
retriever=vectorstore.as_retriever(search_type="similarity",search_kwargs={"k":5})
retrieval_chain=create_retrieval_chain(retriever,document_chain)
response=retrieval_chain.invoke({"input":user_input})
return response
# ['answer']
def main():
st.set_page_config(page_title='Chat Groq Demo')
st.header('Chat Groq Demo')
user_input=st.text_input('Enter the Prompt here')
file=st.file_uploader('Choose Invoice File',type='pdf')
submit = st.button("Submit")
st.session_state.submit_clicked = False
if submit :
st.session_state.submit_clicked = True
if user_input and file:
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
temp_file.write(file.getbuffer())
file_path = temp_file.name
# with open(file.name, mode='wb') as w:
# # w.write(file.getvalue())
# w.write(file.getbuffer())
llm=ChatGroq(groq_api_key=groq_api_key,model_name="gemma-7b-it")
prompt=prompt_temp()
vectorstore=doc_loader(file_path)
response=generate_response(llm,prompt,user_input,vectorstore)
st.write(response['answer'])
with st.expander("Document Similarity Search"):
for i,doc in enumerate(response['context']):
st.write(doc.page_content)
st.write('--------------------------------')
if __name__=="__main__":
main()