Spaces:
Paused
Paused
import os | |
from typing import List | |
import uuid | |
import chainlit as cl | |
from chainlit.types import AskFileResponse | |
from langchain.memory import ConversationBufferMemory | |
from langchain_core.chat_history import BaseChatMessageHistory | |
from langchain_community.chat_message_histories import ChatMessageHistory | |
from langchain_community.document_loaders import PyMuPDFLoader, TextLoader | |
from langchain.prompts import MessagesPlaceholder | |
from langchain.prompts import ChatPromptTemplate | |
from langchain.chains.history_aware_retriever import create_history_aware_retriever | |
from langchain.chains.retrieval import create_retrieval_chain | |
from langchain.chains.combine_documents import create_stuff_documents_chain | |
from langchain_experimental.text_splitter import SemanticChunker | |
from langchain_qdrant import QdrantVectorStore | |
from langchain_core.documents import Document | |
from qdrant_client import QdrantClient | |
from qdrant_client.http.models import Distance, VectorParams | |
from langchain_openai import ChatOpenAI | |
from langchain_core.runnables.history import RunnableWithMessageHistory | |
# from chainlit.input_widget import Select, Switch, Slider | |
from dotenv import load_dotenv | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever | |
from langchain.retrievers.document_compressors import LLMChainExtractor | |
load_dotenv() | |
BOR_FILE_PATH = "https://www.whitehouse.gov/wp-content/uploads/2022/10/Blueprint-for-an-AI-Bill-of-Rights.pdf" | |
NIST_FILE_PATH = "https://nvlpubs.nist.gov/nistpubs/ai/NIST.AI.600-1.pdf" | |
SMALL_DOC = "https://arxiv.org/pdf/1908.10084" # 11 pages Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks | |
documents_to_preload = [ | |
BOR_FILE_PATH, | |
NIST_FILE_PATH | |
# SMALL_DOC | |
] | |
collection_name = "ai-safety" | |
welcome_message = """ | |
Welcome to the chatbot to clarify all your AI Safety related queries.: | |
Now preloading below documents: | |
1. Blueprint for an AI Bill of Rights | |
2. NIST AI Standards | |
Please wait for a moment to load the documents. | |
""" | |
chat_model_name = "gpt-4o" | |
embedding_model_name = "Snowflake/snowflake-arctic-embed-l" | |
chat_model = ChatOpenAI(model=chat_model_name, temperature=0) | |
async def connect_to_qdrant(): | |
embedding_model = HuggingFaceEmbeddings(model_name=embedding_model_name) | |
qdrant_url = os.environ["QDRANT_URL"] | |
qdrant_api_key = os.environ["QDRANT_API_KEY"] | |
collection_name = os.environ["COLLECTION_NAME"] | |
qdrant_client = QdrantClient(url=qdrant_url,api_key=qdrant_api_key) | |
vector_store = QdrantVectorStore( | |
client=qdrant_client, | |
collection_name=collection_name, | |
embedding=embedding_model, | |
) | |
return vector_store.as_retriever(search_type="similarity_score_threshold",search_kwargs={'k':10,'score_threshold': 0.8}) | |
async def get_contextual_compressed_retriever(retriver): | |
base_retriever = retriver | |
compressor_llm = ChatOpenAI(temperature=0, model_name="gpt-4o", max_tokens=4000) | |
compressor = LLMChainExtractor.from_llm(compressor_llm) | |
#Combine the retriever with the compressor | |
compression_retriever = ContextualCompressionRetriever( | |
base_compressor=compressor, | |
base_retriever=base_retriever | |
) | |
return compression_retriever | |
def initialize_vectorstore( | |
collection_name: str, | |
embedding_model, | |
dimension, | |
distance_metric: Distance = Distance.COSINE, | |
): | |
client = QdrantClient(":memory:") | |
client.create_collection( | |
collection_name=collection_name, | |
vectors_config=VectorParams(size=dimension, distance=distance_metric), | |
) | |
vector_store = QdrantVectorStore( | |
client=client, | |
collection_name=collection_name, | |
embedding=embedding_model, | |
) | |
return vector_store | |
def get_text_splitter(strategy, embedding_model): | |
if strategy == "semantic": | |
return SemanticChunker( | |
embedding_model, | |
breakpoint_threshold_type="percentile", | |
breakpoint_threshold_amount=90, | |
) | |
def process_file(file: AskFileResponse, text_splitter): | |
if file.type == "text/plain": | |
Loader = TextLoader | |
elif file.type == "application/pdf": | |
Loader = PyMuPDFLoader | |
loader = Loader(file.path) | |
documents = loader.load() | |
title = documents[0].metadata.get("title") | |
docs = text_splitter.split_documents(documents) | |
for i, doc in enumerate(docs): | |
doc.metadata["source"] = f"source_{i}" | |
doc.metadata["title"] = title | |
return docs | |
def populate_vectorstore(vector_store, docs: List[Document]): | |
vector_store.add_documents(docs) | |
return vector_store | |
def create_history_aware_retriever_self(chat_model, retriever): | |
contextualize_q_system_prompt = ( | |
"Given a chat history and the latest user question which might reference context in the chat history, " | |
"formulate a standalone question which can be understood without the chat history. Do NOT answer the question, " | |
"just reformulate it if needed and otherwise return it as is." | |
) | |
contextualize_q_prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", contextualize_q_system_prompt), | |
MessagesPlaceholder("chat_history"), | |
("human", "{input}"), | |
] | |
) | |
return create_history_aware_retriever(chat_model, retriever, contextualize_q_prompt) | |
def create_qa_chain(chat_model): | |
qa_system_prompt = ( | |
"You are an helpful assistant named 'Shield' and your task is to answer any questions related to AI Safety for the given context." | |
"Use the following pieces of retrieved context to answer the question." | |
# "If any questions asked outside AI Safety context, just say that you are a specialist in AI Safety and can't answer that." | |
# f"When introducing you, just say that you are an AI assistant powered by embedding model {embedding_model_name} and chat model {chat_model_name} and your knowledge is limited to 'Blueprint for an AI Bill of Rights' and 'NIST AI Standards' documents." | |
"If you don't know the answer, just say that you don't know.\n\n" | |
"{context}" | |
) | |
qa_prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", qa_system_prompt), | |
MessagesPlaceholder("chat_history"), | |
("human", "{input}"), | |
] | |
) | |
return create_stuff_documents_chain(chat_model, qa_prompt) | |
def create_rag_chain(chat_model, retriever): | |
history_aware_retriever = create_history_aware_retriever_self(chat_model, retriever) | |
question_answer_chain = create_qa_chain(chat_model) | |
return create_retrieval_chain(history_aware_retriever, question_answer_chain) | |
def create_session_id(): | |
session_id = str(uuid.uuid4()) | |
return session_id | |
async def start(): | |
msg = cl.Message(content=welcome_message) | |
await msg.send() | |
# Create a session id | |
session_id = create_session_id() | |
cl.user_session.set("session_id", session_id) | |
retriever = await connect_to_qdrant() | |
contextual_compressed_retriever = await get_contextual_compressed_retriever(retriever) | |
rag_chain = create_rag_chain(chat_model, contextual_compressed_retriever) | |
store = {} | |
def get_session_history(session_id: str) -> BaseChatMessageHistory: | |
if session_id not in store: | |
store[session_id] = ChatMessageHistory() | |
return store[session_id] | |
conversational_rag_chain = RunnableWithMessageHistory( | |
rag_chain, | |
get_session_history, | |
input_messages_key="input", | |
history_messages_key="chat_history", | |
output_messages_key="answer", | |
) | |
# Let the user know that the system is ready | |
msg.content = msg.content + "\nReady to answer your questions!" | |
await msg.update() | |
cl.user_session.set("conversational_rag_chain", conversational_rag_chain) | |
async def main(message: cl.Message): | |
session_id = cl.user_session.get("session_id") | |
conversational_rag_chain = cl.user_session.get("conversational_rag_chain") | |
response = await conversational_rag_chain.ainvoke( | |
{"input": message.content}, | |
config={"configurable": {"session_id": session_id}, | |
"callbacks":[cl.AsyncLangchainCallbackHandler()]}, | |
) | |
answer = response["answer"] | |
source_documents = response["context"] | |
text_elements = [] | |
unique_pages = set() | |
if source_documents: | |
for source_idx, source_doc in enumerate(source_documents): | |
source_name = f"source_{source_idx+1}" | |
page_number = source_doc.metadata['page'] | |
#page_number = source_doc.metadata.get('page', "NA") # NA or any default value | |
page = f"Page {page_number}" | |
text_element_content = source_doc.page_content | |
text_element_content = text_element_content if text_element_content != "" else "No Content" | |
#text_elements.append(cl.Text(content=text_element_content, name=source_name)) | |
if page not in unique_pages: | |
unique_pages.add(page) | |
text_elements.append(cl.Text(content=text_element_content, name=page)) | |
#text_elements.append(cl.Text(content=text_element_content, name=page)) | |
source_names = [text_el.name for text_el in text_elements] | |
if source_names: | |
answer += f"\n\n Sources:{', '.join(source_names)}" | |
else: | |
answer += "\n\n No sources found" | |
await cl.Message(content=answer, elements=text_elements).send() | |
if __name__ == "__main__": | |
from chainlit.cli import run_chainlit | |
run_chainlit(__file__) | |