File size: 4,685 Bytes
2c8ec8c
2fd4522
c2f4f95
2c66235
3984a70
2c8ec8c
a9d242c
a181300
2c8ec8c
a0d821f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c677cca
b49a872
ddf8e31
99d393b
54a0ffd
 
dec2aff
aae93e8
 
 
 
 
 
 
 
 
 
 
 
 
 
5b2ffcf
2a01b43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138




import faiss




from annoy import AnnoyIndex

# Build Annoy index
def create_annoy_index(embeddings, num_trees=10):
    index = AnnoyIndex(embeddings.shape[1], 'angular')
    for i, emb in enumerate(embeddings):
        index.add_item(i, emb)
    index.build(num_trees)
    return index

# Query Annoy index
def retrieve_relevant_text(query, annoy_index, texts, top_k=3):
    query_embedding = embedder.encode([query])[0]
    indices = annoy_index.get_nns_by_vector(query_embedding, top_k)
    return [texts[i] for i in indices]








# Function to create an Annoy index from the embeddings
def create_annoy_index(embeddings, num_trees=10):
    index = AnnoyIndex(embeddings.shape[1], 'angular')  # Using angular distance metric
    for i, emb in enumerate(embeddings):
        index.add_item(i, emb)
    index.build(num_trees)
    return index

# Function to retrieve the most relevant text using Annoy
def retrieve_relevant_text(query, annoy_index, texts, top_k=3):
    query_embedding = embedder.encode([query], convert_to_tensor=True)
    indices = annoy_index.get_nns_by_vector(query_embedding[0], top_k)
    return [texts[i] for i in indices]


import os
import fitz  # PyMuPDF for PDF extraction
import faiss  # for efficient vector search
import numpy as np
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, RagTokenizer, RagRetriever, RagSequenceForGeneration
from sentence_transformers import SentenceTransformer
import streamlit as st

# Load the pre-trained RAG model and tokenizer
model_name = "facebook/rag-token-nq"  # You can change this to a different open-source RAG model if needed
tokenizer = RagTokenizer.from_pretrained(model_name)
model = RagSequenceForGeneration.from_pretrained(model_name)

# Initialize sentence transformer model for embeddings
embedder = SentenceTransformer('all-MiniLM-L6-v2')

# Function to extract text from a PDF file
def extract_text_from_pdf(pdf_file):
    pdf_document = fitz.open(pdf_file)
    text = ""
    for page_num in range(pdf_document.page_count):
        page = pdf_document.load_page(page_num)
        text += page.get_text("text")
    return text

# Function to create embeddings from text data
def create_embeddings(text_data):
    embeddings = embedder.encode(text_data, convert_to_tensor=True)
    return embeddings

# Function to create a FAISS index from the embeddings
def create_faiss_index(embeddings):
    index = faiss.IndexFlatL2(embeddings.shape[1])  # Using L2 distance metric
    index.add(embeddings)
    return index

# Function to retrieve the most relevant text using FAISS
def retrieve_relevant_text(query, faiss_index, texts, top_k=3):
    query_embedding = embedder.encode([query], convert_to_tensor=True)
    D, I = faiss_index.search(query_embedding, top_k)  # D: distances, I: indices
    return [texts[i] for i in I[0]]

# Main function to answer questions based on uploaded PDF
def get_answer_from_pdf(pdf_file, query):
    # Step 1: Extract text from the uploaded PDF file
    document_text = extract_text_from_pdf(pdf_file)
    
    # Step 2: Split the document text into chunks (optional but recommended for large docs)
    text_chunks = document_text.split('\n')
    
    # Step 3: Create embeddings for each chunk of text
    embeddings = create_embeddings(text_chunks)
    
    # Step 4: Create a FAISS index for efficient retrieval
    faiss_index = create_faiss_index(embeddings)
    
    # Step 5: Retrieve relevant text from the document based on the query
    relevant_texts = retrieve_relevant_text(query, faiss_index, text_chunks)
    
    # Step 6: Combine the relevant text and pass it to the RAG model
    context = " ".join(relevant_texts)
    inputs = tokenizer([query], return_tensors="pt", padding=True, truncation=True)
    context_inputs = tokenizer(context, return_tensors="pt", padding=True, truncation=True)
    
    # Generate the answer
    outputs = model.generate(input_ids=inputs["input_ids"], context_input_ids=context_inputs["input_ids"])
    answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    return answer

# Streamlit UI
def main():
    st.title("RAG Application - PDF Q&A")
    
    # Upload PDF file
    uploaded_file = st.file_uploader("Upload a PDF Document", type="pdf")
    
    if uploaded_file is not None:
        # Ask a question from the uploaded PDF
        question = st.text_input("Ask a question based on the document:")
        
        if question:
            # Get the answer from the PDF document
            answer = get_answer_from_pdf(uploaded_file, question)
            
            # Display the answer
            st.write("Answer: ", answer)

if __name__ == "__main__":
    main()