import os from dotenv import load_dotenv from langchain_openai import OpenAIEmbeddings from langchain_chroma import Chroma # 이 줄을 수정 from langchain.chains import ConversationalRetrievalChain from langchain_openai import ChatOpenAI from langchain.docstore.document import Document from langchain.text_splitter import RecursiveCharacterTextSplitter import pdfplumber from concurrent.futures import ThreadPoolExecutor from langchain.retrievers import ContextualCompressionRetriever from langchain.retrievers.document_compressors import LLMChainExtractor from langgraph.graph import Graph from langchain_core.runnables import RunnablePassthrough, RunnableLambda # Load environment variables load_dotenv() # Set OpenAI API key api_key = os.getenv("OPENAI_API_KEY") if not api_key: raise ValueError("API key not found. Please set the OPENAI_API_KEY environment variable.") os.environ["OPENAI_API_KEY"] = api_key def load_retrieval_qa_chain(): # Load embeddings embeddings = OpenAIEmbeddings() # Load vector store vectorstore = Chroma(persist_directory="./chroma_db", embedding_function=embeddings) # Initialize ChatOpenAI model llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0) # "gpt-4o-mini # Create a compressor for re-ranking compressor = LLMChainExtractor.from_llm(llm) # Create a ContextualCompressionRetriever compression_retriever = ContextualCompressionRetriever( base_compressor=compressor, base_retriever=vectorstore.as_retriever() ) # Create ConversationalRetrievalChain with the new retriever qa_chain = ConversationalRetrievalChain.from_llm( llm, retriever=compression_retriever, return_source_documents=True ) return qa_chain def extract_text_from_pdf(file_path): documents = [] with pdfplumber.open(file_path) as pdf: for page_num, page in enumerate(pdf.pages): text = page.extract_text() if text: # Split text into chunks text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) chunks = text_splitter.split_text(text) for chunk in chunks: doc = Document(page_content=chunk, metadata={"source": os.path.basename(file_path), "page": page_num + 1}) documents.append(doc) return documents def embed_documents(): embeddings = OpenAIEmbeddings(api_key=os.getenv("OPENAI_API_KEY")) vectorstore = Chroma(persist_directory="./chroma_db", embedding_function=embeddings) pdf_files = [f for f in os.listdir("./documents") if f.endswith('.pdf')] documents = [] with ThreadPoolExecutor() as executor: results = executor.map(extract_text_from_pdf, [f"./documents/{pdf_file}" for pdf_file in pdf_files]) for result in results: documents.extend(result) vectorstore.add_documents(documents) def update_embeddings(): embeddings = OpenAIEmbeddings(api_key=os.getenv("OPENAI_API_KEY")) vectorstore = Chroma(persist_directory="./chroma_db", embedding_function=embeddings) # Retrieve existing documents existing_files = set() for doc in vectorstore.similarity_search(""): existing_files.add(doc.metadata["source"]) pdf_files = [f for f in os.listdir("./documents") if f.endswith('.pdf')] new_files = [f for f in pdf_files if f not in existing_files] documents = [] with ThreadPoolExecutor() as executor: results = executor.map(extract_text_from_pdf, [f"./documents/{pdf_file}" for pdf_file in new_files]) for result in results: documents.extend(result) vectorstore.add_documents(documents) def create_rag_graph(): qa_chain = load_retrieval_qa_chain() def retrieve_and_generate(inputs): question = inputs["question"] chat_history = inputs["chat_history"] result = qa_chain({"question": question, "chat_history": chat_history}) # Ensure source documents have the correct metadata sources = [] for doc in result.get("source_documents", []): if "source" in doc.metadata and "page" in doc.metadata: sources.append(f"{os.path.basename(doc.metadata['source'])} (Page {doc.metadata['page']})") else: print(f"Warning: Document missing metadata: {doc.metadata}") return { "answer": result["answer"], "sources": sources } workflow = Graph() workflow.add_node("retrieve_and_generate", retrieve_and_generate) workflow.set_entry_point("retrieve_and_generate") chain = workflow.compile() return chain rag_chain = create_rag_graph() def get_answer(query, chat_history): formatted_history = [(q, a) for q, a in zip(chat_history[::2], chat_history[1::2])] response = rag_chain.invoke({"question": query, "chat_history": formatted_history}) # Validate response format if "answer" not in response or "sources" not in response: print("Warning: Unexpected response format") return {"answer": "Error in processing", "sources": []} return {"answer": response["answer"], "sources": response["sources"]} # Example usage if __name__ == "__main__": update_embeddings() # Update embeddings with new documents question = "RAG 시스템에 대해 설명해주세요." response = get_answer(question, []) print(f"Question: {question}") print(f"Answer: {response['answer']}") print(f"Sources: {response['sources']}") # Validate source format for source in response['sources']: if not (source.endswith(')') and ' (Page ' in source): print(f"Warning: Unexpected source format: {source}")