File size: 2,013 Bytes
1b124c2
 
 
8e28d8f
1b124c2
 
 
 
8e28d8f
 
 
 
1b124c2
 
 
8e28d8f
 
1b124c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from datasets import load_dataset
import gradio as gr
from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration
import torch

# Load the horoscope-chat dataset
dataset = load_dataset("chloeliu/horoscope-chat", split="train")

# Check dataset structure
print(dataset.column_names)

# Assuming 'input' and 'response' are the correct keys:
def prepare_docs(dataset):
    docs = []
    for data in dataset:
        question = data.get('input', '')  # Safely access the 'input' field
        answer = data.get('response', '')  # Safely access the 'response' field
        docs.append({
            "question": question,
            "answer": answer
        })
    return docs

# Prepare the documents
docs = prepare_docs(dataset)

# Custom Retriever that searches in the dataset
class HoroscopeRetriever(RagRetriever):
    def __init__(self, docs):
        self.docs = docs

    def retrieve(self, question_texts, n_docs=1):
        # Simple retrieval logic: return the most relevant document based on the question
        question = question_texts[0].lower()
        for doc in self.docs:
            if question in doc["question"].lower():
                return [doc["answer"]]
        return ["Sorry, I couldn't find a relevant horoscope."]

# Initialize the custom retriever with the dataset
retriever = HoroscopeRetriever(docs)

# Initialize RAG components
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")
model = RagTokenForGeneration.from_pretrained("facebook/rag-token-base", retriever=retriever)

# Define the chatbot function
def horoscope_chatbot(input_text):
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids
    generated_ids = model.generate(input_ids=input_ids)
    generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return generated_text

# Set up Gradio interface
iface = gr.Interface(fn=horoscope_chatbot, inputs="text", outputs="text", title="Horoscope RAG Chatbot")

# Launch the interface
iface.launch()