ThaiCodex / rag.py
microhum's picture
initial commit
220a370
raw
history blame
4.09 kB
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts import MessagesPlaceholder
from langchain.chains import create_history_aware_retriever
from langchain_pinecone import PineconeVectorStore
from pinecone import Pinecone
from uuid import uuid4
import os
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_openai import ChatOpenAI
class Rag:
def __init__(self):
self.embedding = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
self.model = ChatOpenAI(
base_url='https://api.opentyphoon.ai/v1',
model='typhoon-v1.5-instruct',
api_key="sk-clKR9DG6C5K02OeHUBU927gbzXmTCydV9PjFaTBXfRVAJLKC",
)
self.system_prompt = (
"""
You are a helpful librarian named ThaiCodex. A user has requested book recommendations.
We have retrieved {num_docs} document(s) based on the user's request, listed below:
{context}
Please list ALL and ONLY the books that were found above in the order they were retrieved.
For each book, provide:
1. The Title.
2. A brief Content.
3. A reference to locate the book (e.g., a link, university, organization, or other relevant details).
Format your response as a numbered list, matching the order in which the documents were retrieved.
Results:
"""
)
self.contextualize_q_system_prompt = (
"Given a chat history and the latest user question "
"which might reference context in the chat history, "
"formulate a standalone question which can be understood "
"without the chat history. Do NOT answer the question, "
"just reformulate it if needed and otherwise return it as is."
)
self.contextualize_q_prompt = ChatPromptTemplate.from_messages(
[
("system", self.contextualize_q_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
self.qa_prompt = ChatPromptTemplate.from_messages(
[
("system", self.system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
if not os.getenv("PINECONE_API_KEY"):
os.environ["PINECONE_API_KEY"] = "ed681339-2270-4f85-b416-a372e857827b"
pinecone_api_key = os.environ.get("PINECONE_API_KEY")
pc = Pinecone(api_key=pinecone_api_key)
index_name = "thaicodex"
index = pc.Index(index_name)
self.vectorstore = PineconeVectorStore(index=index, embedding=self.embedding)
def storeDocumentsInVectorstore(self, documents):
uuids = [str(uuid4()) for _ in range(len(documents))]
self.vectorstore.add_documents(documents=documents, ids=uuids)
def createRagChain(self):
self.question_answer_chain = create_stuff_documents_chain(self.model, self.qa_prompt)
self.history_aware_retriever = create_history_aware_retriever(self.model, self.vectorstore.as_retriever(), self.contextualize_q_prompt)
self.rag_chain = create_retrieval_chain(self.history_aware_retriever, self.question_answer_chain)
def generateResponse(self, question, chat_history):
retrieved_docs = self.vectorstore.as_retriever().get_relevant_documents(question)
num_docs = len(retrieved_docs)
docs = "\n\n".join([
f"{i+1}. Title: {doc.metadata.get('source')}\nContent: {doc.page_content}"
for i, doc in enumerate(retrieved_docs)
])
print(num_docs)
print(docs)
ai_msg = self.rag_chain.invoke({
"context": docs,
"num_docs": num_docs,
"input": question,
"chat_history": chat_history
})
return ai_msg