Shankarm08 commited on
Commit
a52a9bb
1 Parent(s): 7108a73

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -77
app.py CHANGED
@@ -1,110 +1,83 @@
1
  import streamlit as st
2
- import pandas as pd
3
  import torch
4
- import faiss
5
- import numpy as np
6
- from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration
7
  import pdfplumber
8
- import pytesseract
9
  from sklearn.metrics.pairwise import cosine_similarity
10
 
11
- # Load the RAG tokenizer and model
12
- tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
13
- retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True)
14
- model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq")
15
 
16
- # Function to get embeddings for FAISS index
17
- def get_faiss_index(data_chunks):
18
- embeddings = [retriever.question_encoder_tokenizer(chunk, return_tensors="pt").input_ids for chunk in data_chunks]
19
- embeddings = torch.cat(embeddings).detach().numpy()
 
 
20
 
21
- # Build FAISS index
22
- index = faiss.IndexFlatL2(embeddings.shape[1]) # L2 distance
23
- index.add(embeddings)
24
- return index, embeddings
25
-
26
- # Extract text and tables from PDF (with OCR fallback)
27
  def extract_text_from_pdf(pdf_file):
28
- text = ""
29
  with pdfplumber.open(pdf_file) as pdf:
30
- for page_num, page in enumerate(pdf.pages, 1):
 
31
  page_text = page.extract_text()
32
- if page_text:
33
  text += page_text + "\n"
34
- else:
35
- st.warning(f"No extractable text found on page {page_num}. Using OCR...")
36
- page_image = page.to_image().original
37
- ocr_text = pytesseract.image_to_string(page_image)
38
- if ocr_text.strip():
39
- text += ocr_text + "\n"
40
- else:
41
- st.error(f"Even OCR couldn't extract text from page {page_num}.")
42
  return text
43
 
44
- # Function to process input for RAG model
45
- def generate_rag_response(user_input, data_chunks):
46
- inputs = tokenizer([user_input], return_tensors="pt")
47
- retrieved_docs = retriever(input_ids=inputs['input_ids'], n_docs=5)
48
- outputs = model.generate(input_ids=inputs['input_ids'], context_input_ids=retrieved_docs['context_input_ids'])
49
- return tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
 
 
50
 
51
- # Streamlit app
52
- st.title("CSV and PDF Chatbot with RAG")
53
 
54
  # CSV file upload
55
  csv_file = st.file_uploader("Upload a CSV file", type=["csv"])
56
- csv_data = None
57
-
58
  if csv_file:
59
  csv_data = pd.read_csv(csv_file)
60
- st.success("CSV loaded successfully!")
61
- st.write("### CSV Data:")
62
  st.write(csv_data)
63
 
64
  # PDF file upload
65
  pdf_file = st.file_uploader("Upload a PDF file", type=["pdf"])
66
- pdf_text = ""
67
- data_chunks = []
68
-
69
  if pdf_file:
70
  pdf_text = extract_text_from_pdf(pdf_file)
71
-
72
- if not pdf_text.strip():
73
- st.error("The extracted PDF text is empty. Please upload a PDF with extractable text.")
74
- else:
75
  st.success("PDF loaded successfully!")
76
- st.write("### Extracted Text:")
77
- st.write(pdf_text)
78
-
79
- # Split the extracted text into chunks for FAISS
80
- data_chunks = pdf_text.split('\n')
81
- st.write("### Extracted Chunks:")
82
- for chunk in data_chunks[:5]: # Display first 5 chunks
83
- st.write(chunk)
84
 
85
  # User input for chatbot
86
- user_input = st.text_input("Ask a question about the CSV or PDF:")
87
 
 
88
  if st.button("Get Response"):
89
- if csv_data is None and not data_chunks:
90
- st.warning("Please upload both a CSV and PDF file first.")
91
- elif not user_input.strip():
92
- st.warning("Please enter a question.")
93
  else:
 
 
 
 
 
 
 
 
94
  try:
95
- if csv_data is not None:
96
- # Check if the query is related to CSV content
97
- csv_response = csv_data[csv_data.apply(lambda row: row.astype(str).str.contains(user_input, case=False).any(), axis=1)]
98
- if not csv_response.empty:
99
- st.write("### CSV Response:")
100
- st.write(csv_response)
101
- else:
102
- st.write("No relevant data found in the CSV.")
103
-
104
- if data_chunks:
105
- # Generate response using RAG for PDF content
106
- response = generate_rag_response(user_input, data_chunks)
107
- st.write("### PDF Response:")
108
- st.write(response)
109
  except Exception as e:
110
- st.error(f"Error while processing user input: {e}")
 
 
1
  import streamlit as st
 
2
  import torch
3
+ from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
4
+ from datasets import load_dataset
5
+ import pandas as pd
6
  import pdfplumber
7
+ import numpy as np
8
  from sklearn.metrics.pairwise import cosine_similarity
9
 
10
+ # Load RAG model, tokenizer, and retriever
11
+ tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
12
+ retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", use_dummy_dataset=True)
13
+ model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever)
14
 
15
+ # Function to get RAG embeddings
16
+ def get_rag_embeddings(question, context):
17
+ inputs = tokenizer(question, context, return_tensors="pt", truncation=True)
18
+ with torch.no_grad():
19
+ output = model.generate(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
20
+ return tokenizer.batch_decode(output, skip_special_tokens=True)[0]
21
 
22
+ # Extract text from PDF
 
 
 
 
 
23
  def extract_text_from_pdf(pdf_file):
 
24
  with pdfplumber.open(pdf_file) as pdf:
25
+ text = ""
26
+ for page in pdf.pages:
27
  page_text = page.extract_text()
28
+ if page_text: # Check if the page has extractable text
29
  text += page_text + "\n"
 
 
 
 
 
 
 
 
30
  return text
31
 
32
+ # Load dataset (wiki_dpr) and set trust_remote_code=True
33
+ def load_wiki_dpr():
34
+ return load_dataset('wiki_dpr', trust_remote_code=True)
35
+
36
+ # Store the PDF text and embeddings
37
+ pdf_text = ""
38
+ pdf_embeddings = None
39
+ csv_data = None
40
 
41
+ # Streamlit app UI
42
+ st.title("RAG-Powered PDF & CSV Chatbot")
43
 
44
  # CSV file upload
45
  csv_file = st.file_uploader("Upload a CSV file", type=["csv"])
 
 
46
  if csv_file:
47
  csv_data = pd.read_csv(csv_file)
48
+ st.write("CSV file loaded successfully!")
 
49
  st.write(csv_data)
50
 
51
  # PDF file upload
52
  pdf_file = st.file_uploader("Upload a PDF file", type=["pdf"])
 
 
 
53
  if pdf_file:
54
  pdf_text = extract_text_from_pdf(pdf_file)
55
+ if pdf_text.strip():
 
 
 
56
  st.success("PDF loaded successfully!")
57
+ st.text_area("Extracted Text from PDF", pdf_text, height=200)
58
+ else:
59
+ st.warning("No extractable text found in the PDF.")
 
 
 
 
 
60
 
61
  # User input for chatbot
62
+ user_input = st.text_input("Ask a question related to the PDF or CSV:")
63
 
64
+ # Get response on button click
65
  if st.button("Get Response"):
66
+ if not pdf_text and csv_data is None:
67
+ st.warning("Please upload a PDF or CSV file first.")
 
 
68
  else:
69
+ # Combine PDF text and CSV content for context in RAG
70
+ combined_context = ""
71
+ if pdf_text:
72
+ combined_context += pdf_text
73
+ if csv_data is not None:
74
+ combined_context += "\n" + csv_data.to_string()
75
+
76
+ # Get RAG-generated response
77
  try:
78
+ response = get_rag_embeddings(user_input, combined_context)
79
+ st.write("### Response:")
80
+ st.write(response)
 
 
 
 
 
 
 
 
 
 
 
81
  except Exception as e:
82
+ st.error(f"Error while processing the question: {e}")
83
+