|
import streamlit as st |
|
import torch |
|
import numpy as np |
|
import faiss |
|
import PyPDF2 |
|
import os |
|
|
|
from transformers import BertTokenizer, BertModel |
|
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer, DPRQuestionEncoder, DPRQuestionEncoderTokenizer, BartForQuestionAnswering |
|
from transformers import BartForConditionalGeneration, BartTokenizer, AutoTokenizer |
|
|
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from langchain import text_splitter |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.document_loaders import PyPDFLoader |
|
|
|
device = torch.device("cpu") |
|
if torch.cuda.is_available(): |
|
print("Training on GPU") |
|
device = torch.device("cuda:0") |
|
|
|
file_url = "https://arxiv.org/pdf/1706.03762.pdf" |
|
file_path = "assets/attention.pdf" |
|
|
|
if not os.path.exists('assets'): |
|
os.mkdir('assets') |
|
|
|
if not os.path.isfile(file_path): |
|
os.system(f'curl -o {file_path} {file_url}') |
|
else: |
|
print("File already exists!") |
|
|
|
class Retriever: |
|
|
|
def __init__(self, file_path, device, context_model_name, question_model_name): |
|
self.file_path = file_path |
|
self.device = device |
|
|
|
self.context_tokenizer = DPRContextEncoderTokenizer.from_pretrained(context_model_name) |
|
self.context_model = DPRContextEncoder.from_pretrained(context_model_name).to(device) |
|
|
|
self.question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(question_model_name) |
|
self.question_model = DPRQuestionEncoder.from_pretrained(question_model_name).to(device) |
|
|
|
def token_len(self, text): |
|
tokens = self.context_tokenizer.encode(text) |
|
return len(tokens) |
|
|
|
def extract_text_from_pdf(self, file_path): |
|
with open(file_path, 'rb') as file: |
|
reader = PyPDF2.PdfReader(file) |
|
text = '' |
|
for page in reader.pages: |
|
text += page.extract_text() |
|
return text |
|
|
|
def get_text(self): |
|
with open(self.file_path, 'rb') as file: |
|
reader = PyPDF2.PdfReader(file) |
|
text = '' |
|
for page in reader.pages: |
|
text += page.extract_text() |
|
return text |
|
|
|
def load_chunks(self): |
|
self.text = self.extract_text_from_pdf(self.file_path) |
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=300, |
|
chunk_overlap=20, |
|
length_function=self.token_len, |
|
separators=["\n\n", " ", ".", ""] |
|
) |
|
|
|
self.chunks = text_splitter.split_text(self.text) |
|
|
|
def load_context_embeddings(self): |
|
encoded_input = self.context_tokenizer(self.chunks, return_tensors='pt', padding=True, truncation=True, max_length=100).to(device) |
|
|
|
with torch.no_grad(): |
|
model_output = self.context_model(**encoded_input) |
|
self.token_embeddings = model_output.pooler_output.cpu().detach().numpy() |
|
|
|
self.index = faiss.IndexFlatL2(self.token_embeddings.shape[1]) |
|
self.index.add(self.token_embeddings) |
|
|
|
def retrieve_top_k(self, query_prompt, k=10): |
|
encoded_query = self.question_tokenizer(query_prompt, return_tensors="pt", truncation=True, padding=True).to(device) |
|
|
|
with torch.no_grad(): |
|
model_output = self.question_model(**encoded_query) |
|
query_vector = model_output.pooler_output |
|
|
|
query_vector_np = query_vector.cpu().numpy() |
|
D, I = self.index.search(query_vector_np, k) |
|
|
|
retrieved_texts = [self.chunks[i] for i in I[0]] |
|
|
|
scores = [d for d in D[0]] |
|
|
|
|
|
|
|
|
|
|
|
return retrieved_texts |
|
|
|
class RAG: |
|
def __init__(self, |
|
file_path, |
|
device, |
|
context_model_name="facebook/dpr-ctx_encoder-multiset-base", |
|
question_model_name="facebook/dpr-question_encoder-multiset-base", |
|
generator_name="facebook/bart-large"): |
|
|
|
|
|
|
|
generator_name = "a-ware/bart-squadv2" |
|
|
|
self.generator_tokenizer = BartTokenizer.from_pretrained(generator_name) |
|
self.generator_model = BartForConditionalGeneration.from_pretrained(generator_name).to(device) |
|
|
|
self.retriever = Retriever(file_path, device, context_model_name, question_model_name) |
|
self.retriever.load_chunks() |
|
self.retriever.load_context_embeddings() |
|
|
|
def get_answer(self, question, context): |
|
input_text = "context: %s <question for context: %s </s>" % (context,question) |
|
features = self.generator_tokenizer([input_text], return_tensors='pt') |
|
out = self.generator_model.generate(input_ids=features['input_ids'].to(device), attention_mask=features['attention_mask'].to(device)) |
|
return self.generator_tokenizer.decode(out[0]) |
|
|
|
def query(self, question): |
|
context = self.retriever.retrieve_top_k(question, k=5) |
|
|
|
|
|
input_text = "answer: " + " ".join(context) + " " + question |
|
|
|
print(input_text) |
|
|
|
inputs = self.generator_tokenizer.encode(input_text, return_tensors='pt', max_length=1024, truncation=True).to(device) |
|
outputs = self.generator_model.generate(inputs, max_length=150, min_length=2, length_penalty=2.0, num_beams=4, early_stopping=True) |
|
|
|
answer = self.generator_tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return answer |
|
|
|
|
|
context_model_name="facebook/dpr-ctx_encoder-single-nq-base" |
|
context_model_name="facebook/dpr-ctx_encoder-multiset-base" |
|
question_model_name="facebook/dpr-question_encoder-multiset-base" |
|
|
|
rag = RAG(file_path, device) |
|
|
|
st.title("RAG Model Query Interface") |
|
|
|
query = st.text_input("Enter your question:") |
|
|
|
|
|
if query: |
|
answer = rag.query(query) |
|
st.write(f"Answer: {answer}") |
|
|
|
if __name__ == "__main__": |
|
|
|
st.run() |