Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
from langchain_community.vectorstores import FAISS | |
from langchain_core.messages import AIMessage, HumanMessage | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_openai.chat_models.azure import ChatOpenAI | |
from langchain_openai.embeddings.azure import OpenAIEmbeddings | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from boto_client import extract_text_from_pdf | |
vector_database_name = "Adina_Vector_Database" | |
temp_pdf_folder = "temp-pdf-files" | |
vector_database_path = ( | |
f"{os.environ.get('VECTOR_DATABASE_PATH', '.')}/{vector_database_name}" | |
) | |
RETRIEVER = None | |
def delete_temp_files(): | |
for item in os.listdir(temp_pdf_folder): | |
file_path = os.path.join(temp_pdf_folder, item) | |
os.remove(file_path) | |
def load_and_split(file): | |
if not os.path.exists(temp_pdf_folder): | |
os.makedirs(temp_pdf_folder) | |
local_filepath = os.path.join(temp_pdf_folder, file.name) | |
with open(local_filepath, "wb") as f: | |
f.write(file.getvalue()) | |
text = extract_text_from_pdf(file_path=local_filepath, file_name=file.name) | |
docs = [] | |
if text: | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=512, chunk_overlap=100 | |
) | |
texts = text_splitter.split_text(text) | |
docs = text_splitter.create_documents( | |
texts=texts, metadatas=[{"file_name": file.name}] * len(texts) | |
) | |
delete_temp_files() | |
return docs | |
def initialize_vector_db(): | |
vector_database = FAISS.from_texts( | |
["Adina Cosmetic Ingredients"], OpenAIEmbeddings() | |
) | |
vector_database.save_local(vector_database_path) | |
return vector_database | |
def load_vector_db(): | |
if os.path.exists(vector_database_path): | |
return FAISS.load_local( | |
vector_database_path, | |
OpenAIEmbeddings(), | |
allow_dangerous_deserialization=True, | |
) | |
return initialize_vector_db() | |
def append_to_vector_db(docs: list = []): | |
global RETRIEVER | |
existing_vector_db = load_vector_db() | |
new_vector_db = FAISS.from_documents(docs, OpenAIEmbeddings()) | |
existing_vector_db.merge_from(new_vector_db) | |
existing_vector_db.save_local(vector_database_path) | |
RETRIEVER = existing_vector_db.as_retriever() | |
def create_embeddings(files: list = []): | |
for file in files: | |
docs = load_and_split(file) | |
if docs: | |
append_to_vector_db(docs=docs) | |
st.session_state.last_uploaded_files.append(file.name) | |
st.toast(f"{file.name} processed successfully") | |
print(f"{file.name} processed successfully") | |
else: | |
st.toast(f"{file.name} could not be processed") | |
print(f"{file.name} could not be processed") | |
def get_response(user_query, chat_history): | |
docs = RETRIEVER.invoke(user_query) | |
additional_info = RETRIEVER.invoke( | |
" ".join( | |
[ | |
message.content | |
for message in chat_history | |
if isinstance(message, HumanMessage) | |
] | |
) | |
) | |
docs_content = [doc.page_content for doc in docs] | |
for doc in additional_info: | |
if doc.page_content not in docs_content: | |
docs.append(doc) | |
template = """ | |
Your name is ADINA, who provides helpful information about Adina Consmetic Ingredients. | |
<rules> | |
- Answer the question based on the context only. | |
- If the question can not be answered, simply say you can not annswer it. | |
</rules> | |
Execute the below mandatory considerations when responding to the inquiries: | |
--- Tone - Respectful, Patient, and Encouraging: | |
Maintain a tone that is not only polite but also encouraging. Positive language can help build confidence, especially when they are trying to learn something new. | |
Be mindful of cultural references or idioms that may not be universally understood or may date back to a different era, ensuring relatability. | |
--- Clarity - Simple, Direct, and Unambiguous: | |
Avoid abbreviations, slang, or colloquialisms that might be confusing. Stick to standard language. | |
Use bullet points or numbered lists to break down instructions or information, which can aid in comprehension. | |
--- Structure - Organized, Consistent, and Considerate: | |
Include relevant examples or analogies that relate to experiences common in their lifetime, which can aid in understanding complex topics. | |
--- Empathy and Understanding - Compassionate and Responsive: | |
Recognize and validate their feelings or concerns. Phrases like, “It’s completely normal to find this challenging,” can be comforting. | |
Be aware of the potential need for more frequent repetition or rephrasing of information for clarity. | |
Answer the following questions considering the context and/or history of the conversation. | |
Chat history: {chat_history} | |
Context: {retrieved_info} | |
User question: {user_question} | |
""" | |
prompt = ChatPromptTemplate.from_template(template) | |
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", streaming=True) | |
chain = prompt | llm | StrOutputParser() | |
return chain.stream( | |
{ | |
"chat_history": chat_history, | |
"retrieved_info": docs, | |
"user_question": user_query, | |
} | |
) | |
def main(): | |
st.set_page_config(page_title="Adina Cosmetic Ingredients", page_icon="") | |
st.title("Adina Cosmetic Ingredients") | |
if "last_uploaded_files" not in st.session_state: | |
st.session_state.last_uploaded_files = [] | |
if "chat_history" not in st.session_state: | |
st.session_state.chat_history = [ | |
AIMessage(content="Hello, I am Adina. How can I help you?"), | |
] | |
for message in st.session_state.chat_history: | |
if isinstance(message, AIMessage): | |
with st.chat_message("AI"): | |
st.write(message.content) | |
elif isinstance(message, HumanMessage): | |
with st.chat_message("Human"): | |
st.write(message.content) | |
user_query = st.chat_input("Type your message here...") | |
if user_query is not None and user_query != "": | |
st.session_state.chat_history.append(HumanMessage(content=user_query)) | |
with st.chat_message("Human"): | |
st.markdown(user_query) | |
with st.chat_message("AI"): | |
response = st.write_stream( | |
get_response( | |
user_query=user_query, chat_history=st.session_state.chat_history | |
) | |
) | |
st.session_state.chat_history.append(AIMessage(content=response)) | |
uploaded_files = st.sidebar.file_uploader( | |
label="Upload files", type="pdf", accept_multiple_files=True | |
) | |
to_be_vectorised_files = [ | |
item | |
for item in uploaded_files | |
if item.name not in st.session_state.last_uploaded_files | |
] | |
if to_be_vectorised_files: | |
create_embeddings(to_be_vectorised_files) | |
if __name__ == "__main__": | |
RETRIEVER = load_vector_db().as_retriever() | |
main() | |