JUNGU commited on
Commit
14bbe59
β€’
1 Parent(s): 89d0595

Update rag_system.py

Browse files
Files changed (1) hide show
  1. rag_system.py +60 -30
rag_system.py CHANGED
@@ -8,6 +8,10 @@ from langchain.docstore.document import Document
8
  from langchain.text_splitter import RecursiveCharacterTextSplitter
9
  import pdfplumber
10
  from concurrent.futures import ThreadPoolExecutor
 
 
 
 
11
 
12
  # Load environment variables
13
  load_dotenv()
@@ -28,10 +32,19 @@ def load_retrieval_qa_chain():
28
  # Initialize ChatOpenAI model
29
  llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0) # "gpt-4o-mini
30
 
31
- # Create ConversationalRetrievalChain
 
 
 
 
 
 
 
 
 
32
  qa_chain = ConversationalRetrievalChain.from_llm(
33
  llm,
34
- vectorstore.as_retriever(),
35
  return_source_documents=True
36
  )
37
 
@@ -82,41 +95,58 @@ def update_embeddings():
82
  documents.extend(result)
83
  vectorstore.add_documents(documents)
84
 
85
- # Generate answer for a query
86
- def get_answer(qa_chain, query, chat_history):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  formatted_history = [(q, a) for q, a in zip(chat_history[::2], chat_history[1::2])]
88
 
89
- response = qa_chain.invoke({"question": query, "chat_history": formatted_history})
90
-
91
- answer = response["answer"]
92
 
93
- source_docs = response.get("source_documents", [])
94
- source_texts = [f"{os.path.basename(doc.metadata['source'])} (Page {doc.metadata['page']})" for doc in source_docs]
 
 
95
 
96
- return {"answer": answer, "sources": source_texts}
97
 
98
  # Example usage
99
  if __name__ == "__main__":
100
  update_embeddings() # Update embeddings with new documents
101
- qa_chain = load_retrieval_qa_chain()
102
- question = """당신은 RAG(Retrieval-Augmented Generation) 기반 AI μ–΄μ‹œμŠ€ν„΄νŠΈμž…λ‹ˆλ‹€. λ‹€μŒ 지침을 따라 μ‚¬μš©μž μ§ˆλ¬Έμ— λ‹΅ν•˜μ„Έμš”:
103
-
104
- 1. 검색 κ²°κ³Ό ν™œμš©: 제곡된 검색 κ²°κ³Όλ₯Ό λΆ„μ„ν•˜κ³  κ΄€λ ¨ 정보λ₯Ό μ‚¬μš©ν•΄ λ‹΅λ³€ν•˜μ„Έμš”.
105
-
106
- 2. μ •ν™•μ„± μœ μ§€: μ •λ³΄μ˜ 정확성을 ν™•μΈν•˜κ³ , λΆˆν™•μ‹€ν•œ 경우 이λ₯Ό λͺ…μ‹œν•˜μ„Έμš”.
107
-
108
- 3. κ°„κ²°ν•œ 응닡: μ§ˆλ¬Έμ— 직접 λ‹΅ν•˜κ³  핡심 λ‚΄μš©μ— μ§‘μ€‘ν•˜μ„Έμš”.
109
-
110
- 4. μΆ”κ°€ 정보 μ œμ•ˆ: κ΄€λ ¨λœ μΆ”κ°€ 정보가 μžˆλ‹€λ©΄ μ–ΈκΈ‰ν•˜μ„Έμš”.
111
-
112
- 5. μœ€λ¦¬μ„± κ³ λ €: 객관적이고 쀑립적인 νƒœλ„λ₯Ό μœ μ§€ν•˜μ„Έμš”.
113
-
114
- 6. ν•œκ³„ 인정: λ‹΅λ³€ν•  수 μ—†λŠ” 경우 μ†”μ§νžˆ μΈμ •ν•˜μ„Έμš”.
115
-
116
- 7. λŒ€ν™” μœ μ§€: μžμ—°μŠ€λŸ½κ²Œ λŒ€ν™”λ₯Ό 이어가고, ν•„μš”μ‹œ 후속 μ§ˆλ¬Έμ„ μ œμ•ˆν•˜μ„Έμš”.
117
- 항상 μ •ν™•ν•˜κ³  μœ μš©ν•œ 정보λ₯Ό μ œκ³΅ν•˜λŠ” 것을 λͺ©ν‘œλ‘œ ν•˜μ„Έμš”."""
118
-
119
- response = get_answer(qa_chain, question, [])
120
  print(f"Question: {question}")
121
  print(f"Answer: {response['answer']}")
122
- print(f"Sources: {response['sources']}")
 
 
 
 
 
 
8
  from langchain.text_splitter import RecursiveCharacterTextSplitter
9
  import pdfplumber
10
  from concurrent.futures import ThreadPoolExecutor
11
+ from langchain.retrievers import ContextualCompressionRetriever
12
+ from langchain.retrievers.document_compressors import LLMChainExtractor
13
+ from langgraph.graph import Graph
14
+ from langchain_core.runnables import RunnablePassthrough, RunnableLambda
15
 
16
  # Load environment variables
17
  load_dotenv()
 
32
  # Initialize ChatOpenAI model
33
  llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0) # "gpt-4o-mini
34
 
35
+ # Create a compressor for re-ranking
36
+ compressor = LLMChainExtractor.from_llm(llm)
37
+
38
+ # Create a ContextualCompressionRetriever
39
+ compression_retriever = ContextualCompressionRetriever(
40
+ base_compressor=compressor,
41
+ base_retriever=vectorstore.as_retriever()
42
+ )
43
+
44
+ # Create ConversationalRetrievalChain with the new retriever
45
  qa_chain = ConversationalRetrievalChain.from_llm(
46
  llm,
47
+ retriever=compression_retriever,
48
  return_source_documents=True
49
  )
50
 
 
95
  documents.extend(result)
96
  vectorstore.add_documents(documents)
97
 
98
+ def create_rag_graph():
99
+ qa_chain = load_retrieval_qa_chain()
100
+
101
+ def retrieve_and_generate(inputs):
102
+ question = inputs["question"]
103
+ chat_history = inputs["chat_history"]
104
+ result = qa_chain({"question": question, "chat_history": chat_history})
105
+
106
+ # Ensure source documents have the correct metadata
107
+ sources = []
108
+ for doc in result.get("source_documents", []):
109
+ if "source" in doc.metadata and "page" in doc.metadata:
110
+ sources.append(f"{os.path.basename(doc.metadata['source'])} (Page {doc.metadata['page']})")
111
+ else:
112
+ print(f"Warning: Document missing metadata: {doc.metadata}")
113
+
114
+ return {
115
+ "answer": result["answer"],
116
+ "sources": sources
117
+ }
118
+
119
+ workflow = Graph()
120
+ workflow.add_node("retrieve_and_generate", retrieve_and_generate)
121
+ workflow.set_entry_point("retrieve_and_generate")
122
+
123
+ chain = workflow.compile()
124
+ return chain
125
+
126
+ rag_chain = create_rag_graph()
127
+
128
+ def get_answer(query, chat_history):
129
  formatted_history = [(q, a) for q, a in zip(chat_history[::2], chat_history[1::2])]
130
 
131
+ response = rag_chain.invoke({"question": query, "chat_history": formatted_history})
 
 
132
 
133
+ # Validate response format
134
+ if "answer" not in response or "sources" not in response:
135
+ print("Warning: Unexpected response format")
136
+ return {"answer": "Error in processing", "sources": []}
137
 
138
+ return {"answer": response["answer"], "sources": response["sources"]}
139
 
140
  # Example usage
141
  if __name__ == "__main__":
142
  update_embeddings() # Update embeddings with new documents
143
+ question = "RAG μ‹œμŠ€ν…œμ— λŒ€ν•΄ μ„€λͺ…ν•΄μ£Όμ„Έμš”."
144
+ response = get_answer(question, [])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  print(f"Question: {question}")
146
  print(f"Answer: {response['answer']}")
147
+ print(f"Sources: {response['sources']}")
148
+
149
+ # Validate source format
150
+ for source in response['sources']:
151
+ if not (source.endswith(')') and ' (Page ' in source):
152
+ print(f"Warning: Unexpected source format: {source}")