Spaces:
Sleeping
Sleeping
added source citations
Browse files- .gitignore +5 -0
- app.py +75 -32
- rag/rag_pipeline.py +27 -34
- study_files.json +9 -1
.gitignore
CHANGED
@@ -131,6 +131,11 @@ ENV/
|
|
131 |
env.bak/
|
132 |
venv.bak/
|
133 |
|
|
|
|
|
|
|
|
|
|
|
134 |
# Spyder project settings
|
135 |
.spyderproject
|
136 |
.spyproject
|
|
|
131 |
env.bak/
|
132 |
venv.bak/
|
133 |
|
134 |
+
|
135 |
+
yes
|
136 |
+
*.pub
|
137 |
+
|
138 |
+
|
139 |
# Spyder project settings
|
140 |
.spyderproject
|
141 |
.spyproject
|
app.py
CHANGED
@@ -494,7 +494,11 @@ def create_gr_interface() -> gr.Blocks:
|
|
494 |
|
495 |
# Right column: PDF Preview and Upload
|
496 |
with gr.Column(scale=3):
|
497 |
-
pdf_preview = gr.Image(label="Source Page", height=600)
|
|
|
|
|
|
|
|
|
498 |
with gr.Row():
|
499 |
pdf_files = gr.File(
|
500 |
file_count="multiple",
|
@@ -572,6 +576,31 @@ def create_gr_interface() -> gr.Blocks:
|
|
572 |
history = history + [(message, None)]
|
573 |
return history, "", None
|
574 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
575 |
def generate_chat_response(history, collection_id, pdf_processor):
|
576 |
"""Generate response for the last message in history."""
|
577 |
if not collection_id:
|
@@ -583,41 +612,55 @@ def create_gr_interface() -> gr.Blocks:
|
|
583 |
try:
|
584 |
# Get response and source info
|
585 |
rag = get_rag_pipeline(collection_id)
|
586 |
-
|
587 |
-
|
588 |
-
#
|
589 |
-
|
590 |
-
if (
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
600 |
)
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
-
|
610 |
-
|
611 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
612 |
|
613 |
# Update history with response
|
614 |
-
history[-1] = (last_message,
|
615 |
-
return history,
|
616 |
|
617 |
except Exception as e:
|
618 |
logger.error(f"Error in generate_chat_response: {str(e)}")
|
619 |
history[-1] = (last_message, f"Error: {str(e)}")
|
620 |
-
return history,
|
|
|
621 |
|
622 |
# Update PDF event handlers
|
623 |
upload_btn.click( # Change from pdf_files.upload to upload_btn.click
|
@@ -630,11 +673,11 @@ def create_gr_interface() -> gr.Blocks:
|
|
630 |
chat_submit_btn.click(
|
631 |
add_message,
|
632 |
inputs=[chat_history, query_input],
|
633 |
-
outputs=[chat_history, query_input
|
634 |
).success(
|
635 |
-
|
636 |
inputs=[chat_history, current_collection],
|
637 |
-
outputs=[chat_history,
|
638 |
)
|
639 |
|
640 |
return demo
|
|
|
494 |
|
495 |
# Right column: PDF Preview and Upload
|
496 |
with gr.Column(scale=3):
|
497 |
+
# pdf_preview = gr.Image(label="Source Page", height=600)
|
498 |
+
source_info = gr.Markdown(
|
499 |
+
label="Sources",
|
500 |
+
value="No sources available yet."
|
501 |
+
)
|
502 |
with gr.Row():
|
503 |
pdf_files = gr.File(
|
504 |
file_count="multiple",
|
|
|
576 |
history = history + [(message, None)]
|
577 |
return history, "", None
|
578 |
|
579 |
+
def format_source_info(source_nodes) -> str:
|
580 |
+
"""Format source information into a markdown string."""
|
581 |
+
if not source_nodes:
|
582 |
+
return "No source information available"
|
583 |
+
|
584 |
+
sources_md = "### Sources\n\n"
|
585 |
+
seen_sources = set() # To track unique sources
|
586 |
+
|
587 |
+
for idx, node in enumerate(source_nodes, 1):
|
588 |
+
metadata = node.metadata
|
589 |
+
if not metadata:
|
590 |
+
continue
|
591 |
+
|
592 |
+
source_key = (metadata.get('source_file', ''), metadata.get('page_number', 0))
|
593 |
+
if source_key in seen_sources:
|
594 |
+
continue
|
595 |
+
|
596 |
+
seen_sources.add(source_key)
|
597 |
+
title = metadata.get('title', os.path.basename(metadata.get('source_file', 'Unknown')))
|
598 |
+
page = metadata.get('page_number', 'N/A')
|
599 |
+
|
600 |
+
sources_md += f"{idx}. **{title}** - Page {page}\n"
|
601 |
+
|
602 |
+
return sources_md
|
603 |
+
|
604 |
def generate_chat_response(history, collection_id, pdf_processor):
|
605 |
"""Generate response for the last message in history."""
|
606 |
if not collection_id:
|
|
|
612 |
try:
|
613 |
# Get response and source info
|
614 |
rag = get_rag_pipeline(collection_id)
|
615 |
+
response_text, source_nodes = rag.query(last_message)
|
616 |
+
|
617 |
+
# Format sources info
|
618 |
+
sources_md = "### Top Sources\n\n"
|
619 |
+
if source_nodes and len(source_nodes) > 0:
|
620 |
+
seen_sources = set()
|
621 |
+
source_count = 0
|
622 |
+
|
623 |
+
# Only process up to 3 sources
|
624 |
+
for node in source_nodes:
|
625 |
+
if source_count >= 3: # Stop after 3 sources
|
626 |
+
break
|
627 |
+
|
628 |
+
if not hasattr(node, 'metadata'):
|
629 |
+
continue
|
630 |
+
|
631 |
+
metadata = node.metadata
|
632 |
+
source_key = (
|
633 |
+
metadata.get('source_file', ''),
|
634 |
+
metadata.get('page_number', 0)
|
635 |
)
|
636 |
+
|
637 |
+
if source_key in seen_sources:
|
638 |
+
continue
|
639 |
+
|
640 |
+
seen_sources.add(source_key)
|
641 |
+
source_count += 1
|
642 |
+
|
643 |
+
title = metadata.get('title', 'Unknown')
|
644 |
+
if not title or title == 'Unknown':
|
645 |
+
title = os.path.basename(metadata.get('source_file', 'Unknown Document'))
|
646 |
+
|
647 |
+
page = metadata.get('page_number', 'N/A')
|
648 |
+
sources_md += f"{source_count}. **{title}** - Page {page}\n"
|
649 |
+
|
650 |
+
if source_count == 0:
|
651 |
+
sources_md = "No source information available"
|
652 |
+
else:
|
653 |
+
sources_md = "No source information available"
|
654 |
|
655 |
# Update history with response
|
656 |
+
history[-1] = (last_message, response_text)
|
657 |
+
return history, sources_md
|
658 |
|
659 |
except Exception as e:
|
660 |
logger.error(f"Error in generate_chat_response: {str(e)}")
|
661 |
history[-1] = (last_message, f"Error: {str(e)}")
|
662 |
+
return history, "Error retrieving sources"
|
663 |
+
|
664 |
|
665 |
# Update PDF event handlers
|
666 |
upload_btn.click( # Change from pdf_files.upload to upload_btn.click
|
|
|
673 |
chat_submit_btn.click(
|
674 |
add_message,
|
675 |
inputs=[chat_history, query_input],
|
676 |
+
outputs=[chat_history, query_input],
|
677 |
).success(
|
678 |
+
generate_chat_response,
|
679 |
inputs=[chat_history, current_collection],
|
680 |
+
outputs=[chat_history, source_info],
|
681 |
)
|
682 |
|
683 |
return demo
|
rag/rag_pipeline.py
CHANGED
@@ -152,29 +152,36 @@ class RAGPipeline:
|
|
152 |
self.index = VectorStoreIndex(
|
153 |
nodes, vector_store=vector_store, embed_model=self.embedding_model
|
154 |
)
|
|
|
155 |
|
156 |
def query(
|
157 |
self, context: str, prompt_template: PromptTemplate = None
|
158 |
-
) -> Tuple[str,
|
159 |
if prompt_template is None:
|
160 |
prompt_template = PromptTemplate(
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
|
172 |
# Extract page number for PDF documents
|
173 |
requested_page = (
|
174 |
self.extract_page_number_from_query(context) if self.is_pdf else None
|
175 |
)
|
176 |
|
177 |
-
# This is a hack to index all the documents in the store :)
|
178 |
n_documents = len(self.index.docstore.docs)
|
179 |
print(f"n_documents: {n_documents}")
|
180 |
query_engine = self.index.as_query_engine(
|
@@ -185,25 +192,11 @@ class RAGPipeline:
|
|
185 |
)
|
186 |
|
187 |
response = query_engine.query(context)
|
188 |
-
|
189 |
-
#
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
page_number = (
|
197 |
-
requested_page
|
198 |
-
if requested_page is not None
|
199 |
-
else metadata.get("page_number", 0)
|
200 |
-
)
|
201 |
-
source_info = {
|
202 |
-
"source_file": metadata.get("source_file"),
|
203 |
-
"page_number": page_number,
|
204 |
-
"title": metadata.get("title"),
|
205 |
-
"authors": metadata.get("authors"),
|
206 |
-
"content": source_node.text,
|
207 |
-
}
|
208 |
-
|
209 |
-
return response.response, source_info
|
|
|
152 |
self.index = VectorStoreIndex(
|
153 |
nodes, vector_store=vector_store, embed_model=self.embedding_model
|
154 |
)
|
155 |
+
|
156 |
|
157 |
def query(
|
158 |
self, context: str, prompt_template: PromptTemplate = None
|
159 |
+
) -> Tuple[str, List[Any]]:
|
160 |
if prompt_template is None:
|
161 |
prompt_template = PromptTemplate(
|
162 |
+
"Context information is below.\n"
|
163 |
+
"---------------------\n"
|
164 |
+
"{context_str}\n"
|
165 |
+
"---------------------\n"
|
166 |
+
"Given this information, please answer the question: {query_str}\n"
|
167 |
+
"Follow these guidelines for your response:\n"
|
168 |
+
"1. If the answer contains multiple pieces of information (e.g., author names, dates, statistics), "
|
169 |
+
"present it in a markdown table format.\n"
|
170 |
+
"2. For single piece information or simple answers, respond in a clear sentence.\n"
|
171 |
+
"3. Always cite sources using square brackets for EVERY piece of information, e.g. [1], [2], etc.\n"
|
172 |
+
"4. If the information spans multiple documents or pages, organize it by source.\n"
|
173 |
+
"5. If you're unsure about something, say so rather than making assumptions.\n"
|
174 |
+
"\nFormat tables like this:\n"
|
175 |
+
"| Field | Information | Source |\n"
|
176 |
+
"|-------|-------------|--------|\n"
|
177 |
+
"| Title | Example Title | [1] |\n"
|
178 |
+
)
|
179 |
|
180 |
# Extract page number for PDF documents
|
181 |
requested_page = (
|
182 |
self.extract_page_number_from_query(context) if self.is_pdf else None
|
183 |
)
|
184 |
|
|
|
185 |
n_documents = len(self.index.docstore.docs)
|
186 |
print(f"n_documents: {n_documents}")
|
187 |
query_engine = self.index.as_query_engine(
|
|
|
192 |
)
|
193 |
|
194 |
response = query_engine.query(context)
|
195 |
+
|
196 |
+
# Debug logging
|
197 |
+
print(f"Response type: {type(response)}")
|
198 |
+
print(f"Has source_nodes: {hasattr(response, 'source_nodes')}")
|
199 |
+
if hasattr(response, 'source_nodes'):
|
200 |
+
print(f"Number of source nodes: {len(response.source_nodes)}")
|
201 |
+
|
202 |
+
return response.response, getattr(response, 'source_nodes', [])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
study_files.json
CHANGED
@@ -1,5 +1,13 @@
|
|
1 |
{
|
2 |
"Ebola Virus": "data/ebola-virus_zotero_items.json",
|
3 |
"GeneXpert": "data/genexpert_zotero_items.json",
|
4 |
-
"Vaccine coverage": "data/vaccine-coverage_zotero_items.json"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
}
|
|
|
1 |
{
|
2 |
"Ebola Virus": "data/ebola-virus_zotero_items.json",
|
3 |
"GeneXpert": "data/genexpert_zotero_items.json",
|
4 |
+
"Vaccine coverage": "data/vaccine-coverage_zotero_items.json",
|
5 |
+
"Concept": "data/concept_zotero_items.json",
|
6 |
+
"Zotero Collection Pastan": "data/zotero-collection-pastan_zotero_items.json",
|
7 |
+
"pdf_thequickone": "data/thequickone_20250108_111913_documents.json",
|
8 |
+
"pdf_aforapples": "data/aforapples_20250108_113044_documents.json",
|
9 |
+
"pdf_bforbinance": "data/bforbinance_20250108_114459_documents.json",
|
10 |
+
"pdf_cforcongo": "data/cforcongo_20250108_115233_documents.json",
|
11 |
+
"pdf_hjhj": "data/hjhj_20250108_115714_documents.json",
|
12 |
+
"pdf_schooldropouts": "data/schooldropouts_20250108_140257_documents.json"
|
13 |
}
|