RAGBot-gpt / chat.py
Jaspertw177's picture
init
7f8ded9
import logging
import os
import tempfile
from langchain.chains import ConversationalRetrievalChain, ConversationChain
from langchain_openai import AzureOpenAIEmbeddings, AzureChatOpenAI
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain.schema import BaseRetriever, Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import DocArrayInMemorySearch
from langchain.agents import initialize_agent, AgentType
from langchain_community.agent_toolkits.load_tools import load_tools
from utils import MEMORY, load_document
import streamlit as st
logging.basicConfig(encoding="utf-8", level=logging.INFO)
LOGGER = logging.getLogger()
def config_retriever(docs: list[Document], use_compression=False, chunk_size=1500):
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap = 200)
splits = text_splitter.split_documents(docs)
embeddings = AzureOpenAIEmbeddings(
api_key=st.secrets['key'],
azure_deployment=st.secrets['embedding_name'],
openai_api_version=st.secrets['embedding_version'],
azure_endpoint=st.secrets['endpoint'],
)
vectorDB = DocArrayInMemorySearch.from_documents(splits, embeddings)
retriever = vectorDB.as_retriever(
search_type='mmr',
search_kwargs={
"k": 5,
"fetch_k": 7,
"include_metadata": True
}
)
if not use_compression:
return retriever
else:
embeddings_filter = EmbeddingsFilter(
embeddings=embeddings, similarity_threshold=0.2
)
return ContextualCompressionRetriever(
base_compressor=embeddings_filter,
base_retriever=retriever
)
def config_baseretrieval_chain(retriever: BaseRetriever, temperature=0.1):
LLM = AzureChatOpenAI(
api_key=st.secrets['key'],
openai_api_version=st.secrets['chat_version'],
azure_deployment=st.secrets['chat_name'],
azure_endpoint=st.secrets['endpoint'],
temperature=temperature,
)
MEMORY.output_key = 'answer'
params = dict(
llm=LLM,
retriever=retriever,
memory=MEMORY,
verbose=True
)
return ConversationalRetrievalChain.from_llm(**params)
def ddg_search_agent(temperature=0.1):
LLM = AzureChatOpenAI(
api_key=st.secrets['key'],
openai_api_version=st.secrets['chat_version'],
azure_deployment=st.secrets['chat_name'],
azure_endpoint=st.secrets['endpoint'],
temperature=temperature,
)
tools = load_tools(
tool_names=['ddg-search'],
llm=LLM,
model="gpt-4o-mini"
)
return initialize_agent(
tools=tools, llm=LLM, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, handle_parsing_errors=True
)
def config_retrieval_chain(
upload_files,
use_compression=False,
use_chunksize=1500,
use_temperature=0.1,
use_zeroshoot=False
):
docs = []
temp_dir = tempfile.TemporaryDirectory()
for file in upload_files:
temp_filepath = os.path.join(temp_dir.name, file.name)
with open(temp_filepath, "wb") as f:
f.write(file.getvalue())
docs.extend(load_document(temp_filepath))
retriever = config_retriever(docs=docs, use_compression=use_compression, chunk_size=use_chunksize)
chain = config_baseretrieval_chain(retriever=retriever, temperature=use_temperature)
if use_zeroshoot:
return ddg_search_agent(temperature=use_temperature)
else:
return chain
def config_noretrieval_chain(use_temperature=0.1,use_zeroshoot=False):
LLM = AzureChatOpenAI(
api_key=st.secrets['key'],
openai_api_version=st.secrets['chat_version'],
azure_deployment=st.secrets['chat_name'],
azure_endpoint=st.secrets['endpoint'],
temperature=use_temperature,
)
if use_zeroshoot:
return ddg_search_agent(temperature=use_temperature)
else:
return ConversationChain(llm=LLM)