import streamlit as st import torch import pandas as pd import pdfplumber from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration # Ensure you are logged in with `huggingface-cli login` token = "YOUR_HUGGINGFACE_TOKEN" # Optional if you have logged in via CLI # Load the tokenizer and model for RAG tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq", use_auth_token=token) retriever = RagRetriever.from_pretrained("facebook/wikipedia-dpr", use_auth_token=token) model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever, use_auth_token=token) # Function to extract text from a PDF file def extract_text_from_pdf(pdf_file): text = "" with pdfplumber.open(pdf_file) as pdf: for page in pdf.pages: page_text = page.extract_text() if page_text: text += page_text + "\n" return text.strip() # Streamlit app st.title("RAG-Powered PDF & CSV Chatbot") # CSV file upload csv_file = st.file_uploader("Upload a CSV file", type=["csv"]) csv_data = None if csv_file: csv_data = pd.read_csv(csv_file) st.write("CSV file loaded successfully!") st.write(csv_data) # PDF file upload pdf_file = st.file_uploader("Upload a PDF file", type=["pdf"]) pdf_text = "" if pdf_file: pdf_text = extract_text_from_pdf(pdf_file) if pdf_text: st.success("PDF loaded successfully!") st.text_area("Extracted Text from PDF", pdf_text, height=200) else: st.warning("No extractable text found in the PDF.") # User input for chatbot user_input = st.text_input("Ask a question related to the PDF or CSV:") # Get response on button click if st.button("Get Response"): if not pdf_text and csv_data is None: st.warning("Please upload a PDF or CSV file first.") else: combined_context = pdf_text if csv_data is not None: combined_context += "\n" + csv_data.to_string() # Generate response using RAG inputs = tokenizer(user_input, return_tensors="pt", truncation=True) with torch.no_grad(): output = model.generate(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask']) response = tokenizer.batch_decode(output, skip_special_tokens=True)[0] st.write("### Response:") st.write(response)