keerthi-balaji commited on
Commit
47465ab
1 Parent(s): 118896e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -12
app.py CHANGED
@@ -24,34 +24,48 @@ docs = prepare_docs(dataset)
24
 
25
  # Custom Retriever that searches in the dataset
26
  class HoroscopeRetriever(RagRetriever):
27
- def __init__(self, docs):
28
  self.docs = docs
 
29
 
30
  def retrieve(self, question_hidden_states, n_docs=1):
31
  # Convert the question_hidden_states to a text string
32
- question_hidden_states = question_hidden_states[0]
33
 
34
- if isinstance(question_hidden_states, np.ndarray):
35
- if question_hidden_states.size == 1:
36
- question = question_hidden_states.item() # Convert single-element array to scalar
37
  else:
38
- question = str(question_hidden_states[0]) # Take the first element of the array
39
  else:
40
- question = str(question_hidden_states)
41
 
42
  question = question.lower()
43
 
44
- # Simple retrieval logic: return the most relevant document based on the question
 
45
  for doc in self.docs:
46
  if question in doc["question"].lower():
47
- return [doc["answer"]]
48
- return ["Sorry, I couldn't find a relevant horoscope."]
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  # Initialize the custom retriever with the dataset
51
- retriever = HoroscopeRetriever(docs)
 
52
 
53
  # Initialize RAG components
54
- tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")
55
  model = RagTokenForGeneration.from_pretrained("facebook/rag-token-base", retriever=retriever)
56
 
57
  # Define the chatbot function
 
24
 
25
  # Custom Retriever that searches in the dataset
26
  class HoroscopeRetriever(RagRetriever):
27
+ def __init__(self, docs, tokenizer):
28
  self.docs = docs
29
+ self.tokenizer = tokenizer
30
 
31
  def retrieve(self, question_hidden_states, n_docs=1):
32
  # Convert the question_hidden_states to a text string
33
+ question = question_hidden_states[0]
34
 
35
+ if isinstance(question, np.ndarray):
36
+ if question.size == 1:
37
+ question = question.item() # Convert single-element array to scalar
38
  else:
39
+ question = str(question[0]) # Take the first element of the array
40
  else:
41
+ question = str(question)
42
 
43
  question = question.lower()
44
 
45
+ # Simple retrieval logic: find the most relevant document based on the question
46
+ best_match = None
47
  for doc in self.docs:
48
  if question in doc["question"].lower():
49
+ best_match = doc
50
+ break
51
+
52
+ if best_match:
53
+ # Fake embedding as RAG expects this (In a real case, compute embeddings)
54
+ retrieved_doc_embeds = torch.zeros((1, 1, 768)) # Example tensor
55
+ doc_ids = ["0"] # Example document ID
56
+ docs = [best_match["answer"]]
57
+ else:
58
+ retrieved_doc_embeds = torch.zeros((1, 1, 768)) # Example tensor
59
+ doc_ids = ["0"] # Example document ID
60
+ docs = ["Sorry, I couldn't find a relevant horoscope."]
61
+
62
+ return retrieved_doc_embeds, doc_ids, docs
63
 
64
  # Initialize the custom retriever with the dataset
65
+ tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")
66
+ retriever = HoroscopeRetriever(docs, tokenizer)
67
 
68
  # Initialize RAG components
 
69
  model = RagTokenForGeneration.from_pretrained("facebook/rag-token-base", retriever=retriever)
70
 
71
  # Define the chatbot function