|
import gradio as gr |
|
from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration |
|
import json |
|
import numpy as np |
|
|
|
|
|
with open("horoscope_data.json", "r") as file: |
|
horoscope_data = json.load(file) |
|
|
|
|
|
class CustomHoroscopeRetriever(RagRetriever): |
|
def __init__(self, horoscope_data): |
|
self.horoscope_data = horoscope_data |
|
|
|
def retrieve(self, question_texts, n_docs=1): |
|
|
|
if isinstance(question_texts, np.ndarray): |
|
question_texts = question_texts.tolist() |
|
|
|
|
|
if isinstance(question_texts, list): |
|
question_texts = question_texts[0] |
|
if isinstance(question_texts, list): |
|
question_texts = question_texts[0] |
|
|
|
|
|
if isinstance(question_texts, str): |
|
zodiac_sign = question_texts |
|
else: |
|
return ["I couldn't process your request. Please try again with a valid zodiac sign."] |
|
|
|
if zodiac_sign in self.horoscope_data: |
|
return [self.horoscope_data[zodiac_sign]] |
|
else: |
|
return ["I couldn't find your zodiac sign. Please try again with a valid one."] |
|
|
|
|
|
retriever = CustomHoroscopeRetriever(horoscope_data) |
|
|
|
|
|
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base") |
|
model = RagTokenForGeneration.from_pretrained("facebook/rag-token-base", retriever=retriever) |
|
|
|
|
|
def horoscope_chatbot(input_text): |
|
input_ids = tokenizer(input_text, return_tensors="pt").input_ids |
|
generated_ids = model.generate(input_ids=input_ids) |
|
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
return generated_text |
|
|
|
|
|
iface = gr.Interface(fn=horoscope_chatbot, inputs="text", outputs="text", title="Horoscope RAG Chatbot") |
|
|
|
|
|
iface.launch(share=True) |
|
|