JustKiddo commited on
Commit
289913c
1 Parent(s): 80eee0f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -7
app.py CHANGED
@@ -33,7 +33,7 @@ class VietnameseChatbot:
33
 
34
  # Pre-compute embeddings for faster response generation
35
  print("Pre-computing conversation embeddings...")
36
- self.conversation_embeddings = self._precompute_embeddings()
37
 
38
  def _load_conversation_data(self):
39
  """
@@ -50,7 +50,7 @@ class VietnameseChatbot:
50
  {"query": "Bạn từ đâu đến?", "response": "Tôi được phát triển bởi một nhóm kỹ sư AI, và tôn chỉ của tôi là hỗ trợ con người."},
51
 
52
  # Small talk
53
- {"query": "Bạn thích gì?", "response": "Tôi thích học hỏi và giú đỡ mọi người. Mỗi cuộc trò chuyện là một cơ hội để tôi phát triển."},
54
  {"query": "Bạn có thể làm gì?", "response": "Tôi có thể trò chuyện, trả lời câu hỏi, và hỗ trợ bạn trong nhiều tình huống khác nhau."},
55
 
56
  # Weather and time
@@ -67,16 +67,39 @@ class VietnameseChatbot:
67
  ]
68
 
69
  @st.cache_data
70
- def _precompute_embeddings(self):
71
  """
72
- Pre-compute embeddings for all conversation queries
73
  Cached to avoid recomputing on every run
74
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  embeddings = []
76
- for item in self.conversation_data:
77
- embedding = self.embed_text(item['query'])
78
  if embedding is not None:
79
- embeddings.append(embedding[0])
80
  return np.array(embeddings)
81
 
82
  def embed_text(self, text):
 
33
 
34
  # Pre-compute embeddings for faster response generation
35
  print("Pre-computing conversation embeddings...")
36
+ self.conversation_embeddings = self._compute_embeddings()
37
 
38
  def _load_conversation_data(self):
39
  """
 
50
  {"query": "Bạn từ đâu đến?", "response": "Tôi được phát triển bởi một nhóm kỹ sư AI, và tôn chỉ của tôi là hỗ trợ con người."},
51
 
52
  # Small talk
53
+ {"query": "Bạn thích gì?", "response": "Tôi thích học hỏi và giúp đỡ mọi người. Mỗi cuộc trò chuyện là một cơ hội để tôi phát triển."},
54
  {"query": "Bạn có thể làm gì?", "response": "Tôi có thể trò chuyện, trả lời câu hỏi, và hỗ trợ bạn trong nhiều tình huống khác nhau."},
55
 
56
  # Weather and time
 
67
  ]
68
 
69
  @st.cache_data
70
+ def _compute_embeddings(queries):
71
  """
72
+ Pre-compute embeddings for conversation queries
73
  Cached to avoid recomputing on every run
74
  """
75
+ def embed_single_text(text, tokenizer, model):
76
+ try:
77
+ # Tokenize and generate embeddings
78
+ inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
79
+
80
+ with torch.no_grad():
81
+ model_output = model(**inputs)
82
+
83
+ # Mean pooling
84
+ token_embeddings = model_output[0]
85
+ input_mask_expanded = inputs['attention_mask'].unsqueeze(-1).expand(token_embeddings.size()).float()
86
+ embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
87
+
88
+ return embeddings.numpy()[0]
89
+ except Exception as e:
90
+ print(f"Embedding error: {e}")
91
+ return None
92
+
93
+ # Import these arguments to make the function self-contained
94
+ from transformers import AutoTokenizer, AutoModel
95
+ tokenizer = AutoTokenizer.from_pretrained('intfloat/multilingual-e5-small')
96
+ model = AutoModel.from_pretrained('intfloat/multilingual-e5-small', torch_dtype=torch.float16)
97
+
98
  embeddings = []
99
+ for query in queries:
100
+ embedding = embed_single_text(query['query'], tokenizer, model)
101
  if embedding is not None:
102
+ embeddings.append(embedding)
103
  return np.array(embeddings)
104
 
105
  def embed_text(self, text):