File size: 4,910 Bytes
aa8e01a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2763b46
aa8e01a
 
2763b46
 
aa8e01a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2763b46
 
 
 
 
 
 
 
 
 
 
 
aa8e01a
 
 
 
2763b46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa8e01a
 
2763b46
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
from dotenv import load_dotenv
from langchain_openai import OpenAIEmbeddings
from langchain_chroma import Chroma  # 이 쀄을 μˆ˜μ •
from langchain.chains import ConversationalRetrievalChain
from langchain_openai import ChatOpenAI
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
import pdfplumber
from concurrent.futures import ThreadPoolExecutor

# Load environment variables
load_dotenv()

# Set OpenAI API key
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
    raise ValueError("API key not found. Please set the OPENAI_API_KEY environment variable.")
os.environ["OPENAI_API_KEY"] = api_key

def load_retrieval_qa_chain():
    # Load embeddings
    embeddings = OpenAIEmbeddings()
    
    # Load vector store
    vectorstore = Chroma(persist_directory="./chroma_db", embedding_function=embeddings)

    # Initialize ChatOpenAI model
    llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)  # "gpt-4o-mini

    # Create ConversationalRetrievalChain
    qa_chain = ConversationalRetrievalChain.from_llm(
        llm,
        vectorstore.as_retriever(),
        return_source_documents=True
    )

    return qa_chain

def extract_text_from_pdf(file_path):
    documents = []
    with pdfplumber.open(file_path) as pdf:
        for page_num, page in enumerate(pdf.pages):
            text = page.extract_text()
            if text:
                # Split text into chunks
                text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
                chunks = text_splitter.split_text(text)
                for chunk in chunks:
                    doc = Document(page_content=chunk, metadata={"source": os.path.basename(file_path), "page": page_num + 1})
                    documents.append(doc)
    return documents

def embed_documents():
    embeddings = OpenAIEmbeddings(api_key=os.getenv("OPENAI_API_KEY"))
    vectorstore = Chroma(persist_directory="./chroma_db", embedding_function=embeddings)
    
    pdf_files = [f for f in os.listdir("./documents") if f.endswith('.pdf')]
    documents = []
    with ThreadPoolExecutor() as executor:
        results = executor.map(extract_text_from_pdf, [f"./documents/{pdf_file}" for pdf_file in pdf_files])
        for result in results:
            documents.extend(result)
    vectorstore.add_documents(documents)

def update_embeddings():
    embeddings = OpenAIEmbeddings(api_key=os.getenv("OPENAI_API_KEY"))
    vectorstore = Chroma(persist_directory="./chroma_db", embedding_function=embeddings)
    
    # Retrieve existing documents
    existing_files = set()
    for doc in vectorstore.similarity_search(""):
        existing_files.add(doc.metadata["source"])
    
    pdf_files = [f for f in os.listdir("./documents") if f.endswith('.pdf')]
    
    new_files = [f for f in pdf_files if f not in existing_files]
    documents = []
    with ThreadPoolExecutor() as executor:
        results = executor.map(extract_text_from_pdf, [f"./documents/{pdf_file}" for pdf_file in new_files])
        for result in results:
            documents.extend(result)
    vectorstore.add_documents(documents)

# Generate answer for a query
def get_answer(qa_chain, query, chat_history):
    formatted_history = [(q, a) for q, a in zip(chat_history[::2], chat_history[1::2])]
    
    response = qa_chain.invoke({"question": query, "chat_history": formatted_history})
    
    answer = response["answer"]
    
    source_docs = response.get("source_documents", [])
    source_texts = [f"{os.path.basename(doc.metadata['source'])} (Page {doc.metadata['page']})" for doc in source_docs]
    
    return {"answer": answer, "sources": source_texts}

# Example usage
if __name__ == "__main__":
    update_embeddings()  # Update embeddings with new documents
    qa_chain = load_retrieval_qa_chain()
    question = """당신은 RAG(Retrieval-Augmented Generation) 기반 AI μ–΄μ‹œμŠ€ν„΄νŠΈμž…λ‹ˆλ‹€. λ‹€μŒ 지침을 따라 μ‚¬μš©μž μ§ˆλ¬Έμ— λ‹΅ν•˜μ„Έμš”:

1. 검색 κ²°κ³Ό ν™œμš©: 제곡된 검색 κ²°κ³Όλ₯Ό λΆ„μ„ν•˜κ³  κ΄€λ ¨ 정보λ₯Ό μ‚¬μš©ν•΄ λ‹΅λ³€ν•˜μ„Έμš”.

2. μ •ν™•μ„± μœ μ§€: μ •λ³΄μ˜ 정확성을 ν™•μΈν•˜κ³ , λΆˆν™•μ‹€ν•œ 경우 이λ₯Ό λͺ…μ‹œν•˜μ„Έμš”.

3. κ°„κ²°ν•œ 응닡: μ§ˆλ¬Έμ— 직접 λ‹΅ν•˜κ³  핡심 λ‚΄μš©μ— μ§‘μ€‘ν•˜μ„Έμš”.

4. μΆ”κ°€ 정보 μ œμ•ˆ: κ΄€λ ¨λœ μΆ”κ°€ 정보가 μžˆλ‹€λ©΄ μ–ΈκΈ‰ν•˜μ„Έμš”.

5. μœ€λ¦¬μ„± κ³ λ €: 객관적이고 쀑립적인 νƒœλ„λ₯Ό μœ μ§€ν•˜μ„Έμš”.

6. ν•œκ³„ 인정: λ‹΅λ³€ν•  수 μ—†λŠ” 경우 μ†”μ§νžˆ μΈμ •ν•˜μ„Έμš”.

7. λŒ€ν™” μœ μ§€: μžμ—°μŠ€λŸ½κ²Œ λŒ€ν™”λ₯Ό 이어가고, ν•„μš”μ‹œ 후속 μ§ˆλ¬Έμ„ μ œμ•ˆν•˜μ„Έμš”.
항상 μ •ν™•ν•˜κ³  μœ μš©ν•œ 정보λ₯Ό μ œκ³΅ν•˜λŠ” 것을 λͺ©ν‘œλ‘œ ν•˜μ„Έμš”."""

    response = get_answer(qa_chain, question, [])
    print(f"Question: {question}")
    print(f"Answer: {response['answer']}")
    print(f"Sources: {response['sources']}")