rtabrizi commited on
Commit
aa4fa52
1 Parent(s): 3550b10

initial commit

Browse files
Files changed (2) hide show
  1. app.py +163 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import numpy as np
4
+ import faiss
5
+ import PyPDF2
6
+ import os
7
+
8
+ from transformers import BertTokenizer, BertModel
9
+ from transformers import DPRContextEncoder, DPRContextEncoderTokenizer, DPRQuestionEncoder, DPRQuestionEncoderTokenizer, BartForQuestionAnswering
10
+ from transformers import BartForConditionalGeneration, BartTokenizer, AutoTokenizer
11
+
12
+ from langchain.embeddings import HuggingFaceEmbeddings
13
+ from langchain import text_splitter
14
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
15
+ from langchain.document_loaders import PyPDFLoader
16
+
17
+ device = torch.device("cpu")
18
+ if torch.cuda.is_available():
19
+ print("Training on GPU")
20
+ device = torch.device("cuda:0")
21
+
22
+ file_url = "https://arxiv.org/pdf/1706.03762.pdf"
23
+ file_path = "assets/attention.pdf"
24
+
25
+ if not os.path.exists('assets'):
26
+ os.mkdir('assets')
27
+
28
+ if not os.path.isfile(file_path):
29
+ os.system(f'curl -o {file_path} {file_url}')
30
+ else:
31
+ print("File already exists!")
32
+
33
+ class Retriever:
34
+
35
+ def __init__(self, file_path, device, context_model_name, question_model_name):
36
+ self.file_path = file_path
37
+ self.device = device
38
+
39
+ self.context_tokenizer = DPRContextEncoderTokenizer.from_pretrained(context_model_name)
40
+ self.context_model = DPRContextEncoder.from_pretrained(context_model_name).to(device)
41
+
42
+ self.question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(question_model_name)
43
+ self.question_model = DPRQuestionEncoder.from_pretrained(question_model_name).to(device)
44
+
45
+ def token_len(self, text):
46
+ tokens = self.context_tokenizer.encode(text)
47
+ return len(tokens)
48
+
49
+ def extract_text_from_pdf(self, file_path):
50
+ with open(file_path, 'rb') as file:
51
+ reader = PyPDF2.PdfReader(file)
52
+ text = ''
53
+ for page in reader.pages:
54
+ text += page.extract_text()
55
+ return text
56
+
57
+ def get_text(self):
58
+ with open(self.file_path, 'rb') as file:
59
+ reader = PyPDF2.PdfReader(file)
60
+ text = ''
61
+ for page in reader.pages:
62
+ text += page.extract_text()
63
+ return text
64
+
65
+ def load_chunks(self):
66
+ self.text = self.extract_text_from_pdf(self.file_path)
67
+ text_splitter = RecursiveCharacterTextSplitter(
68
+ chunk_size=300,
69
+ chunk_overlap=20,
70
+ length_function=self.token_len,
71
+ separators=["\n\n", " ", ".", ""]
72
+ )
73
+
74
+ self.chunks = text_splitter.split_text(self.text)
75
+
76
+ def load_context_embeddings(self):
77
+ encoded_input = self.context_tokenizer(self.chunks, return_tensors='pt', padding=True, truncation=True, max_length=100).to(device)
78
+
79
+ with torch.no_grad():
80
+ model_output = self.context_model(**encoded_input)
81
+ self.token_embeddings = model_output.pooler_output.cpu().detach().numpy()
82
+
83
+ self.index = faiss.IndexFlatL2(self.token_embeddings.shape[1])
84
+ self.index.add(self.token_embeddings)
85
+
86
+ def retrieve_top_k(self, query_prompt, k=10):
87
+ encoded_query = self.question_tokenizer(query_prompt, return_tensors="pt", truncation=True, padding=True).to(device)
88
+
89
+ with torch.no_grad():
90
+ model_output = self.question_model(**encoded_query)
91
+ query_vector = model_output.pooler_output
92
+
93
+ query_vector_np = query_vector.cpu().numpy()
94
+ D, I = self.index.search(query_vector_np, k)
95
+
96
+ retrieved_texts = [self.chunks[i] for i in I[0]]
97
+
98
+ scores = [d for d in D[0]]
99
+
100
+ # print("Top 5 retrieved texts and their associated scores:")
101
+ # for idx, (text, score) in enumerate(zip(retrieved_texts, scores)):
102
+ # print(f"{idx + 1}. Text: {text} \n Score: {score:.4f}\n")
103
+
104
+ return retrieved_texts
105
+
106
+ class RAG:
107
+ def __init__(self,
108
+ file_path,
109
+ device,
110
+ context_model_name="facebook/dpr-ctx_encoder-multiset-base",
111
+ question_model_name="facebook/dpr-question_encoder-multiset-base",
112
+ generator_name="facebook/bart-large"):
113
+
114
+ # generator_name = "valhalla/bart-large-finetuned-squadv1"
115
+ # generator_name = "'vblagoje/bart_lfqa'"
116
+ generator_name = "a-ware/bart-squadv2"
117
+
118
+ self.generator_tokenizer = BartTokenizer.from_pretrained(generator_name)
119
+ self.generator_model = BartForConditionalGeneration.from_pretrained(generator_name).to(device)
120
+
121
+ self.retriever = Retriever(file_path, device, context_model_name, question_model_name)
122
+ self.retriever.load_chunks()
123
+ self.retriever.load_context_embeddings()
124
+
125
+ def get_answer(self, question, context):
126
+ input_text = "context: %s <question for context: %s </s>" % (context,question)
127
+ features = self.generator_tokenizer([input_text], return_tensors='pt')
128
+ out = self.generator_model.generate(input_ids=features['input_ids'].to(device), attention_mask=features['attention_mask'].to(device))
129
+ return self.generator_tokenizer.decode(out[0])
130
+
131
+ def query(self, question):
132
+ context = self.retriever.retrieve_top_k(question, k=5)
133
+ # input_text = question + " " + " ".join(context)
134
+
135
+ input_text = "answer: " + " ".join(context) + " " + question
136
+
137
+ print(input_text)
138
+
139
+ inputs = self.generator_tokenizer.encode(input_text, return_tensors='pt', max_length=1024, truncation=True).to(device)
140
+ outputs = self.generator_model.generate(inputs, max_length=150, min_length=2, length_penalty=2.0, num_beams=4, early_stopping=True)
141
+
142
+ answer = self.generator_tokenizer.decode(outputs[0], skip_special_tokens=True)
143
+ return answer
144
+
145
+
146
+ context_model_name="facebook/dpr-ctx_encoder-single-nq-base"
147
+ context_model_name="facebook/dpr-ctx_encoder-multiset-base"
148
+ question_model_name="facebook/dpr-question_encoder-multiset-base"
149
+
150
+ rag = RAG(file_path, device)
151
+
152
+ st.title("RAG Model Query Interface")
153
+
154
+ query = st.text_input("Enter your question:")
155
+
156
+ # If a query is given, get the answer
157
+ if query:
158
+ answer = rag.query(query)
159
+ st.write(f"Answer: {answer}")
160
+
161
+ if __name__ == "__main__":
162
+ # This is used when running locally. Can be removed if deploying to a server.
163
+ st.run()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ torch
3
+ numpy
4
+ faiss-cpu
5
+ PyPDF2
6
+ transformers
7
+ langchain
8
+