Shankarm08 commited on
Commit
f7f091e
1 Parent(s): b17e18d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -29
app.py CHANGED
@@ -1,24 +1,13 @@
1
  import streamlit as st
2
  import torch
3
  from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
4
- from datasets import load_dataset
5
  import pandas as pd
6
  import pdfplumber
7
 
8
- # Load RAG model and tokenizer
9
  tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
10
- retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name="exact")
11
- model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever)
12
-
13
- # Load the wiki_dpr dataset with trust_remote_code=True
14
- dataset = load_dataset("facebook/wiki_dpr", split="train", trust_remote_code=True)
15
-
16
- # Function to get RAG embeddings
17
- def get_rag_embeddings(question, context):
18
- inputs = tokenizer(question, context, return_tensors="pt", truncation=True)
19
- with torch.no_grad():
20
- output = model.generate(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
21
- return tokenizer.batch_decode(output, skip_special_tokens=True)[0]
22
 
23
  # Extract text from PDF
24
  def extract_text_from_pdf(pdf_file):
@@ -26,19 +15,16 @@ def extract_text_from_pdf(pdf_file):
26
  text = ""
27
  for page in pdf.pages:
28
  page_text = page.extract_text()
29
- if page_text: # Check if the page has extractable text
30
  text += page_text + "\n"
31
- return text.strip() # Return stripped text for better formatting
32
 
33
- # Store the PDF text and CSV data
34
- pdf_text = ""
35
- csv_data = None
36
-
37
- # Streamlit app UI
38
  st.title("RAG-Powered PDF & CSV Chatbot")
39
 
40
  # CSV file upload
41
  csv_file = st.file_uploader("Upload a CSV file", type=["csv"])
 
42
  if csv_file:
43
  csv_data = pd.read_csv(csv_file)
44
  st.write("CSV file loaded successfully!")
@@ -46,6 +32,7 @@ if csv_file:
46
 
47
  # PDF file upload
48
  pdf_file = st.file_uploader("Upload a PDF file", type=["pdf"])
 
49
  if pdf_file:
50
  pdf_text = extract_text_from_pdf(pdf_file)
51
  if pdf_text:
@@ -62,15 +49,14 @@ if st.button("Get Response"):
62
  if not pdf_text and csv_data is None:
63
  st.warning("Please upload a PDF or CSV file first.")
64
  else:
65
- # Combine PDF text and CSV content for context in RAG
66
  combined_context = pdf_text
67
  if csv_data is not None:
68
  combined_context += "\n" + csv_data.to_string()
69
 
70
- # Get RAG-generated response
71
- try:
72
- response = get_rag_embeddings(user_input, combined_context)
73
- st.write("### Response:")
74
- st.write(response)
75
- except Exception as e:
76
- st.error(f"Error while processing the question: {e}")
 
1
  import streamlit as st
2
  import torch
3
  from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
 
4
  import pandas as pd
5
  import pdfplumber
6
 
7
+ # Initialize RAG components
8
  tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
9
+ retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq")
10
+ model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq")
 
 
 
 
 
 
 
 
 
 
11
 
12
  # Extract text from PDF
13
  def extract_text_from_pdf(pdf_file):
 
15
  text = ""
16
  for page in pdf.pages:
17
  page_text = page.extract_text()
18
+ if page_text:
19
  text += page_text + "\n"
20
+ return text.strip()
21
 
22
+ # Streamlit UI
 
 
 
 
23
  st.title("RAG-Powered PDF & CSV Chatbot")
24
 
25
  # CSV file upload
26
  csv_file = st.file_uploader("Upload a CSV file", type=["csv"])
27
+ csv_data = None
28
  if csv_file:
29
  csv_data = pd.read_csv(csv_file)
30
  st.write("CSV file loaded successfully!")
 
32
 
33
  # PDF file upload
34
  pdf_file = st.file_uploader("Upload a PDF file", type=["pdf"])
35
+ pdf_text = ""
36
  if pdf_file:
37
  pdf_text = extract_text_from_pdf(pdf_file)
38
  if pdf_text:
 
49
  if not pdf_text and csv_data is None:
50
  st.warning("Please upload a PDF or CSV file first.")
51
  else:
 
52
  combined_context = pdf_text
53
  if csv_data is not None:
54
  combined_context += "\n" + csv_data.to_string()
55
 
56
+ # Generate response using RAG
57
+ inputs = tokenizer(user_input, combined_context, return_tensors="pt", truncation=True)
58
+ with torch.no_grad():
59
+ output = model.generate(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
60
+ response = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
61
+ st.write("### Response:")
62
+ st.write(response)