Shankarm08 commited on
Commit
4995935
·
verified ·
1 Parent(s): 930f177

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -53
app.py CHANGED
@@ -1,87 +1,120 @@
1
  import streamlit as st
2
- import torch
3
- from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
4
- from datasets import load_dataset
5
  import pandas as pd
6
  import pdfplumber
 
 
7
  import numpy as np
8
- from sklearn.metrics.pairwise import cosine_similarity
 
 
 
 
 
 
 
 
9
 
10
- # Load RAG model, tokenizer, and retriever
11
- tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
12
- retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", use_dummy_dataset=True)
13
- model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever)
14
 
15
- # Function to get RAG embeddings
16
- def get_rag_embeddings(question, context):
17
- inputs = tokenizer(question, context, return_tensors="pt", truncation=True)
18
- with torch.no_grad():
19
- output = model.generate(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
20
- return tokenizer.batch_decode(output, skip_special_tokens=True)[0]
21
 
22
- # Extract text from PDF
23
  def extract_text_from_pdf(pdf_file):
 
24
  with pdfplumber.open(pdf_file) as pdf:
25
- text = ""
26
  for page in pdf.pages:
27
  page_text = page.extract_text()
28
- if page_text: # Check if the page has extractable text
29
  text += page_text + "\n"
30
  return text
31
 
32
- # Load dataset (using SQuAD v2 as a placeholder)
33
- def load_squad_v2():
34
- return load_dataset('squad_v2')
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- # Store the PDF text and embeddings
37
- pdf_text = ""
38
- pdf_embeddings = None
39
- csv_data = None
 
 
 
40
 
41
- # Streamlit app UI
42
- st.title("RAG-Powered PDF & CSV Chatbot")
43
 
44
- # CSV file upload
45
  csv_file = st.file_uploader("Upload a CSV file", type=["csv"])
 
46
  if csv_file:
47
  csv_data = pd.read_csv(csv_file)
48
- st.write("CSV file loaded successfully!")
49
  st.write(csv_data)
 
50
 
51
- # PDF file upload
52
  pdf_file = st.file_uploader("Upload a PDF file", type=["pdf"])
 
53
  if pdf_file:
54
  pdf_text = extract_text_from_pdf(pdf_file)
55
  if pdf_text.strip():
56
- st.success("PDF loaded successfully!")
57
- st.text_area("Extracted Text from PDF", pdf_text, height=200)
58
  else:
59
  st.warning("No extractable text found in the PDF.")
60
 
61
- # Load the SQuAD v2 dataset as an example for RAG retrieval
62
- dataset = load_squad_v2()
 
 
 
63
 
64
- # User input for chatbot
65
- user_input = st.text_input("Ask a question related to the PDF or CSV:")
66
 
67
- # Get response on button click
68
- if st.button("Get Response"):
69
- if not pdf_text and csv_data is None:
70
- st.warning("Please upload a PDF or CSV file first.")
71
- else:
72
- # Combine PDF text and CSV content for context in RAG
73
- combined_context = ""
74
- if pdf_text:
75
- combined_context += pdf_text
76
- if csv_data is not None:
77
- combined_context += "\n" + csv_data.to_string()
78
-
79
- # Get RAG-generated response
80
- try:
81
- response = get_rag_embeddings(user_input, combined_context)
 
 
 
 
 
 
 
 
 
82
  st.write("### Response:")
83
- st.write(response)
84
- except Exception as e:
85
- st.error(f"Error while processing the question: {e}")
 
 
86
 
87
 
 
1
  import streamlit as st
 
 
 
2
  import pandas as pd
3
  import pdfplumber
4
+ import torch
5
+ import faiss
6
  import numpy as np
7
+ from transformers import pipeline
8
+ from sentence_transformers import SentenceTransformer
9
+
10
+ # Load the Sentence Transformer model for embeddings
11
+ @st.cache_resource
12
+ def load_embedder():
13
+ return SentenceTransformer('all-MiniLM-L6-v2')
14
+
15
+ embedder = load_embedder()
16
 
17
+ # Load a generative model for answer generation
18
+ @st.cache_resource
19
+ def load_generator():
20
+ return pipeline('text-generation', model='gpt2', tokenizer='gpt2', device=0 if torch.cuda.is_available() else -1)
21
 
22
+ generator = load_generator()
 
 
 
 
 
23
 
24
+ # Function to extract text from PDF
25
  def extract_text_from_pdf(pdf_file):
26
+ text = ""
27
  with pdfplumber.open(pdf_file) as pdf:
 
28
  for page in pdf.pages:
29
  page_text = page.extract_text()
30
+ if page_text:
31
  text += page_text + "\n"
32
  return text
33
 
34
+ # Function to split text into chunks
35
+ def split_text(text, chunk_size=500):
36
+ sentences = text.split('. ')
37
+ chunks = []
38
+ current_chunk = ""
39
+ for sentence in sentences:
40
+ if len(current_chunk) + len(sentence) <= chunk_size:
41
+ current_chunk += sentence + ". "
42
+ else:
43
+ chunks.append(current_chunk.strip())
44
+ current_chunk = sentence + ". "
45
+ if current_chunk:
46
+ chunks.append(current_chunk.strip())
47
+ return chunks
48
 
49
+ # Function to build FAISS index
50
+ def build_faiss_index(chunks):
51
+ embeddings = embedder.encode(chunks)
52
+ embeddings = np.array(embeddings).astype('float32')
53
+ index = faiss.IndexFlatL2(embeddings.shape[1])
54
+ index.add(embeddings)
55
+ return index, embeddings
56
 
57
+ # Streamlit app
58
+ st.title("PDF and CSV Chatbot with RAG")
59
 
60
+ # Upload CSV file
61
  csv_file = st.file_uploader("Upload a CSV file", type=["csv"])
62
+ csv_text = ""
63
  if csv_file:
64
  csv_data = pd.read_csv(csv_file)
65
+ st.write("### CSV Data:")
66
  st.write(csv_data)
67
+ csv_text = csv_data.to_csv(index=False)
68
 
69
+ # Upload PDF file
70
  pdf_file = st.file_uploader("Upload a PDF file", type=["pdf"])
71
+ pdf_text = ""
72
  if pdf_file:
73
  pdf_text = extract_text_from_pdf(pdf_file)
74
  if pdf_text.strip():
75
+ st.write("### PDF Text:")
76
+ st.write(pdf_text)
77
  else:
78
  st.warning("No extractable text found in the PDF.")
79
 
80
+ # Combine texts
81
+ combined_text = csv_text + "\n" + pdf_text
82
+ if combined_text.strip():
83
+ # Split text into chunks
84
+ chunks = split_text(combined_text)
85
 
86
+ # Build FAISS index
87
+ index, embeddings = build_faiss_index(chunks)
88
 
89
+ # Prepare for user input
90
+ user_input = st.text_input("Ask a question about the uploaded data:")
91
+
92
+ if st.button("Get Response"):
93
+ if user_input.strip():
94
+ # Get embedding of user question
95
+ question_embedding = embedder.encode([user_input])
96
+ question_embedding = np.array(question_embedding).astype('float32')
97
+
98
+ # Search FAISS index
99
+ k = 3 # number of nearest neighbors
100
+ distances, indices = index.search(question_embedding, k)
101
+
102
+ # Retrieve the most relevant chunks
103
+ retrieved_chunks = [chunks[idx] for idx in indices[0]]
104
+
105
+ # Combine retrieved chunks
106
+ context = " ".join(retrieved_chunks)
107
+
108
+ # Generate answer
109
+ prompt = context + "\n\nQuestion: " + user_input + "\nAnswer:"
110
+ response = generator(prompt, max_length=200, num_return_sequences=1)
111
+
112
+ # Display response
113
  st.write("### Response:")
114
+ st.write(response[0]['generated_text'].split("Answer:")[1].strip())
115
+ else:
116
+ st.warning("Please enter a question.")
117
+ else:
118
+ st.info("Please upload a CSV file or a PDF file to proceed.")
119
 
120