Shankarm08 commited on
Commit
93a3da9
1 Parent(s): 80bf310

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -93
app.py CHANGED
@@ -1,120 +1,73 @@
1
  import streamlit as st
 
 
 
2
  import pandas as pd
3
  import pdfplumber
4
- import torch
5
- import faiss
6
- import numpy as np
7
- from transformers import pipeline
8
- from sentence_transformers import SentenceTransformer
9
-
10
- # Load the Sentence Transformer model for embeddings
11
- @st.cache_resource
12
- def load_embedder():
13
- return SentenceTransformer('all-MiniLM-L6-v2')
14
 
15
- embedder = load_embedder()
 
 
 
16
 
17
- # Load a generative model for answer generation
18
- @st.cache_resource
19
- def load_generator():
20
- return pipeline('text-generation', model='gpt2', tokenizer='gpt2', device=0 if torch.cuda.is_available() else -1)
 
 
21
 
22
- generator = load_generator()
23
-
24
- # Function to extract text from PDF
25
  def extract_text_from_pdf(pdf_file):
26
- text = ""
27
  with pdfplumber.open(pdf_file) as pdf:
 
28
  for page in pdf.pages:
29
  page_text = page.extract_text()
30
- if page_text:
31
  text += page_text + "\n"
32
- return text
33
-
34
- # Function to split text into chunks
35
- def split_text(text, chunk_size=500):
36
- sentences = text.split('. ')
37
- chunks = []
38
- current_chunk = ""
39
- for sentence in sentences:
40
- if len(current_chunk) + len(sentence) <= chunk_size:
41
- current_chunk += sentence + ". "
42
- else:
43
- chunks.append(current_chunk.strip())
44
- current_chunk = sentence + ". "
45
- if current_chunk:
46
- chunks.append(current_chunk.strip())
47
- return chunks
48
 
49
- # Function to build FAISS index
50
- def build_faiss_index(chunks):
51
- embeddings = embedder.encode(chunks)
52
- embeddings = np.array(embeddings).astype('float32')
53
- index = faiss.IndexFlatL2(embeddings.shape[1])
54
- index.add(embeddings)
55
- return index, embeddings
56
 
57
- # Streamlit app
58
- st.title("PDF and CSV Chatbot with RAG")
59
 
60
- # Upload CSV file
61
  csv_file = st.file_uploader("Upload a CSV file", type=["csv"])
62
- csv_text = ""
63
  if csv_file:
64
  csv_data = pd.read_csv(csv_file)
65
- st.write("### CSV Data:")
66
  st.write(csv_data)
67
- csv_text = csv_data.to_csv(index=False)
68
 
69
- # Upload PDF file
70
  pdf_file = st.file_uploader("Upload a PDF file", type=["pdf"])
71
- pdf_text = ""
72
  if pdf_file:
73
  pdf_text = extract_text_from_pdf(pdf_file)
74
- if pdf_text.strip():
75
- st.write("### PDF Text:")
76
- st.write(pdf_text)
77
  else:
78
  st.warning("No extractable text found in the PDF.")
79
 
80
- # Combine texts
81
- combined_text = csv_text + "\n" + pdf_text
82
- if combined_text.strip():
83
- # Split text into chunks
84
- chunks = split_text(combined_text)
85
-
86
- # Build FAISS index
87
- index, embeddings = build_faiss_index(chunks)
88
-
89
- # Prepare for user input
90
- user_input = st.text_input("Ask a question about the uploaded data:")
91
 
92
- if st.button("Get Response"):
93
- if user_input.strip():
94
- # Get embedding of user question
95
- question_embedding = embedder.encode([user_input])
96
- question_embedding = np.array(question_embedding).astype('float32')
97
-
98
- # Search FAISS index
99
- k = 3 # number of nearest neighbors
100
- distances, indices = index.search(question_embedding, k)
101
-
102
- # Retrieve the most relevant chunks
103
- retrieved_chunks = [chunks[idx] for idx in indices[0]]
104
-
105
- # Combine retrieved chunks
106
- context = " ".join(retrieved_chunks)
107
-
108
- # Generate answer
109
- prompt = context + "\n\nQuestion: " + user_input + "\nAnswer:"
110
- response = generator(prompt, max_length=200, num_return_sequences=1)
111
-
112
- # Display response
113
  st.write("### Response:")
114
- st.write(response[0]['generated_text'].split("Answer:")[1].strip())
115
- else:
116
- st.warning("Please enter a question.")
117
- else:
118
- st.info("Please upload a CSV file or a PDF file to proceed.")
119
-
120
-
 
1
  import streamlit as st
2
+ import torch
3
+ from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
4
+ from datasets import load_dataset
5
  import pandas as pd
6
  import pdfplumber
 
 
 
 
 
 
 
 
 
 
7
 
8
+ # Load RAG model, tokenizer, and retriever
9
+ tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
10
+ retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", use_dummy_dataset=True)
11
+ model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever)
12
 
13
+ # Function to get RAG embeddings
14
+ def get_rag_embeddings(question, context):
15
+ inputs = tokenizer(question, context, return_tensors="pt", truncation=True)
16
+ with torch.no_grad():
17
+ output = model.generate(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
18
+ return tokenizer.batch_decode(output, skip_special_tokens=True)[0]
19
 
20
+ # Extract text from PDF
 
 
21
  def extract_text_from_pdf(pdf_file):
 
22
  with pdfplumber.open(pdf_file) as pdf:
23
+ text = ""
24
  for page in pdf.pages:
25
  page_text = page.extract_text()
26
+ if page_text: # Check if the page has extractable text
27
  text += page_text + "\n"
28
+ return text.strip() # Return stripped text for better formatting
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ # Store the PDF text and embeddings
31
+ pdf_text = ""
32
+ csv_data = None
 
 
 
 
33
 
34
+ # Streamlit app UI
35
+ st.title("RAG-Powered PDF & CSV Chatbot")
36
 
37
+ # CSV file upload
38
  csv_file = st.file_uploader("Upload a CSV file", type=["csv"])
 
39
  if csv_file:
40
  csv_data = pd.read_csv(csv_file)
41
+ st.write("CSV file loaded successfully!")
42
  st.write(csv_data)
 
43
 
44
+ # PDF file upload
45
  pdf_file = st.file_uploader("Upload a PDF file", type=["pdf"])
 
46
  if pdf_file:
47
  pdf_text = extract_text_from_pdf(pdf_file)
48
+ if pdf_text:
49
+ st.success("PDF loaded successfully!")
50
+ st.text_area("Extracted Text from PDF", pdf_text, height=200)
51
  else:
52
  st.warning("No extractable text found in the PDF.")
53
 
54
+ # User input for chatbot
55
+ user_input = st.text_input("Ask a question related to the PDF or CSV:")
 
 
 
 
 
 
 
 
 
56
 
57
+ # Get response on button click
58
+ if st.button("Get Response"):
59
+ if not pdf_text and csv_data is None:
60
+ st.warning("Please upload a PDF or CSV file first.")
61
+ else:
62
+ # Combine PDF text and CSV content for context in RAG
63
+ combined_context = pdf_text
64
+ if csv_data is not None:
65
+ combined_context += "\n" + csv_data.to_string()
66
+
67
+ # Get RAG-generated response
68
+ try:
69
+ response = get_rag_embeddings(user_input, combined_context)
 
 
 
 
 
 
 
 
70
  st.write("### Response:")
71
+ st.write(response)
72
+ except Exception as e:
73
+ st.error(f"Error while processing the question: {e}")