ak3ra commited on
Commit
286d467
·
1 Parent(s): b8698fb

Refactor chat_function to unpack response tuple and return only the response

Browse files
Files changed (2) hide show
  1. app.py +2 -2
  2. rag/rag_pipeline.py +75 -36
app.py CHANGED
@@ -134,8 +134,8 @@ def chat_function(message: str, study_name: str, prompt_type: str) -> str:
134
  "Evidence-based": evidence_based_prompt,
135
  }.get(prompt_type)
136
 
137
- response = rag.query(message, prompt_template=prompt)
138
- return response.response
139
 
140
 
141
  def process_zotero_library_items(
 
134
  "Evidence-based": evidence_based_prompt,
135
  }.get(prompt_type)
136
 
137
+ response, _ = rag.query(message, prompt_template=prompt) # Unpack the tuple
138
+ return response
139
 
140
 
141
  def process_zotero_library_items(
rag/rag_pipeline.py CHANGED
@@ -10,7 +10,7 @@ from llama_index.embeddings.openai import OpenAIEmbedding
10
  from llama_index.llms.openai import OpenAI
11
  from llama_index.vector_stores.chroma import ChromaVectorStore
12
  import chromadb
13
- from typing import Dict, Any, List, Tuple
14
  import re
15
  import logging
16
 
@@ -33,9 +33,23 @@ class RAGPipeline:
33
  self.client = chromadb.Client()
34
  self.collection = self.client.get_or_create_collection(self.collection_name)
35
  self.embedding_model = OpenAIEmbedding(model_name="text-embedding-ada-002")
 
36
  self.load_documents()
37
  self.build_index()
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def extract_page_number_from_query(self, query: str) -> int:
40
  """Extract page number from query text."""
41
  # Look for patterns like "page 3", "p3", "p. 3", etc.
@@ -59,14 +73,45 @@ class RAGPipeline:
59
  self.data = json.load(f)
60
 
61
  self.documents = []
62
- for index, doc_data in enumerate(self.data):
63
- # Process each page's content separately
64
- pages = doc_data.get("pages", {})
65
- for page_num, page_content in pages.items():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  doc_content = (
67
- f"Title: {doc_data['title']}\n"
68
- f"Page {page_num} Content:\n{page_content}\n"
69
- f"Authors: {', '.join(doc_data['authors'])}\n"
70
  )
71
 
72
  metadata = {
@@ -74,16 +119,11 @@ class RAGPipeline:
74
  "authors": ", ".join(doc_data.get("authors", [])),
75
  "year": doc_data.get("date"),
76
  "doi": doc_data.get("doi"),
77
- "source_file": doc_data.get("source_file"),
78
- "page_number": int(page_num), # Store as integer
79
- "total_pages": len(pages),
80
  }
81
 
82
  self.documents.append(
83
  Document(
84
- text=doc_content,
85
- id_=f"doc_{index}_page_{page_num}",
86
- metadata=metadata,
87
  )
88
  )
89
 
@@ -113,7 +153,7 @@ class RAGPipeline:
113
 
114
  def query(
115
  self, context: str, prompt_template: PromptTemplate = None
116
- ) -> Tuple[str, Dict[str, Any]]:
117
  if prompt_template is None:
118
  prompt_template = PromptTemplate(
119
  "Context information is below.\n"
@@ -123,13 +163,14 @@ class RAGPipeline:
123
  "Given this information, please answer the question: {query_str}\n"
124
  "Provide a detailed answer using the content from the context above. "
125
  "If the question asks about specific page content, make sure to include that information. "
126
- "Cite sources using square brackets for EVERY piece of information, e.g. [1, p.3], [2, p.5], etc. "
127
  "If you're unsure about something, say so rather than making assumptions."
128
  )
129
 
130
- # Extract page number from query if present
131
- requested_page = self.extract_page_number_from_query(context)
132
- logger.info(f"Requested page number: {requested_page}")
 
133
 
134
  query_engine = self.index.as_query_engine(
135
  text_qa_template=prompt_template,
@@ -140,26 +181,24 @@ class RAGPipeline:
140
 
141
  response = query_engine.query(context)
142
 
143
- # Extract source information from the response nodes
144
- source_info = {}
145
  if hasattr(response, "source_nodes") and response.source_nodes:
146
  source_node = response.source_nodes[0]
147
  metadata = source_node.metadata
148
 
149
- # Use requested page number if available, otherwise use the page from metadata
150
- page_number = (
151
- requested_page
152
- if requested_page is not None
153
- else metadata.get("page_number", 0)
154
- )
155
-
156
- source_info = {
157
- "source_file": metadata.get("source_file"),
158
- "page_number": page_number,
159
- "title": metadata.get("title"),
160
- "authors": metadata.get("authors"),
161
- "content": source_node.text,
162
- }
163
- logger.info(f"Source info page number: {page_number}")
164
 
165
  return response.response, source_info
 
10
  from llama_index.llms.openai import OpenAI
11
  from llama_index.vector_stores.chroma import ChromaVectorStore
12
  import chromadb
13
+ from typing import Dict, Any, List, Tuple, Optional
14
  import re
15
  import logging
16
 
 
33
  self.client = chromadb.Client()
34
  self.collection = self.client.get_or_create_collection(self.collection_name)
35
  self.embedding_model = OpenAIEmbedding(model_name="text-embedding-ada-002")
36
+ self.is_pdf = self._check_if_pdf_collection()
37
  self.load_documents()
38
  self.build_index()
39
 
40
+ def _check_if_pdf_collection(self) -> bool:
41
+ """Check if this is a PDF collection based on the JSON structure."""
42
+ try:
43
+ with open(self.study_json, "r") as f:
44
+ data = json.load(f)
45
+ # Check first document for PDF-specific fields
46
+ if data and isinstance(data, list) and len(data) > 0:
47
+ return "pages" in data[0] and "source_file" in data[0]
48
+ return False
49
+ except Exception as e:
50
+ logger.error(f"Error checking collection type: {str(e)}")
51
+ return False
52
+
53
  def extract_page_number_from_query(self, query: str) -> int:
54
  """Extract page number from query text."""
55
  # Look for patterns like "page 3", "p3", "p. 3", etc.
 
73
  self.data = json.load(f)
74
 
75
  self.documents = []
76
+ if self.is_pdf:
77
+ # Handle PDF documents
78
+ for index, doc_data in enumerate(self.data):
79
+ pages = doc_data.get("pages", {})
80
+ for page_num, page_content in pages.items():
81
+ if isinstance(page_content, dict):
82
+ content = page_content.get("text", "")
83
+ else:
84
+ content = page_content
85
+
86
+ doc_content = (
87
+ f"Title: {doc_data['title']}\n"
88
+ f"Page {page_num} Content:\n{content}\n"
89
+ f"Authors: {', '.join(doc_data['authors'])}\n"
90
+ )
91
+
92
+ metadata = {
93
+ "title": doc_data.get("title"),
94
+ "authors": ", ".join(doc_data.get("authors", [])),
95
+ "year": doc_data.get("date"),
96
+ "source_file": doc_data.get("source_file"),
97
+ "page_number": int(page_num),
98
+ "total_pages": doc_data.get("page_count"),
99
+ }
100
+
101
+ self.documents.append(
102
+ Document(
103
+ text=doc_content,
104
+ id_=f"doc_{index}_page_{page_num}",
105
+ metadata=metadata,
106
+ )
107
+ )
108
+ else:
109
+ # Handle Zotero documents
110
+ for index, doc_data in enumerate(self.data):
111
  doc_content = (
112
+ f"Title: {doc_data.get('title', '')}\n"
113
+ f"Abstract: {doc_data.get('abstract', '')}\n"
114
+ f"Authors: {', '.join(doc_data.get('authors', []))}\n"
115
  )
116
 
117
  metadata = {
 
119
  "authors": ", ".join(doc_data.get("authors", [])),
120
  "year": doc_data.get("date"),
121
  "doi": doc_data.get("doi"),
 
 
 
122
  }
123
 
124
  self.documents.append(
125
  Document(
126
+ text=doc_content, id_=f"doc_{index}", metadata=metadata
 
 
127
  )
128
  )
129
 
 
153
 
154
  def query(
155
  self, context: str, prompt_template: PromptTemplate = None
156
+ ) -> Tuple[str, Optional[Dict[str, Any]]]:
157
  if prompt_template is None:
158
  prompt_template = PromptTemplate(
159
  "Context information is below.\n"
 
163
  "Given this information, please answer the question: {query_str}\n"
164
  "Provide a detailed answer using the content from the context above. "
165
  "If the question asks about specific page content, make sure to include that information. "
166
+ "Cite sources using square brackets for EVERY piece of information, e.g. [1], [2], etc. "
167
  "If you're unsure about something, say so rather than making assumptions."
168
  )
169
 
170
+ # Extract page number for PDF documents
171
+ requested_page = (
172
+ self.extract_page_number_from_query(context) if self.is_pdf else None
173
+ )
174
 
175
  query_engine = self.index.as_query_engine(
176
  text_qa_template=prompt_template,
 
181
 
182
  response = query_engine.query(context)
183
 
184
+ # Handle source information based on document type
185
+ source_info = None
186
  if hasattr(response, "source_nodes") and response.source_nodes:
187
  source_node = response.source_nodes[0]
188
  metadata = source_node.metadata
189
 
190
+ if self.is_pdf:
191
+ page_number = (
192
+ requested_page
193
+ if requested_page is not None
194
+ else metadata.get("page_number", 0)
195
+ )
196
+ source_info = {
197
+ "source_file": metadata.get("source_file"),
198
+ "page_number": page_number,
199
+ "title": metadata.get("title"),
200
+ "authors": metadata.get("authors"),
201
+ "content": source_node.text,
202
+ }
 
 
203
 
204
  return response.response, source_info