keerthi-balaji commited on
Commit
1b124c2
1 Parent(s): 84115ce

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -0
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ import gradio as gr
3
+ from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration
4
+
5
+ # Load the horoscope-chat dataset
6
+ dataset = load_dataset("chloeliu/horoscope-chat", split="train")
7
+
8
+ # Convert the dataset to a format that can be used by the retriever
9
+ def prepare_docs(dataset):
10
+ docs = []
11
+ for data in dataset:
12
+ question = data['question']
13
+ answer = data['response']
14
+ docs.append({
15
+ "question": question,
16
+ "answer": answer
17
+ })
18
+ return docs
19
+
20
+ # Prepare the documents
21
+ docs = prepare_docs(dataset)
22
+
23
+ # Custom Retriever that searches in the dataset
24
+ class HoroscopeRetriever(RagRetriever):
25
+ def __init__(self, docs):
26
+ self.docs = docs
27
+
28
+ def retrieve(self, question_texts, n_docs=1):
29
+ # Simple retrieval logic: return the most relevant document based on the question
30
+ question = question_texts[0].lower()
31
+ for doc in self.docs:
32
+ if question in doc["question"].lower():
33
+ return [doc["answer"]]
34
+ return ["Sorry, I couldn't find a relevant horoscope."]
35
+
36
+ # Initialize the custom retriever with the dataset
37
+ retriever = HoroscopeRetriever(docs)
38
+
39
+ # Initialize RAG components
40
+ tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")
41
+ model = RagTokenForGeneration.from_pretrained("facebook/rag-token-base", retriever=retriever)
42
+
43
+ # Define the chatbot function
44
+ def horoscope_chatbot(input_text):
45
+ input_ids = tokenizer(input_text, return_tensors="pt").input_ids
46
+ generated_ids = model.generate(input_ids=input_ids)
47
+ generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
48
+ return generated_text
49
+
50
+ # Set up Gradio interface
51
+ iface = gr.Interface(fn=horoscope_chatbot, inputs="text", outputs="text", title="Horoscope RAG Chatbot")
52
+
53
+ # Launch the interface
54
+ iface.launch()