Spaces:
Sleeping
Sleeping
keerthi-balaji
commited on
Commit
•
47465ab
1
Parent(s):
118896e
Update app.py
Browse files
app.py
CHANGED
@@ -24,34 +24,48 @@ docs = prepare_docs(dataset)
|
|
24 |
|
25 |
# Custom Retriever that searches in the dataset
|
26 |
class HoroscopeRetriever(RagRetriever):
|
27 |
-
def __init__(self, docs):
|
28 |
self.docs = docs
|
|
|
29 |
|
30 |
def retrieve(self, question_hidden_states, n_docs=1):
|
31 |
# Convert the question_hidden_states to a text string
|
32 |
-
|
33 |
|
34 |
-
if isinstance(
|
35 |
-
if
|
36 |
-
question =
|
37 |
else:
|
38 |
-
question = str(
|
39 |
else:
|
40 |
-
question = str(
|
41 |
|
42 |
question = question.lower()
|
43 |
|
44 |
-
# Simple retrieval logic:
|
|
|
45 |
for doc in self.docs:
|
46 |
if question in doc["question"].lower():
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
# Initialize the custom retriever with the dataset
|
51 |
-
|
|
|
52 |
|
53 |
# Initialize RAG components
|
54 |
-
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")
|
55 |
model = RagTokenForGeneration.from_pretrained("facebook/rag-token-base", retriever=retriever)
|
56 |
|
57 |
# Define the chatbot function
|
|
|
24 |
|
25 |
# Custom Retriever that searches in the dataset
|
26 |
class HoroscopeRetriever(RagRetriever):
|
27 |
+
def __init__(self, docs, tokenizer):
|
28 |
self.docs = docs
|
29 |
+
self.tokenizer = tokenizer
|
30 |
|
31 |
def retrieve(self, question_hidden_states, n_docs=1):
|
32 |
# Convert the question_hidden_states to a text string
|
33 |
+
question = question_hidden_states[0]
|
34 |
|
35 |
+
if isinstance(question, np.ndarray):
|
36 |
+
if question.size == 1:
|
37 |
+
question = question.item() # Convert single-element array to scalar
|
38 |
else:
|
39 |
+
question = str(question[0]) # Take the first element of the array
|
40 |
else:
|
41 |
+
question = str(question)
|
42 |
|
43 |
question = question.lower()
|
44 |
|
45 |
+
# Simple retrieval logic: find the most relevant document based on the question
|
46 |
+
best_match = None
|
47 |
for doc in self.docs:
|
48 |
if question in doc["question"].lower():
|
49 |
+
best_match = doc
|
50 |
+
break
|
51 |
+
|
52 |
+
if best_match:
|
53 |
+
# Fake embedding as RAG expects this (In a real case, compute embeddings)
|
54 |
+
retrieved_doc_embeds = torch.zeros((1, 1, 768)) # Example tensor
|
55 |
+
doc_ids = ["0"] # Example document ID
|
56 |
+
docs = [best_match["answer"]]
|
57 |
+
else:
|
58 |
+
retrieved_doc_embeds = torch.zeros((1, 1, 768)) # Example tensor
|
59 |
+
doc_ids = ["0"] # Example document ID
|
60 |
+
docs = ["Sorry, I couldn't find a relevant horoscope."]
|
61 |
+
|
62 |
+
return retrieved_doc_embeds, doc_ids, docs
|
63 |
|
64 |
# Initialize the custom retriever with the dataset
|
65 |
+
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")
|
66 |
+
retriever = HoroscopeRetriever(docs, tokenizer)
|
67 |
|
68 |
# Initialize RAG components
|
|
|
69 |
model = RagTokenForGeneration.from_pretrained("facebook/rag-token-base", retriever=retriever)
|
70 |
|
71 |
# Define the chatbot function
|