from datasets import load_dataset import gradio as gr from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration import torch # Load the horoscope-chat dataset dataset = load_dataset("chloeliu/horoscope-chat", split="train") # Check dataset structure print(dataset.column_names) # Assuming 'input' and 'response' are the correct keys: def prepare_docs(dataset): docs = [] for data in dataset: question = data.get('input', '') # Safely access the 'input' field answer = data.get('response', '') # Safely access the 'response' field docs.append({ "question": question, "answer": answer }) return docs # Prepare the documents docs = prepare_docs(dataset) # Custom Retriever that searches in the dataset class HoroscopeRetriever(RagRetriever): def __init__(self, docs): self.docs = docs def retrieve(self, question_texts, n_docs=1): # Simple retrieval logic: return the most relevant document based on the question question = question_texts[0].lower() for doc in self.docs: if question in doc["question"].lower(): return [doc["answer"]] return ["Sorry, I couldn't find a relevant horoscope."] # Initialize the custom retriever with the dataset retriever = HoroscopeRetriever(docs) # Initialize RAG components tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base") model = RagTokenForGeneration.from_pretrained("facebook/rag-token-base", retriever=retriever) # Define the chatbot function def horoscope_chatbot(input_text): input_ids = tokenizer(input_text, return_tensors="pt").input_ids generated_ids = model.generate(input_ids=input_ids) generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] return generated_text # Set up Gradio interface iface = gr.Interface(fn=horoscope_chatbot, inputs="text", outputs="text", title="Horoscope RAG Chatbot") # Launch the interface iface.launch()