muhammadshaheryar commited on
Commit
2a01b43
·
verified ·
1 Parent(s): 77a3df6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -0
app.py CHANGED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Install required libraries
2
+ !pip install transformers faiss-cpu PyMuPDF streamlit sentence-transformers
3
+
4
+ import os
5
+ import fitz # PyMuPDF for PDF extraction
6
+ import faiss # for efficient vector search
7
+ import numpy as np
8
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, RagTokenizer, RagRetriever, RagSequenceForGeneration
9
+ from sentence_transformers import SentenceTransformer
10
+ import streamlit as st
11
+
12
+ # Load the pre-trained RAG model and tokenizer
13
+ model_name = "facebook/rag-token-nq" # You can change this to a different open-source RAG model if needed
14
+ tokenizer = RagTokenizer.from_pretrained(model_name)
15
+ model = RagSequenceForGeneration.from_pretrained(model_name)
16
+
17
+ # Initialize sentence transformer model for embeddings
18
+ embedder = SentenceTransformer('all-MiniLM-L6-v2')
19
+
20
+ # Function to extract text from a PDF file
21
+ def extract_text_from_pdf(pdf_file):
22
+ pdf_document = fitz.open(pdf_file)
23
+ text = ""
24
+ for page_num in range(pdf_document.page_count):
25
+ page = pdf_document.load_page(page_num)
26
+ text += page.get_text("text")
27
+ return text
28
+
29
+ # Function to create embeddings from text data
30
+ def create_embeddings(text_data):
31
+ embeddings = embedder.encode(text_data, convert_to_tensor=True)
32
+ return embeddings
33
+
34
+ # Function to create a FAISS index from the embeddings
35
+ def create_faiss_index(embeddings):
36
+ index = faiss.IndexFlatL2(embeddings.shape[1]) # Using L2 distance metric
37
+ index.add(embeddings)
38
+ return index
39
+
40
+ # Function to retrieve the most relevant text using FAISS
41
+ def retrieve_relevant_text(query, faiss_index, texts, top_k=3):
42
+ query_embedding = embedder.encode([query], convert_to_tensor=True)
43
+ D, I = faiss_index.search(query_embedding, top_k) # D: distances, I: indices
44
+ return [texts[i] for i in I[0]]
45
+
46
+ # Main function to answer questions based on uploaded PDF
47
+ def get_answer_from_pdf(pdf_file, query):
48
+ # Step 1: Extract text from the uploaded PDF file
49
+ document_text = extract_text_from_pdf(pdf_file)
50
+
51
+ # Step 2: Split the document text into chunks (optional but recommended for large docs)
52
+ text_chunks = document_text.split('\n')
53
+
54
+ # Step 3: Create embeddings for each chunk of text
55
+ embeddings = create_embeddings(text_chunks)
56
+
57
+ # Step 4: Create a FAISS index for efficient retrieval
58
+ faiss_index = create_faiss_index(embeddings)
59
+
60
+ # Step 5: Retrieve relevant text from the document based on the query
61
+ relevant_texts = retrieve_relevant_text(query, faiss_index, text_chunks)
62
+
63
+ # Step 6: Combine the relevant text and pass it to the RAG model
64
+ context = " ".join(relevant_texts)
65
+ inputs = tokenizer([query], return_tensors="pt", padding=True, truncation=True)
66
+ context_inputs = tokenizer(context, return_tensors="pt", padding=True, truncation=True)
67
+
68
+ # Generate the answer
69
+ outputs = model.generate(input_ids=inputs["input_ids"], context_input_ids=context_inputs["input_ids"])
70
+ answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
71
+
72
+ return answer
73
+
74
+ # Streamlit UI
75
+ def main():
76
+ st.title("RAG Application - PDF Q&A")
77
+
78
+ # Upload PDF file
79
+ uploaded_file = st.file_uploader("Upload a PDF Document", type="pdf")
80
+
81
+ if uploaded_file is not None:
82
+ # Ask a question from the uploaded PDF
83
+ question = st.text_input("Ask a question based on the document:")
84
+
85
+ if question:
86
+ # Get the answer from the PDF document
87
+ answer = get_answer_from_pdf(uploaded_file, question)
88
+
89
+ # Display the answer
90
+ st.write("Answer: ", answer)
91
+
92
+ if __name__ == "__main__":
93
+ main()