keerthi-balaji's picture
Update app.py
8e28d8f verified
raw
history blame
No virus
2.01 kB
from datasets import load_dataset
import gradio as gr
from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration
import torch
# Load the horoscope-chat dataset
dataset = load_dataset("chloeliu/horoscope-chat", split="train")
# Check dataset structure
print(dataset.column_names)
# Assuming 'input' and 'response' are the correct keys:
def prepare_docs(dataset):
docs = []
for data in dataset:
question = data.get('input', '') # Safely access the 'input' field
answer = data.get('response', '') # Safely access the 'response' field
docs.append({
"question": question,
"answer": answer
})
return docs
# Prepare the documents
docs = prepare_docs(dataset)
# Custom Retriever that searches in the dataset
class HoroscopeRetriever(RagRetriever):
def __init__(self, docs):
self.docs = docs
def retrieve(self, question_texts, n_docs=1):
# Simple retrieval logic: return the most relevant document based on the question
question = question_texts[0].lower()
for doc in self.docs:
if question in doc["question"].lower():
return [doc["answer"]]
return ["Sorry, I couldn't find a relevant horoscope."]
# Initialize the custom retriever with the dataset
retriever = HoroscopeRetriever(docs)
# Initialize RAG components
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")
model = RagTokenForGeneration.from_pretrained("facebook/rag-token-base", retriever=retriever)
# Define the chatbot function
def horoscope_chatbot(input_text):
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
generated_ids = model.generate(input_ids=input_ids)
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_text
# Set up Gradio interface
iface = gr.Interface(fn=horoscope_chatbot, inputs="text", outputs="text", title="Horoscope RAG Chatbot")
# Launch the interface
iface.launch()