File size: 4,910 Bytes
aa8e01a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import os
from dotenv import load_dotenv
from langchain_openai import OpenAIEmbeddings
from langchain_chroma import Chroma # μ΄ μ€μ μμ
from langchain.chains import ConversationalRetrievalChain
from langchain_openai import ChatOpenAI
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
import pdfplumber
from concurrent.futures import ThreadPoolExecutor
# Load environment variables
load_dotenv()
# Set OpenAI API key
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("API key not found. Please set the OPENAI_API_KEY environment variable.")
os.environ["OPENAI_API_KEY"] = api_key
def load_retrieval_qa_chain():
# Load embeddings
embeddings = OpenAIEmbeddings()
# Load vector store
vectorstore = Chroma(persist_directory="./chroma_db", embedding_function=embeddings)
# Initialize ChatOpenAI model
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0) # "gpt-4o-mini
# Create ConversationalRetrievalChain
qa_chain = ConversationalRetrievalChain.from_llm(
llm,
vectorstore.as_retriever(),
return_source_documents=True
)
return qa_chain
def extract_text_from_pdf(file_path):
documents = []
with pdfplumber.open(file_path) as pdf:
for page_num, page in enumerate(pdf.pages):
text = page.extract_text()
if text:
# Split text into chunks
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
chunks = text_splitter.split_text(text)
for chunk in chunks:
doc = Document(page_content=chunk, metadata={"source": os.path.basename(file_path), "page": page_num + 1})
documents.append(doc)
return documents
def embed_documents():
embeddings = OpenAIEmbeddings(api_key=os.getenv("OPENAI_API_KEY"))
vectorstore = Chroma(persist_directory="./chroma_db", embedding_function=embeddings)
pdf_files = [f for f in os.listdir("./documents") if f.endswith('.pdf')]
documents = []
with ThreadPoolExecutor() as executor:
results = executor.map(extract_text_from_pdf, [f"./documents/{pdf_file}" for pdf_file in pdf_files])
for result in results:
documents.extend(result)
vectorstore.add_documents(documents)
def update_embeddings():
embeddings = OpenAIEmbeddings(api_key=os.getenv("OPENAI_API_KEY"))
vectorstore = Chroma(persist_directory="./chroma_db", embedding_function=embeddings)
# Retrieve existing documents
existing_files = set()
for doc in vectorstore.similarity_search(""):
existing_files.add(doc.metadata["source"])
pdf_files = [f for f in os.listdir("./documents") if f.endswith('.pdf')]
new_files = [f for f in pdf_files if f not in existing_files]
documents = []
with ThreadPoolExecutor() as executor:
results = executor.map(extract_text_from_pdf, [f"./documents/{pdf_file}" for pdf_file in new_files])
for result in results:
documents.extend(result)
vectorstore.add_documents(documents)
# Generate answer for a query
def get_answer(qa_chain, query, chat_history):
formatted_history = [(q, a) for q, a in zip(chat_history[::2], chat_history[1::2])]
response = qa_chain.invoke({"question": query, "chat_history": formatted_history})
answer = response["answer"]
source_docs = response.get("source_documents", [])
source_texts = [f"{os.path.basename(doc.metadata['source'])} (Page {doc.metadata['page']})" for doc in source_docs]
return {"answer": answer, "sources": source_texts}
# Example usage
if __name__ == "__main__":
update_embeddings() # Update embeddings with new documents
qa_chain = load_retrieval_qa_chain()
question = """λΉμ μ RAG(Retrieval-Augmented Generation) κΈ°λ° AI μ΄μμ€ν΄νΈμ
λλ€. λ€μ μ§μΉ¨μ λ°λΌ μ¬μ©μ μ§λ¬Έμ λ΅νμΈμ:
1. κ²μ κ²°κ³Ό νμ©: μ 곡λ κ²μ κ²°κ³Όλ₯Ό λΆμνκ³ κ΄λ ¨ μ 보λ₯Ό μ¬μ©ν΄ λ΅λ³νμΈμ.
2. μ νμ± μ μ§: μ 보μ μ νμ±μ νμΈνκ³ , λΆνμ€ν κ²½μ° μ΄λ₯Ό λͺ
μνμΈμ.
3. κ°κ²°ν μλ΅: μ§λ¬Έμ μ§μ λ΅νκ³ ν΅μ¬ λ΄μ©μ μ§μ€νμΈμ.
4. μΆκ° μ 보 μ μ: κ΄λ ¨λ μΆκ° μ λ³΄κ° μλ€λ©΄ μΈκΈνμΈμ.
5. μ€λ¦¬μ± κ³ λ €: κ°κ΄μ μ΄κ³ μ€λ¦½μ μΈ νλλ₯Ό μ μ§νμΈμ.
6. νκ³ μΈμ : λ΅λ³ν μ μλ κ²½μ° μμ§ν μΈμ νμΈμ.
7. λν μ μ§: μμ°μ€λ½κ² λνλ₯Ό μ΄μ΄κ°κ³ , νμμ νμ μ§λ¬Έμ μ μνμΈμ.
νμ μ ννκ³ μ μ©ν μ 보λ₯Ό μ 곡νλ κ²μ λͺ©νλ‘ νμΈμ."""
response = get_answer(qa_chain, question, [])
print(f"Question: {question}")
print(f"Answer: {response['answer']}")
print(f"Sources: {response['sources']}") |