|
import os |
|
from dotenv import load_dotenv |
|
from langchain_openai import OpenAIEmbeddings |
|
from langchain_chroma import Chroma |
|
from langchain.chains import ConversationalRetrievalChain |
|
from langchain_openai import ChatOpenAI |
|
from langchain.docstore.document import Document |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
import pdfplumber |
|
from concurrent.futures import ThreadPoolExecutor |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
api_key = os.getenv("OPENAI_API_KEY") |
|
if not api_key: |
|
raise ValueError("API key not found. Please set the OPENAI_API_KEY environment variable.") |
|
os.environ["OPENAI_API_KEY"] = api_key |
|
|
|
def load_retrieval_qa_chain(): |
|
|
|
embeddings = OpenAIEmbeddings() |
|
|
|
|
|
vectorstore = Chroma(persist_directory="./chroma_db", embedding_function=embeddings) |
|
|
|
|
|
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0) |
|
|
|
|
|
qa_chain = ConversationalRetrievalChain.from_llm( |
|
llm, |
|
vectorstore.as_retriever(), |
|
return_source_documents=True |
|
) |
|
|
|
return qa_chain |
|
|
|
def extract_text_from_pdf(file_path): |
|
documents = [] |
|
with pdfplumber.open(file_path) as pdf: |
|
for page_num, page in enumerate(pdf.pages): |
|
text = page.extract_text() |
|
if text: |
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) |
|
chunks = text_splitter.split_text(text) |
|
for chunk in chunks: |
|
doc = Document(page_content=chunk, metadata={"source": os.path.basename(file_path), "page": page_num + 1}) |
|
documents.append(doc) |
|
return documents |
|
|
|
def embed_documents(): |
|
embeddings = OpenAIEmbeddings(api_key=os.getenv("OPENAI_API_KEY")) |
|
vectorstore = Chroma(persist_directory="./chroma_db", embedding_function=embeddings) |
|
|
|
pdf_files = [f for f in os.listdir("./documents") if f.endswith('.pdf')] |
|
documents = [] |
|
with ThreadPoolExecutor() as executor: |
|
results = executor.map(extract_text_from_pdf, [f"./documents/{pdf_file}" for pdf_file in pdf_files]) |
|
for result in results: |
|
documents.extend(result) |
|
vectorstore.add_documents(documents) |
|
|
|
def update_embeddings(): |
|
embeddings = OpenAIEmbeddings(api_key=os.getenv("OPENAI_API_KEY")) |
|
vectorstore = Chroma(persist_directory="./chroma_db", embedding_function=embeddings) |
|
|
|
|
|
existing_files = set() |
|
for doc in vectorstore.similarity_search(""): |
|
existing_files.add(doc.metadata["source"]) |
|
|
|
pdf_files = [f for f in os.listdir("./documents") if f.endswith('.pdf')] |
|
|
|
new_files = [f for f in pdf_files if f not in existing_files] |
|
documents = [] |
|
with ThreadPoolExecutor() as executor: |
|
results = executor.map(extract_text_from_pdf, [f"./documents/{pdf_file}" for pdf_file in new_files]) |
|
for result in results: |
|
documents.extend(result) |
|
vectorstore.add_documents(documents) |
|
|
|
|
|
def get_answer(qa_chain, query, chat_history): |
|
formatted_history = [(q, a) for q, a in zip(chat_history[::2], chat_history[1::2])] |
|
|
|
response = qa_chain.invoke({"question": query, "chat_history": formatted_history}) |
|
|
|
answer = response["answer"] |
|
|
|
source_docs = response.get("source_documents", []) |
|
source_texts = [f"{os.path.basename(doc.metadata['source'])} (Page {doc.metadata['page']})" for doc in source_docs] |
|
|
|
return {"answer": answer, "sources": source_texts} |
|
|
|
|
|
if __name__ == "__main__": |
|
update_embeddings() |
|
qa_chain = load_retrieval_qa_chain() |
|
question = """λΉμ μ RAG(Retrieval-Augmented Generation) κΈ°λ° AI μ΄μμ€ν΄νΈμ
λλ€. λ€μ μ§μΉ¨μ λ°λΌ μ¬μ©μ μ§λ¬Έμ λ΅νμΈμ: |
|
|
|
1. κ²μ κ²°κ³Ό νμ©: μ 곡λ κ²μ κ²°κ³Όλ₯Ό λΆμνκ³ κ΄λ ¨ μ 보λ₯Ό μ¬μ©ν΄ λ΅λ³νμΈμ. |
|
|
|
2. μ νμ± μ μ§: μ 보μ μ νμ±μ νμΈνκ³ , λΆνμ€ν κ²½μ° μ΄λ₯Ό λͺ
μνμΈμ. |
|
|
|
3. κ°κ²°ν μλ΅: μ§λ¬Έμ μ§μ λ΅νκ³ ν΅μ¬ λ΄μ©μ μ§μ€νμΈμ. |
|
|
|
4. μΆκ° μ 보 μ μ: κ΄λ ¨λ μΆκ° μ λ³΄κ° μλ€λ©΄ μΈκΈνμΈμ. |
|
|
|
5. μ€λ¦¬μ± κ³ λ €: κ°κ΄μ μ΄κ³ μ€λ¦½μ μΈ νλλ₯Ό μ μ§νμΈμ. |
|
|
|
6. νκ³ μΈμ : λ΅λ³ν μ μλ κ²½μ° μμ§ν μΈμ νμΈμ. |
|
|
|
7. λν μ μ§: μμ°μ€λ½κ² λνλ₯Ό μ΄μ΄κ°κ³ , νμμ νμ μ§λ¬Έμ μ μνμΈμ. |
|
νμ μ ννκ³ μ μ©ν μ 보λ₯Ό μ 곡νλ κ²μ λͺ©νλ‘ νμΈμ.""" |
|
|
|
response = get_answer(qa_chain, question, []) |
|
print(f"Question: {question}") |
|
print(f"Answer: {response['answer']}") |
|
print(f"Sources: {response['sources']}") |