ak3ra commited on
Commit
bc5a5b2
·
1 Parent(s): daee42b

added rag pipelin

Browse files
Files changed (1) hide show
  1. rag/rag_pipeline.py +53 -39
rag/rag_pipeline.py CHANGED
@@ -1,4 +1,3 @@
1
- import gradio as gr
2
  import json
3
  import os
4
  from typing import Dict, Any
@@ -6,40 +5,42 @@ from llama_index.core import (
6
  SimpleDirectoryReader,
7
  VectorStoreIndex,
8
  Document,
9
- Response,
10
- PromptTemplate
11
  )
12
- from llama_index.core.node_parser import SentenceSplitter
13
  from llama_index.embeddings.openai import OpenAIEmbedding
14
-
15
- # Make sure to set your OpenAI API key in the Hugging Face Spaces secrets
16
- import openai
17
- openai.api_key = os.environ.get('OPENAI_API_KEY')
18
-
19
 
20
 
21
  class RAGPipeline:
22
- def __init__(self, metadata_file, pdf_dir, use_semantic_splitter=False):
 
 
23
  self.metadata_file = metadata_file
24
  self.pdf_dir = pdf_dir
25
- self.index = None
26
  self.use_semantic_splitter = use_semantic_splitter
 
27
  self.load_documents()
28
  self.build_index()
29
 
30
  def load_documents(self):
31
- with open(self.metadata_file, 'r') as f:
32
  self.metadata = json.load(f)
33
 
34
  self.documents = []
35
  for item_key, item_data in self.metadata.items():
36
- metadata = item_data['metadata']
37
- pdf_path = item_data.get('pdf_path')
38
 
39
  if pdf_path:
40
  full_pdf_path = os.path.join(self.pdf_dir, os.path.basename(pdf_path))
41
  if os.path.exists(full_pdf_path):
42
- pdf_content = SimpleDirectoryReader(input_files=[full_pdf_path]).load_data()[0].text
 
 
 
 
43
  else:
44
  pdf_content = "PDF file not found"
45
  else:
@@ -54,18 +55,9 @@ class RAGPipeline:
54
  f"Full Text: {pdf_content}"
55
  )
56
 
57
- self.documents.append(Document(
58
- text=doc_content,
59
- id_=item_key,
60
- metadata={
61
- "title": metadata['title'],
62
- "abstract": metadata['abstract'],
63
- "authors": metadata['authors'],
64
- "year": metadata['year'],
65
- "doi": metadata['doi']
66
- }
67
- ))
68
-
69
 
70
  def build_index(self):
71
  if self.use_semantic_splitter:
@@ -73,7 +65,7 @@ class RAGPipeline:
73
  splitter = SemanticSplitterNodeParser(
74
  buffer_size=1,
75
  breakpoint_percentile_threshold=95,
76
- embed_model=embed_model
77
  )
78
  else:
79
  splitter = SentenceSplitter(chunk_size=1024, chunk_overlap=20)
@@ -81,10 +73,40 @@ class RAGPipeline:
81
  nodes = splitter.get_nodes_from_documents(self.documents)
82
  self.index = VectorStoreIndex(nodes)
83
 
 
 
 
 
 
 
 
 
 
84
 
85
- def query(self, question, prompt_template=None):
86
- if prompt_template is None:
87
- prompt_template = PromptTemplate(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  "Context information is below.\n"
89
  "---------------------\n"
90
  "{context_str}\n"
@@ -95,11 +117,3 @@ class RAGPipeline:
95
  "If the information is not available in the context, please state that clearly. "
96
  "When quoting specific information, please use square brackets to indicate the source, e.g. [1], [2], etc."
97
  )
98
-
99
- query_engine = self.index.as_query_engine(
100
- text_qa_template=prompt_template,
101
- similarity_top_k=5
102
- )
103
- response = query_engine.query(question)
104
-
105
- return response
 
 
1
  import json
2
  import os
3
  from typing import Dict, Any
 
5
  SimpleDirectoryReader,
6
  VectorStoreIndex,
7
  Document,
8
+ StorageContext,
9
+ load_index_from_storage,
10
  )
11
+ from llama_index.core.node_parser import SentenceSplitter, SemanticSplitterNodeParser
12
  from llama_index.embeddings.openai import OpenAIEmbedding
13
+ from llama_index.core import PromptTemplate
 
 
 
 
14
 
15
 
16
  class RAGPipeline:
17
+ def __init__(
18
+ self, metadata_file: str, pdf_dir: str, use_semantic_splitter: bool = False
19
+ ):
20
  self.metadata_file = metadata_file
21
  self.pdf_dir = pdf_dir
 
22
  self.use_semantic_splitter = use_semantic_splitter
23
+ self.index = None
24
  self.load_documents()
25
  self.build_index()
26
 
27
  def load_documents(self):
28
+ with open(self.metadata_file, "r") as f:
29
  self.metadata = json.load(f)
30
 
31
  self.documents = []
32
  for item_key, item_data in self.metadata.items():
33
+ metadata = item_data["metadata"]
34
+ pdf_path = item_data.get("pdf_path")
35
 
36
  if pdf_path:
37
  full_pdf_path = os.path.join(self.pdf_dir, os.path.basename(pdf_path))
38
  if os.path.exists(full_pdf_path):
39
+ pdf_content = (
40
+ SimpleDirectoryReader(input_files=[full_pdf_path])
41
+ .load_data()[0]
42
+ .text
43
+ )
44
  else:
45
  pdf_content = "PDF file not found"
46
  else:
 
55
  f"Full Text: {pdf_content}"
56
  )
57
 
58
+ self.documents.append(
59
+ Document(text=doc_content, id_=item_key, metadata=metadata)
60
+ )
 
 
 
 
 
 
 
 
 
61
 
62
  def build_index(self):
63
  if self.use_semantic_splitter:
 
65
  splitter = SemanticSplitterNodeParser(
66
  buffer_size=1,
67
  breakpoint_percentile_threshold=95,
68
+ embed_model=embed_model,
69
  )
70
  else:
71
  splitter = SentenceSplitter(chunk_size=1024, chunk_overlap=20)
 
73
  nodes = splitter.get_nodes_from_documents(self.documents)
74
  self.index = VectorStoreIndex(nodes)
75
 
76
+ def query(self, question: str, prompt_type: str = "default") -> Dict[str, Any]:
77
+ prompt_template = self._get_prompt_template(prompt_type)
78
+
79
+ query_engine = self.index.as_query_engine(
80
+ text_qa_template=prompt_template, similarity_top_k=5
81
+ )
82
+ response = query_engine.query(question)
83
+
84
+ return response
85
 
86
+ def _get_prompt_template(self, prompt_type: str) -> PromptTemplate:
87
+ if prompt_type == "highlight":
88
+ return PromptTemplate(
89
+ "Context information is below.\n"
90
+ "---------------------\n"
91
+ "{context_str}\n"
92
+ "---------------------\n"
93
+ "Given this information, please answer the question: {query_str}\n"
94
+ "Include all relevant information from the provided context. "
95
+ "Highlight key information by enclosing it in **asterisks**. "
96
+ "When quoting specific information, please use square brackets to indicate the source, e.g. [1], [2], etc."
97
+ )
98
+ elif prompt_type == "evidence_based":
99
+ return PromptTemplate(
100
+ "Context information is below.\n"
101
+ "---------------------\n"
102
+ "{context_str}\n"
103
+ "---------------------\n"
104
+ "Given this information, please answer the question: {query_str}\n"
105
+ "Provide an answer to the question using evidence from the context above. "
106
+ "Cite sources using square brackets."
107
+ )
108
+ else:
109
+ return PromptTemplate(
110
  "Context information is below.\n"
111
  "---------------------\n"
112
  "{context_str}\n"
 
117
  "If the information is not available in the context, please state that clearly. "
118
  "When quoting specific information, please use square brackets to indicate the source, e.g. [1], [2], etc."
119
  )