Finance / app.py
NSamson1's picture
Update app.py
0b6b0d8 verified
import os
import pandas as pd
import logging
from datasets import load_dataset
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
from langchain_chroma import Chroma
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# ------------------------------------------------------------------
# 1. Load and Prepare the Bank FAQ Dataset
# ------------------------------------------------------------------
# Load the dataset from Hugging Face (Bank FAQs)
ds = load_dataset("maxpro291/bankfaqs_dataset")
train_ds = ds['train']
data = train_ds[:] # load all examples
# Separate questions and answers from the 'text' field
questions = []
answers = []
for entry in data['text']:
if entry.startswith("Q:"):
questions.append(entry)
elif entry.startswith("A:"):
answers.append(entry)
# Create a DataFrame with questions and answers
Bank_Data = pd.DataFrame({'question': questions, 'answer': answers})
# Build context strings (combining question and answer) for the vector store
context_data = []
for i in range(len(Bank_Data)):
context = f"Question: {Bank_Data.iloc[i]['question']} Answer: {Bank_Data.iloc[i]['answer']}"
context_data.append(context)
# ------------------------------------------------------------------
# 2. Create the Vector Store for Retrieval
# ------------------------------------------------------------------
# Initialize the embedding model
embed_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
# Create a Chroma vector store from the context data
vectorstore = Chroma.from_texts(
texts=context_data,
embedding=embed_model,
persist_directory="./chroma_db_bank"
)
# Create a retriever from the vector store
retriever = vectorstore.as_retriever()
# ------------------------------------------------------------------
# 3. Initialize the LLM for Generation
# ------------------------------------------------------------------
# Note:
# The model "meta-llama/Llama-2-7b-chat-hf" is gated. If you have access,
# authenticate using `huggingface-cli login`. Otherwise, switch to a public model.
model_name = "gpt2" # Replace with "meta-llama/Llama-2-7b-chat-hf" if you are authenticated.
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# Create a text-generation pipeline
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_length=512,
temperature=0.7,
top_p=0.95,
repetition_penalty=1.15
)
# Wrap the pipeline in LangChain's HuggingFacePipeline
huggingface_model = HuggingFacePipeline(pipeline=pipe)
# ------------------------------------------------------------------
# 4. Build the Retrieval-Augmented Generation (RAG) Chain
# ------------------------------------------------------------------
# Define a prompt template that instructs the assistant to use provided context
template = (
"You are a helpful banking assistant. "
"Use the provided context if it is relevant to answer the question. "
"If not, answer using your general banking knowledge.\n"
"Question: {question}\n"
"Answer:"
)
rag_prompt = PromptTemplate.from_template(template)
# Build the RAG chain by piping the retriever, prompt, LLM, and an output parser
rag_chain = (
{"context": retriever, "question": RunnablePassthrough()}
| rag_prompt
| huggingface_model
| StrOutputParser()
)
# ------------------------------------------------------------------
# 5. Set Up the Gradio Chat Interface
# ------------------------------------------------------------------
def rag_memory_stream(message, history):
partial_text = ""
# Stream the generated answer
for new_text in rag_chain.stream(message):
partial_text += new_text
yield partial_text
# Example questions
examples = [
"I want to open an account",
"What is a savings account?",
"How do I use an ATM?",
"How can I resolve a bank account issue?"
]
title = "Your Personal Banking Assistant 💬"
description = (
"Welcome! I’m here to answer your questions about banking and related topics. "
"Ask me anything, and I’ll do my best to assist you."
)
# Create a chat interface using Gradio
demo = gr.ChatInterface(
fn=rag_memory_stream,
title=title,
description=description,
examples=examples,
theme="glass",
)
# ------------------------------------------------------------------
# 6. Launch the App
# ------------------------------------------------------------------
if __name__ == "__main__":
demo.launch(share=True)