keerthi-balaji's picture
Update app.py
2483b3f verified
raw
history blame contribute delete
No virus
2.21 kB
import gradio as gr
from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration
import json
import numpy as np
# Load horoscope data
with open("horoscope_data.json", "r") as file:
horoscope_data = json.load(file)
# Custom Retriever that looks up horoscopes
class CustomHoroscopeRetriever(RagRetriever):
def __init__(self, horoscope_data):
self.horoscope_data = horoscope_data
def retrieve(self, question_texts, n_docs=1):
# Convert numpy arrays to lists if needed
if isinstance(question_texts, np.ndarray):
question_texts = question_texts.tolist()
# Ensure question_texts is a list of strings
if isinstance(question_texts, list):
question_texts = question_texts[0] # Get the first element
if isinstance(question_texts, list): # If it's still a list, get the first string
question_texts = question_texts[0]
# Ensure the text is a string
if isinstance(question_texts, str):
zodiac_sign = question_texts # Use as-is
else:
return ["I couldn't process your request. Please try again with a valid zodiac sign."]
if zodiac_sign in self.horoscope_data:
return [self.horoscope_data[zodiac_sign]]
else:
return ["I couldn't find your zodiac sign. Please try again with a valid one."]
# Initialize the custom retriever with the horoscope data
retriever = CustomHoroscopeRetriever(horoscope_data)
# Initialize RAG components
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")
model = RagTokenForGeneration.from_pretrained("facebook/rag-token-base", retriever=retriever)
# Define the chatbot function
def horoscope_chatbot(input_text):
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
generated_ids = model.generate(input_ids=input_ids)
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_text
# Set up Gradio interface
iface = gr.Interface(fn=horoscope_chatbot, inputs="text", outputs="text", title="Horoscope RAG Chatbot")
# Launch the interface with public link
iface.launch(share=True)