File size: 2,211 Bytes
1b124c2
 
71790db
2483b3f
1b124c2
71790db
 
 
1b124c2
71790db
 
 
 
1b124c2
71790db
2483b3f
 
 
 
89a4d2e
 
 
 
 
 
2483b3f
 
 
 
 
 
71790db
 
47465ab
71790db
1b124c2
71790db
 
1b124c2
 
71790db
1b124c2
 
 
 
 
 
 
 
 
 
 
 
89a4d2e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import gradio as gr
from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration
import json
import numpy as np

# Load horoscope data
with open("horoscope_data.json", "r") as file:
    horoscope_data = json.load(file)

# Custom Retriever that looks up horoscopes
class CustomHoroscopeRetriever(RagRetriever):
    def __init__(self, horoscope_data):
        self.horoscope_data = horoscope_data

    def retrieve(self, question_texts, n_docs=1):
        # Convert numpy arrays to lists if needed
        if isinstance(question_texts, np.ndarray):
            question_texts = question_texts.tolist()
        
        # Ensure question_texts is a list of strings
        if isinstance(question_texts, list):
            question_texts = question_texts[0]  # Get the first element
        if isinstance(question_texts, list):  # If it's still a list, get the first string
            question_texts = question_texts[0]
        
        # Ensure the text is a string
        if isinstance(question_texts, str):
            zodiac_sign = question_texts  # Use as-is
        else:
            return ["I couldn't process your request. Please try again with a valid zodiac sign."]
        
        if zodiac_sign in self.horoscope_data:
            return [self.horoscope_data[zodiac_sign]]
        else:
            return ["I couldn't find your zodiac sign. Please try again with a valid one."]

# Initialize the custom retriever with the horoscope data
retriever = CustomHoroscopeRetriever(horoscope_data)

# Initialize RAG components
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")
model = RagTokenForGeneration.from_pretrained("facebook/rag-token-base", retriever=retriever)

# Define the chatbot function
def horoscope_chatbot(input_text):
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids
    generated_ids = model.generate(input_ids=input_ids)
    generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return generated_text

# Set up Gradio interface
iface = gr.Interface(fn=horoscope_chatbot, inputs="text", outputs="text", title="Horoscope RAG Chatbot")

# Launch the interface with public link
iface.launch(share=True)