ak3ra commited on
Commit
d0a03de
·
1 Parent(s): c55ca9a

added source citations

Browse files
Files changed (4) hide show
  1. .gitignore +5 -0
  2. app.py +75 -32
  3. rag/rag_pipeline.py +27 -34
  4. study_files.json +9 -1
.gitignore CHANGED
@@ -131,6 +131,11 @@ ENV/
131
  env.bak/
132
  venv.bak/
133
 
 
 
 
 
 
134
  # Spyder project settings
135
  .spyderproject
136
  .spyproject
 
131
  env.bak/
132
  venv.bak/
133
 
134
+
135
+ yes
136
+ *.pub
137
+
138
+
139
  # Spyder project settings
140
  .spyderproject
141
  .spyproject
app.py CHANGED
@@ -494,7 +494,11 @@ def create_gr_interface() -> gr.Blocks:
494
 
495
  # Right column: PDF Preview and Upload
496
  with gr.Column(scale=3):
497
- pdf_preview = gr.Image(label="Source Page", height=600)
 
 
 
 
498
  with gr.Row():
499
  pdf_files = gr.File(
500
  file_count="multiple",
@@ -572,6 +576,31 @@ def create_gr_interface() -> gr.Blocks:
572
  history = history + [(message, None)]
573
  return history, "", None
574
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
575
  def generate_chat_response(history, collection_id, pdf_processor):
576
  """Generate response for the last message in history."""
577
  if not collection_id:
@@ -583,41 +612,55 @@ def create_gr_interface() -> gr.Blocks:
583
  try:
584
  # Get response and source info
585
  rag = get_rag_pipeline(collection_id)
586
- response, source_info = rag.query(last_message)
587
-
588
- # Generate preview if source information is available
589
- preview_image = None
590
- if (
591
- source_info
592
- and source_info.get("source_file")
593
- and source_info.get("page_number") is not None
594
- ):
595
- try:
596
- page_num = source_info["page_number"]
597
- logger.info(f"Attempting to render page {page_num}")
598
- preview_image = pdf_processor.render_page(
599
- source_info["source_file"], page_num
 
 
 
 
 
 
600
  )
601
- if preview_image:
602
- logger.info(
603
- f"Successfully generated preview for page {page_num}"
604
- )
605
- else:
606
- logger.warning(
607
- f"Failed to generate preview for page {page_num}"
608
- )
609
- except Exception as e:
610
- logger.error(f"Error generating PDF preview: {str(e)}")
611
- preview_image = None
 
 
 
 
 
 
 
612
 
613
  # Update history with response
614
- history[-1] = (last_message, response)
615
- return history, preview_image
616
 
617
  except Exception as e:
618
  logger.error(f"Error in generate_chat_response: {str(e)}")
619
  history[-1] = (last_message, f"Error: {str(e)}")
620
- return history, None
 
621
 
622
  # Update PDF event handlers
623
  upload_btn.click( # Change from pdf_files.upload to upload_btn.click
@@ -630,11 +673,11 @@ def create_gr_interface() -> gr.Blocks:
630
  chat_submit_btn.click(
631
  add_message,
632
  inputs=[chat_history, query_input],
633
- outputs=[chat_history, query_input, pdf_preview],
634
  ).success(
635
- lambda h, c: generate_chat_response(h, c, pdf_processor),
636
  inputs=[chat_history, current_collection],
637
- outputs=[chat_history, pdf_preview],
638
  )
639
 
640
  return demo
 
494
 
495
  # Right column: PDF Preview and Upload
496
  with gr.Column(scale=3):
497
+ # pdf_preview = gr.Image(label="Source Page", height=600)
498
+ source_info = gr.Markdown(
499
+ label="Sources",
500
+ value="No sources available yet."
501
+ )
502
  with gr.Row():
503
  pdf_files = gr.File(
504
  file_count="multiple",
 
576
  history = history + [(message, None)]
577
  return history, "", None
578
 
579
+ def format_source_info(source_nodes) -> str:
580
+ """Format source information into a markdown string."""
581
+ if not source_nodes:
582
+ return "No source information available"
583
+
584
+ sources_md = "### Sources\n\n"
585
+ seen_sources = set() # To track unique sources
586
+
587
+ for idx, node in enumerate(source_nodes, 1):
588
+ metadata = node.metadata
589
+ if not metadata:
590
+ continue
591
+
592
+ source_key = (metadata.get('source_file', ''), metadata.get('page_number', 0))
593
+ if source_key in seen_sources:
594
+ continue
595
+
596
+ seen_sources.add(source_key)
597
+ title = metadata.get('title', os.path.basename(metadata.get('source_file', 'Unknown')))
598
+ page = metadata.get('page_number', 'N/A')
599
+
600
+ sources_md += f"{idx}. **{title}** - Page {page}\n"
601
+
602
+ return sources_md
603
+
604
  def generate_chat_response(history, collection_id, pdf_processor):
605
  """Generate response for the last message in history."""
606
  if not collection_id:
 
612
  try:
613
  # Get response and source info
614
  rag = get_rag_pipeline(collection_id)
615
+ response_text, source_nodes = rag.query(last_message)
616
+
617
+ # Format sources info
618
+ sources_md = "### Top Sources\n\n"
619
+ if source_nodes and len(source_nodes) > 0:
620
+ seen_sources = set()
621
+ source_count = 0
622
+
623
+ # Only process up to 3 sources
624
+ for node in source_nodes:
625
+ if source_count >= 3: # Stop after 3 sources
626
+ break
627
+
628
+ if not hasattr(node, 'metadata'):
629
+ continue
630
+
631
+ metadata = node.metadata
632
+ source_key = (
633
+ metadata.get('source_file', ''),
634
+ metadata.get('page_number', 0)
635
  )
636
+
637
+ if source_key in seen_sources:
638
+ continue
639
+
640
+ seen_sources.add(source_key)
641
+ source_count += 1
642
+
643
+ title = metadata.get('title', 'Unknown')
644
+ if not title or title == 'Unknown':
645
+ title = os.path.basename(metadata.get('source_file', 'Unknown Document'))
646
+
647
+ page = metadata.get('page_number', 'N/A')
648
+ sources_md += f"{source_count}. **{title}** - Page {page}\n"
649
+
650
+ if source_count == 0:
651
+ sources_md = "No source information available"
652
+ else:
653
+ sources_md = "No source information available"
654
 
655
  # Update history with response
656
+ history[-1] = (last_message, response_text)
657
+ return history, sources_md
658
 
659
  except Exception as e:
660
  logger.error(f"Error in generate_chat_response: {str(e)}")
661
  history[-1] = (last_message, f"Error: {str(e)}")
662
+ return history, "Error retrieving sources"
663
+
664
 
665
  # Update PDF event handlers
666
  upload_btn.click( # Change from pdf_files.upload to upload_btn.click
 
673
  chat_submit_btn.click(
674
  add_message,
675
  inputs=[chat_history, query_input],
676
+ outputs=[chat_history, query_input],
677
  ).success(
678
+ generate_chat_response,
679
  inputs=[chat_history, current_collection],
680
+ outputs=[chat_history, source_info],
681
  )
682
 
683
  return demo
rag/rag_pipeline.py CHANGED
@@ -152,29 +152,36 @@ class RAGPipeline:
152
  self.index = VectorStoreIndex(
153
  nodes, vector_store=vector_store, embed_model=self.embedding_model
154
  )
 
155
 
156
  def query(
157
  self, context: str, prompt_template: PromptTemplate = None
158
- ) -> Tuple[str, Optional[Dict[str, Any]]]:
159
  if prompt_template is None:
160
  prompt_template = PromptTemplate(
161
- "Context information is below.\n"
162
- "---------------------\n"
163
- "{context_str}\n"
164
- "---------------------\n"
165
- "Given this information, please answer the question: {query_str}\n"
166
- "Provide a detailed answer using the content from the context above. "
167
- "If the question asks about specific page content, make sure to include that information. "
168
- "Cite sources using square brackets for EVERY piece of information, e.g. [1], [2], etc. "
169
- "If you're unsure about something, say so rather than making assumptions."
170
- )
 
 
 
 
 
 
 
171
 
172
  # Extract page number for PDF documents
173
  requested_page = (
174
  self.extract_page_number_from_query(context) if self.is_pdf else None
175
  )
176
 
177
- # This is a hack to index all the documents in the store :)
178
  n_documents = len(self.index.docstore.docs)
179
  print(f"n_documents: {n_documents}")
180
  query_engine = self.index.as_query_engine(
@@ -185,25 +192,11 @@ class RAGPipeline:
185
  )
186
 
187
  response = query_engine.query(context)
188
-
189
- # Handle source information based on document type
190
- source_info = None
191
- if hasattr(response, "source_nodes") and response.source_nodes:
192
- source_node = response.source_nodes[0]
193
- metadata = source_node.metadata
194
-
195
- if self.is_pdf:
196
- page_number = (
197
- requested_page
198
- if requested_page is not None
199
- else metadata.get("page_number", 0)
200
- )
201
- source_info = {
202
- "source_file": metadata.get("source_file"),
203
- "page_number": page_number,
204
- "title": metadata.get("title"),
205
- "authors": metadata.get("authors"),
206
- "content": source_node.text,
207
- }
208
-
209
- return response.response, source_info
 
152
  self.index = VectorStoreIndex(
153
  nodes, vector_store=vector_store, embed_model=self.embedding_model
154
  )
155
+
156
 
157
  def query(
158
  self, context: str, prompt_template: PromptTemplate = None
159
+ ) -> Tuple[str, List[Any]]:
160
  if prompt_template is None:
161
  prompt_template = PromptTemplate(
162
+ "Context information is below.\n"
163
+ "---------------------\n"
164
+ "{context_str}\n"
165
+ "---------------------\n"
166
+ "Given this information, please answer the question: {query_str}\n"
167
+ "Follow these guidelines for your response:\n"
168
+ "1. If the answer contains multiple pieces of information (e.g., author names, dates, statistics), "
169
+ "present it in a markdown table format.\n"
170
+ "2. For single piece information or simple answers, respond in a clear sentence.\n"
171
+ "3. Always cite sources using square brackets for EVERY piece of information, e.g. [1], [2], etc.\n"
172
+ "4. If the information spans multiple documents or pages, organize it by source.\n"
173
+ "5. If you're unsure about something, say so rather than making assumptions.\n"
174
+ "\nFormat tables like this:\n"
175
+ "| Field | Information | Source |\n"
176
+ "|-------|-------------|--------|\n"
177
+ "| Title | Example Title | [1] |\n"
178
+ )
179
 
180
  # Extract page number for PDF documents
181
  requested_page = (
182
  self.extract_page_number_from_query(context) if self.is_pdf else None
183
  )
184
 
 
185
  n_documents = len(self.index.docstore.docs)
186
  print(f"n_documents: {n_documents}")
187
  query_engine = self.index.as_query_engine(
 
192
  )
193
 
194
  response = query_engine.query(context)
195
+
196
+ # Debug logging
197
+ print(f"Response type: {type(response)}")
198
+ print(f"Has source_nodes: {hasattr(response, 'source_nodes')}")
199
+ if hasattr(response, 'source_nodes'):
200
+ print(f"Number of source nodes: {len(response.source_nodes)}")
201
+
202
+ return response.response, getattr(response, 'source_nodes', [])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
study_files.json CHANGED
@@ -1,5 +1,13 @@
1
  {
2
  "Ebola Virus": "data/ebola-virus_zotero_items.json",
3
  "GeneXpert": "data/genexpert_zotero_items.json",
4
- "Vaccine coverage": "data/vaccine-coverage_zotero_items.json"
 
 
 
 
 
 
 
 
5
  }
 
1
  {
2
  "Ebola Virus": "data/ebola-virus_zotero_items.json",
3
  "GeneXpert": "data/genexpert_zotero_items.json",
4
+ "Vaccine coverage": "data/vaccine-coverage_zotero_items.json",
5
+ "Concept": "data/concept_zotero_items.json",
6
+ "Zotero Collection Pastan": "data/zotero-collection-pastan_zotero_items.json",
7
+ "pdf_thequickone": "data/thequickone_20250108_111913_documents.json",
8
+ "pdf_aforapples": "data/aforapples_20250108_113044_documents.json",
9
+ "pdf_bforbinance": "data/bforbinance_20250108_114459_documents.json",
10
+ "pdf_cforcongo": "data/cforcongo_20250108_115233_documents.json",
11
+ "pdf_hjhj": "data/hjhj_20250108_115714_documents.json",
12
+ "pdf_schooldropouts": "data/schooldropouts_20250108_140257_documents.json"
13
  }