File size: 4,186 Bytes
7f8ded9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import logging
import os
import tempfile
from langchain.chains import ConversationalRetrievalChain, ConversationChain
from langchain_openai import AzureOpenAIEmbeddings, AzureChatOpenAI
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain.schema import BaseRetriever, Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import DocArrayInMemorySearch
from langchain.agents import initialize_agent, AgentType
from langchain_community.agent_toolkits.load_tools import load_tools
from utils import MEMORY, load_document
import streamlit as st

logging.basicConfig(encoding="utf-8", level=logging.INFO)
LOGGER = logging.getLogger()

def config_retriever(docs: list[Document], use_compression=False, chunk_size=1500):
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap = 200)
    splits = text_splitter.split_documents(docs)

    embeddings = AzureOpenAIEmbeddings(
        api_key=st.secrets['key'],
        azure_deployment=st.secrets['embedding_name'], 
        openai_api_version=st.secrets['embedding_version'],
        azure_endpoint=st.secrets['endpoint'],
    )

    vectorDB = DocArrayInMemorySearch.from_documents(splits, embeddings)
    retriever = vectorDB.as_retriever(
        search_type='mmr',
        search_kwargs={
            "k": 5,
            "fetch_k": 7,
            "include_metadata": True
        }
    )
    if not use_compression:
        return retriever
    else:
        embeddings_filter = EmbeddingsFilter(
            embeddings=embeddings, similarity_threshold=0.2
        )
        return ContextualCompressionRetriever(
            base_compressor=embeddings_filter,
            base_retriever=retriever
        )

def config_baseretrieval_chain(retriever: BaseRetriever, temperature=0.1):
    LLM = AzureChatOpenAI(
        api_key=st.secrets['key'],
        openai_api_version=st.secrets['chat_version'],
        azure_deployment=st.secrets['chat_name'],
        azure_endpoint=st.secrets['endpoint'],
        temperature=temperature,
    )

    MEMORY.output_key = 'answer'
    params = dict(
        llm=LLM,
        retriever=retriever,
        memory=MEMORY,
        verbose=True
    )
    return ConversationalRetrievalChain.from_llm(**params)
    
def ddg_search_agent(temperature=0.1):
    LLM = AzureChatOpenAI(
        api_key=st.secrets['key'],
        openai_api_version=st.secrets['chat_version'],
        azure_deployment=st.secrets['chat_name'],
        azure_endpoint=st.secrets['endpoint'],
        temperature=temperature,
    )
    
    tools = load_tools(
        tool_names=['ddg-search'],
        llm=LLM,
        model="gpt-4o-mini"
    )
    return initialize_agent(
        tools=tools, llm=LLM, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, handle_parsing_errors=True
    )
    
def config_retrieval_chain(
        upload_files,
        use_compression=False,
        use_chunksize=1500,
        use_temperature=0.1,
        use_zeroshoot=False
):
    docs = []
    temp_dir = tempfile.TemporaryDirectory()
    for file in upload_files:
        temp_filepath = os.path.join(temp_dir.name, file.name)
        with open(temp_filepath, "wb") as f:
            f.write(file.getvalue())
        docs.extend(load_document(temp_filepath))
    
    retriever = config_retriever(docs=docs, use_compression=use_compression, chunk_size=use_chunksize)
    chain = config_baseretrieval_chain(retriever=retriever, temperature=use_temperature)
    if use_zeroshoot:
        return ddg_search_agent(temperature=use_temperature)
    else:
        return chain
    
def config_noretrieval_chain(use_temperature=0.1,use_zeroshoot=False):
    LLM = AzureChatOpenAI(
        api_key=st.secrets['key'],
        openai_api_version=st.secrets['chat_version'],
        azure_deployment=st.secrets['chat_name'],
        azure_endpoint=st.secrets['endpoint'],
        temperature=use_temperature,
    )
    if use_zeroshoot:
        return ddg_search_agent(temperature=use_temperature)
    else:
        return ConversationChain(llm=LLM)