|
from datasets import load_dataset |
|
import gradio as gr |
|
from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration |
|
import torch |
|
|
|
|
|
dataset = load_dataset("chloeliu/horoscope-chat", split="train") |
|
|
|
|
|
print(dataset.column_names) |
|
|
|
|
|
def prepare_docs(dataset): |
|
docs = [] |
|
for data in dataset: |
|
question = data.get('input', '') |
|
answer = data.get('response', '') |
|
docs.append({ |
|
"question": question, |
|
"answer": answer |
|
}) |
|
return docs |
|
|
|
|
|
docs = prepare_docs(dataset) |
|
|
|
|
|
class HoroscopeRetriever(RagRetriever): |
|
def __init__(self, docs): |
|
self.docs = docs |
|
|
|
def retrieve(self, question_texts, n_docs=1): |
|
|
|
question = question_texts[0].lower() |
|
for doc in self.docs: |
|
if question in doc["question"].lower(): |
|
return [doc["answer"]] |
|
return ["Sorry, I couldn't find a relevant horoscope."] |
|
|
|
|
|
retriever = HoroscopeRetriever(docs) |
|
|
|
|
|
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base") |
|
model = RagTokenForGeneration.from_pretrained("facebook/rag-token-base", retriever=retriever) |
|
|
|
|
|
def horoscope_chatbot(input_text): |
|
input_ids = tokenizer(input_text, return_tensors="pt").input_ids |
|
generated_ids = model.generate(input_ids=input_ids) |
|
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
return generated_text |
|
|
|
|
|
iface = gr.Interface(fn=horoscope_chatbot, inputs="text", outputs="text", title="Horoscope RAG Chatbot") |
|
|
|
|
|
iface.launch() |
|
|