Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
from langchain_groq import ChatGroq | |
from langchain_community.document_loaders import WebBaseLoader | |
from langchain_community.embeddings import OllamaEmbeddings | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.chains.combine_documents import create_stuff_documents_chain | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain.chains import create_retrieval_chain | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain_community.document_loaders import PyPDFDirectoryLoader | |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings | |
# from langchain.vectorstores.cassandra import Cassandra | |
from langchain_community.vectorstores import Cassandra | |
from langchain_community.llms import Ollama | |
from cassandra.auth import PlainTextAuthProvider | |
import tempfile | |
import cassio | |
from PyPDF2 import PdfReader | |
from cassandra.cluster import Cluster | |
import warnings | |
warnings.filterwarnings("ignore") | |
from dotenv import load_dotenv | |
import time | |
load_dotenv() | |
ASTRA_DB_SECURE_BUNDLE_PATH ='secure-connect-pdf-query-db.zip' | |
os.environ["LANGCHAIN_TRACING_V2"]="true" | |
LANGCHAIN_API_KEY=os.getenv("LANGCHAIN_API_KEY") | |
LANGCHAIN_PROJECT=os.getenv("LANGCHAIN_PROJECT") | |
LANGCHAIN_ENDPOINT=os.getenv("LANGCHAIN_ENDPOINT") | |
ASTRA_DB_APPLICATION_TOKEN=os.getenv("ASTRA_DB_APPLICATION_TOKEN") | |
ASTRA_DB_ID=os.getenv("ASTRA_DB_ID") | |
ASTRA_DB_KEYSPACE=os.getenv("ASTRA_DB_KEYSPACE") | |
ASTRA_DB_API_ENDPOINT=os.getenv("ASTRA_DB_API_ENDPOINT") | |
ASTRA_DB_CLIENT_ID=os.getenv("ASTRA_DB_CLIENT_ID") | |
ASTRA_DB_CLIENT_SECRET=os.getenv("ASTRA_DB_CLIENT_SECRET") | |
ASTRA_DB_TABLE=os.getenv("ASTRA_DB_TABLE") | |
groq_api_key=os.getenv('groq_api_key') | |
cassio.init(token=ASTRA_DB_APPLICATION_TOKEN,database_id=ASTRA_DB_ID,secure_connect_bundle=ASTRA_DB_SECURE_BUNDLE_PATH) | |
cloud_config = { | |
'secure_connect_bundle': ASTRA_DB_SECURE_BUNDLE_PATH | |
} | |
def doc_loader(pdf_reader): | |
encode_kwargs = {'normalize_embeddings': True} | |
huggigface_embeddings=HuggingFaceBgeEmbeddings( | |
model_name='BAAI/bge-small-en-v1.5', | |
# model_name='sentence-transformers/all-MiniLM-16-v2', | |
model_kwargs={'device':'cpu'}, | |
encode_kwargs=encode_kwargs) | |
loader=PyPDFLoader(pdf_reader) | |
documents=loader.load_and_split() | |
text_splitter=RecursiveCharacterTextSplitter(chunk_size=1000,chunk_overlap=200) | |
final_documents=text_splitter.split_documents(documents) | |
astrasession = Cluster( | |
cloud={"secure_connect_bundle": ASTRA_DB_SECURE_BUNDLE_PATH}, | |
auth_provider=PlainTextAuthProvider("token", ASTRA_DB_APPLICATION_TOKEN), | |
).connect() | |
# Truncate the existing table | |
astrasession.execute(f'TRUNCATE {ASTRA_DB_KEYSPACE}.{ASTRA_DB_TABLE}') | |
astra_vector_store=Cassandra( | |
embedding=huggigface_embeddings, | |
table_name="qa_mini_demo", | |
session=astrasession, | |
keyspace=ASTRA_DB_KEYSPACE | |
) | |
astra_vector_store.add_documents(final_documents) | |
return astra_vector_store | |
def prompt_temp(): | |
prompt=ChatPromptTemplate.from_template( | |
""" | |
Answer the question based on provided context only. | |
Your context retrieval mechanism works correclty but your are not providing answer from context. | |
Please provide the most accurate response based on question. | |
{context}, | |
Questions:{input} | |
""" | |
) | |
return prompt | |
def generate_response(llm,prompt,user_input,vectorstore): | |
document_chain=create_stuff_documents_chain(llm,prompt) | |
retriever=vectorstore.as_retriever(search_type="similarity",search_kwargs={"k":5}) | |
retrieval_chain=create_retrieval_chain(retriever,document_chain) | |
response=retrieval_chain.invoke({"input":user_input}) | |
return response | |
# ['answer'] | |
def main(): | |
st.set_page_config(page_title='Chat Groq Demo') | |
st.header('Chat Groq Demo') | |
user_input=st.text_input('Enter the Prompt here') | |
file=st.file_uploader('Choose Invoice File',type='pdf') | |
submit = st.button("Submit") | |
st.session_state.submit_clicked = False | |
if submit : | |
st.session_state.submit_clicked = True | |
if user_input and file: | |
with tempfile.NamedTemporaryFile(delete=False) as temp_file: | |
temp_file.write(file.getbuffer()) | |
file_path = temp_file.name | |
# with open(file.name, mode='wb') as w: | |
# # w.write(file.getvalue()) | |
# w.write(file.getbuffer()) | |
llm=ChatGroq(groq_api_key=groq_api_key,model_name="gemma-7b-it") | |
prompt=prompt_temp() | |
vectorstore=doc_loader(file_path) | |
response=generate_response(llm,prompt,user_input,vectorstore) | |
st.write(response['answer']) | |
with st.expander("Document Similarity Search"): | |
for i,doc in enumerate(response['context']): | |
st.write(doc.page_content) | |
st.write('--------------------------------') | |
if __name__=="__main__": | |
main() | |