ak3ra commited on
Commit
cfb1a62
·
1 Parent(s): b117341
app.py CHANGED
@@ -1,62 +1,72 @@
 
 
1
  import gradio as gr
2
- import os
3
- from database.vaccine_coverage_db import VaccineCoverageDB
4
  from rag.rag_pipeline import RAGPipeline
5
- from utils.helpers import process_response
6
- from config import DB_PATH, METADATA_FILE, PDF_DIR
7
- from initialize_db import initialize_database, populate_database
8
 
9
- # Initialize database if it doesn't exist
10
- if not os.path.exists(DB_PATH):
11
- print("Database not found. Initializing...")
12
- initialize_database()
13
- populate_database()
 
14
 
15
- # Initialize database and RAG pipeline
16
- db = VaccineCoverageDB(DB_PATH)
17
- rag = RAGPipeline(METADATA_FILE, PDF_DIR, use_semantic_splitter=True)
18
 
 
 
19
 
20
- def query_rag(question, prompt_type):
21
  if prompt_type == "Highlight":
22
- response = rag.query(question, prompt_type="highlight")
 
 
23
  else:
24
- response = rag.query(question, prompt_type="evidence_based")
25
 
26
- processed = process_response(response)
27
- return processed["markdown"]
28
 
29
 
30
- def save_pdf(item_key):
31
- attachments = db.get_attachments_for_item(item_key)
32
- if attachments:
33
- attachment_key = attachments[0]["key"]
34
- output_path = os.path.join(PDF_DIR, f"{attachment_key}.pdf")
35
- if db.save_pdf_to_file(attachment_key, output_path):
36
- return f"PDF saved successfully to {output_path}"
37
- return "Failed to save PDF or no attachments found"
38
 
39
 
40
- # Gradio interface
41
  with gr.Blocks() as demo:
42
- gr.Markdown("# Vaccine Coverage Study RAG System")
43
 
44
- with gr.Tab("Query"):
45
- question_input = gr.Textbox(label="Enter your question")
46
- prompt_type = gr.Radio(["Highlight", "Evidence-based"], label="Prompt Type")
47
- query_button = gr.Button("Submit Query")
48
- output = gr.Markdown(label="Response")
 
 
49
 
50
- query_button.click(
51
- query_rag, inputs=[question_input, prompt_type], outputs=output
 
 
 
 
52
  )
53
 
54
- with gr.Tab("Save PDF"):
55
- item_key_input = gr.Textbox(label="Enter item key")
56
- save_button = gr.Button("Save PDF")
57
- save_output = gr.Textbox(label="Save Result")
58
 
59
- save_button.click(save_pdf, inputs=item_key_input, outputs=save_output)
 
 
 
 
60
 
61
  if __name__ == "__main__":
62
  demo.launch()
 
1
+ # app.py
2
+
3
  import gradio as gr
4
+ import json
 
5
  from rag.rag_pipeline import RAGPipeline
6
+ from utils.prompts import highlight_prompt, evidence_based_prompt
7
+ from config import STUDY_FILES
8
+
9
 
10
+ def load_rag_pipeline(study_name):
11
+ study_file = STUDY_FILES.get(study_name)
12
+ if study_file:
13
+ return RAGPipeline(study_file)
14
+ else:
15
+ raise ValueError(f"Invalid study name: {study_name}")
16
 
 
 
 
17
 
18
+ def query_rag(study_name, question, prompt_type):
19
+ rag = load_rag_pipeline(study_name)
20
 
 
21
  if prompt_type == "Highlight":
22
+ prompt = highlight_prompt
23
+ elif prompt_type == "Evidence-based":
24
+ prompt = evidence_based_prompt
25
  else:
26
+ prompt = None
27
 
28
+ response = rag.query(question, prompt)
29
+ return response.response
30
 
31
 
32
+ def get_study_info(study_name):
33
+ study_file = STUDY_FILES.get(study_name)
34
+ if study_file:
35
+ with open(study_file, "r") as f:
36
+ data = json.load(f)
37
+ return f"Number of documents: {len(data)}\nFirst document title: {data[0]['title']}"
38
+ else:
39
+ return "Invalid study name"
40
 
41
 
 
42
  with gr.Blocks() as demo:
43
+ gr.Markdown("# RAG Pipeline Demo")
44
 
45
+ with gr.Row():
46
+ study_dropdown = gr.Dropdown(
47
+ choices=list(STUDY_FILES.keys()), label="Select Study"
48
+ )
49
+ study_info = gr.Textbox(label="Study Information", interactive=False)
50
+
51
+ study_dropdown.change(get_study_info, inputs=[study_dropdown], outputs=[study_info])
52
 
53
+ with gr.Row():
54
+ question_input = gr.Textbox(label="Enter your question")
55
+ prompt_type = gr.Radio(
56
+ ["Default", "Highlight", "Evidence-based"],
57
+ label="Prompt Type",
58
+ value="Default",
59
  )
60
 
61
+ submit_button = gr.Button("Submit")
62
+
63
+ answer_output = gr.Textbox(label="Answer")
 
64
 
65
+ submit_button.click(
66
+ query_rag,
67
+ inputs=[study_dropdown, question_input, prompt_type],
68
+ outputs=[answer_output],
69
+ )
70
 
71
  if __name__ == "__main__":
72
  demo.launch()
config.py CHANGED
@@ -1,20 +1,9 @@
1
  import os
2
 
3
- # Base directory
4
- BASE_DIR = os.path.dirname(os.path.abspath(__file__))
5
-
6
- # Database configuration
7
- DB_NAME = "vaccine_coverage_study.db"
8
- DB_PATH = os.path.join(BASE_DIR, DB_NAME)
9
-
10
- # RAG Pipeline configuration
11
- DATA_DIR = os.path.join(BASE_DIR, "data")
12
- METADATA_FILE = os.path.join(DATA_DIR, "metadata_map.json")
13
- PDF_DIR = os.path.join(DATA_DIR, "pdfs")
14
-
15
- # Create directories if they don't exist
16
- os.makedirs(DATA_DIR, exist_ok=True)
17
- os.makedirs(PDF_DIR, exist_ok=True)
18
-
19
- # OpenAI configuration
20
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
 
 
 
 
 
 
 
1
  import os
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
4
+
5
+ STUDY_FILES = {
6
+ "Vaccine Coverage": "data/vaccine_coverage_zotero_items.json",
7
+ "Ebola Virus": "data/ebola_virus_zotero_items.json",
8
+ "Gene Xpert": "data/gene_xpert_zotero_items.json",
9
+ }
database/__init__.py DELETED
File without changes
database/vaccine_coverage_db.py DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:42a0645cdd38f2d7ede525768eb21a4cbe08b4d86959cb4eb2349887f2bcf70e
3
- size 1774
 
 
 
 
initialize_db.py DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:08030c4783a86d9a9afb9437b102dde959405b6b2857725eec02b6d9c2699e97
3
- size 2346
 
 
 
 
rag/rag_pipeline.py CHANGED
@@ -1,24 +1,13 @@
 
 
1
  import json
2
- import os
3
- from typing import Dict, Any
4
- from llama_index.core import (
5
- VectorStoreIndex,
6
- Document,
7
- SentenceWindowNodeParser,
8
- )
9
- from llama_index.core.node_parser import (
10
- SentenceSplitter,
11
- )
12
  from llama_index.core import PromptTemplate
13
 
14
 
15
  class RAGPipeline:
16
- def __init__(
17
- self,
18
- study_json,
19
- use_semantic_splitter=False,
20
- ):
21
-
22
  self.study_json = study_json
23
  self.index = None
24
  self.use_semantic_splitter = use_semantic_splitter
@@ -34,10 +23,7 @@ class RAGPipeline:
34
  for index, doc_data in enumerate(self.data):
35
  doc_content = (
36
  f"Title: {doc_data['title']}\n"
37
- f"Abstract: {doc_data['abstract']}\n"
38
  f"Authors: {', '.join(doc_data['authors'])}\n"
39
- f"Year: {doc_data['year']}\n"
40
- f"DOI: {doc_data['doi']}\n"
41
  f"Full Text: {doc_data['full_text']}"
42
  )
43
 
@@ -50,11 +36,7 @@ class RAGPipeline:
50
  }
51
 
52
  self.documents.append(
53
- Document(
54
- text=doc_content,
55
- id_=f"doc_{index}",
56
- metadata=metadata,
57
- )
58
  )
59
 
60
  def build_index(self):
@@ -71,7 +53,6 @@ class RAGPipeline:
71
  )
72
 
73
  nodes = node_parser.get_nodes_from_documents(self.documents)
74
-
75
  self.index = VectorStoreIndex(nodes)
76
 
77
  def query(self, question, prompt_template=None):
@@ -89,8 +70,7 @@ class RAGPipeline:
89
  )
90
 
91
  query_engine = self.index.as_query_engine(
92
- text_qa_template=prompt_template,
93
- similarity_top_k=5,
94
  )
95
  response = query_engine.query(question)
96
 
 
1
+ # rag/rag_pipeline.py
2
+
3
  import json
4
+ from llama_index.core import Document, VectorStoreIndex
5
+ from llama_index.core.node_parser import SentenceWindowNodeParser, SentenceSplitter
 
 
 
 
 
 
 
 
6
  from llama_index.core import PromptTemplate
7
 
8
 
9
  class RAGPipeline:
10
+ def __init__(self, study_json, use_semantic_splitter=False):
 
 
 
 
 
11
  self.study_json = study_json
12
  self.index = None
13
  self.use_semantic_splitter = use_semantic_splitter
 
23
  for index, doc_data in enumerate(self.data):
24
  doc_content = (
25
  f"Title: {doc_data['title']}\n"
 
26
  f"Authors: {', '.join(doc_data['authors'])}\n"
 
 
27
  f"Full Text: {doc_data['full_text']}"
28
  )
29
 
 
36
  }
37
 
38
  self.documents.append(
39
+ Document(text=doc_content, id_=f"doc_{index}", metadata=metadata)
 
 
 
 
40
  )
41
 
42
  def build_index(self):
 
53
  )
54
 
55
  nodes = node_parser.get_nodes_from_documents(self.documents)
 
56
  self.index = VectorStoreIndex(nodes)
57
 
58
  def query(self, question, prompt_template=None):
 
70
  )
71
 
72
  query_engine = self.index.as_query_engine(
73
+ text_qa_template=prompt_template, similarity_top_k=5
 
74
  )
75
  response = query_engine.query(question)
76
 
utils/prompts.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from llama_index.core import PromptTemplate
2
+
3
+ highlight_prompt = PromptTemplate(
4
+ "Context information is below.\n"
5
+ "---------------------\n"
6
+ "{context_str}\n"
7
+ "---------------------\n"
8
+ "Given this information, please answer the question: {query_str}\n"
9
+ "Include all relevant information from the provided context. "
10
+ "Highlight key information by enclosing it in **asterisks**. "
11
+ "When quoting specific information, please use square brackets to indicate the source, e.g. [1], [2], etc."
12
+ )
13
+
14
+ evidence_based_prompt = PromptTemplate(
15
+ "Context information is below.\n"
16
+ "---------------------\n"
17
+ "{context_str}\n"
18
+ "---------------------\n"
19
+ "Given this information, please answer the question: {query_str}\n"
20
+ "Provide an answer to the question using evidence from the context above. "
21
+ "Cite sources using square brackets."
22
+ )