pdfcsvdatarag / app.py
Shankarm08's picture
Update app.py
34d53c3 verified
raw
history blame
2.36 kB
import streamlit as st
import torch
import pandas as pd
import pdfplumber
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
# Ensure you are logged in with `huggingface-cli login`
token = "YOUR_HUGGINGFACE_TOKEN" # Optional if you have logged in via CLI
# Load the tokenizer and model for RAG
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq", use_auth_token=token)
retriever = RagRetriever.from_pretrained("facebook/wikipedia-dpr", use_auth_token=token)
model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever, use_auth_token=token)
# Function to extract text from a PDF file
def extract_text_from_pdf(pdf_file):
text = ""
with pdfplumber.open(pdf_file) as pdf:
for page in pdf.pages:
page_text = page.extract_text()
if page_text:
text += page_text + "\n"
return text.strip()
# Streamlit app
st.title("RAG-Powered PDF & CSV Chatbot")
# CSV file upload
csv_file = st.file_uploader("Upload a CSV file", type=["csv"])
csv_data = None
if csv_file:
csv_data = pd.read_csv(csv_file)
st.write("CSV file loaded successfully!")
st.write(csv_data)
# PDF file upload
pdf_file = st.file_uploader("Upload a PDF file", type=["pdf"])
pdf_text = ""
if pdf_file:
pdf_text = extract_text_from_pdf(pdf_file)
if pdf_text:
st.success("PDF loaded successfully!")
st.text_area("Extracted Text from PDF", pdf_text, height=200)
else:
st.warning("No extractable text found in the PDF.")
# User input for chatbot
user_input = st.text_input("Ask a question related to the PDF or CSV:")
# Get response on button click
if st.button("Get Response"):
if not pdf_text and csv_data is None:
st.warning("Please upload a PDF or CSV file first.")
else:
combined_context = pdf_text
if csv_data is not None:
combined_context += "\n" + csv_data.to_string()
# Generate response using RAG
inputs = tokenizer(user_input, return_tensors="pt", truncation=True)
with torch.no_grad():
output = model.generate(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
response = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
st.write("### Response:")
st.write(response)