Spaces:
Sleeping
Sleeping
File size: 4,186 Bytes
7f8ded9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import logging
import os
import tempfile
from langchain.chains import ConversationalRetrievalChain, ConversationChain
from langchain_openai import AzureOpenAIEmbeddings, AzureChatOpenAI
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain.schema import BaseRetriever, Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import DocArrayInMemorySearch
from langchain.agents import initialize_agent, AgentType
from langchain_community.agent_toolkits.load_tools import load_tools
from utils import MEMORY, load_document
import streamlit as st
logging.basicConfig(encoding="utf-8", level=logging.INFO)
LOGGER = logging.getLogger()
def config_retriever(docs: list[Document], use_compression=False, chunk_size=1500):
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap = 200)
splits = text_splitter.split_documents(docs)
embeddings = AzureOpenAIEmbeddings(
api_key=st.secrets['key'],
azure_deployment=st.secrets['embedding_name'],
openai_api_version=st.secrets['embedding_version'],
azure_endpoint=st.secrets['endpoint'],
)
vectorDB = DocArrayInMemorySearch.from_documents(splits, embeddings)
retriever = vectorDB.as_retriever(
search_type='mmr',
search_kwargs={
"k": 5,
"fetch_k": 7,
"include_metadata": True
}
)
if not use_compression:
return retriever
else:
embeddings_filter = EmbeddingsFilter(
embeddings=embeddings, similarity_threshold=0.2
)
return ContextualCompressionRetriever(
base_compressor=embeddings_filter,
base_retriever=retriever
)
def config_baseretrieval_chain(retriever: BaseRetriever, temperature=0.1):
LLM = AzureChatOpenAI(
api_key=st.secrets['key'],
openai_api_version=st.secrets['chat_version'],
azure_deployment=st.secrets['chat_name'],
azure_endpoint=st.secrets['endpoint'],
temperature=temperature,
)
MEMORY.output_key = 'answer'
params = dict(
llm=LLM,
retriever=retriever,
memory=MEMORY,
verbose=True
)
return ConversationalRetrievalChain.from_llm(**params)
def ddg_search_agent(temperature=0.1):
LLM = AzureChatOpenAI(
api_key=st.secrets['key'],
openai_api_version=st.secrets['chat_version'],
azure_deployment=st.secrets['chat_name'],
azure_endpoint=st.secrets['endpoint'],
temperature=temperature,
)
tools = load_tools(
tool_names=['ddg-search'],
llm=LLM,
model="gpt-4o-mini"
)
return initialize_agent(
tools=tools, llm=LLM, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, handle_parsing_errors=True
)
def config_retrieval_chain(
upload_files,
use_compression=False,
use_chunksize=1500,
use_temperature=0.1,
use_zeroshoot=False
):
docs = []
temp_dir = tempfile.TemporaryDirectory()
for file in upload_files:
temp_filepath = os.path.join(temp_dir.name, file.name)
with open(temp_filepath, "wb") as f:
f.write(file.getvalue())
docs.extend(load_document(temp_filepath))
retriever = config_retriever(docs=docs, use_compression=use_compression, chunk_size=use_chunksize)
chain = config_baseretrieval_chain(retriever=retriever, temperature=use_temperature)
if use_zeroshoot:
return ddg_search_agent(temperature=use_temperature)
else:
return chain
def config_noretrieval_chain(use_temperature=0.1,use_zeroshoot=False):
LLM = AzureChatOpenAI(
api_key=st.secrets['key'],
openai_api_version=st.secrets['chat_version'],
azure_deployment=st.secrets['chat_name'],
azure_endpoint=st.secrets['endpoint'],
temperature=use_temperature,
)
if use_zeroshoot:
return ddg_search_agent(temperature=use_temperature)
else:
return ConversationChain(llm=LLM)
|