SirinootKK commited on
Commit
610dfca
·
1 Parent(s): 6354c71
Files changed (1) hide show
  1. app.py +18 -8
app.py CHANGED
@@ -1,11 +1,4 @@
1
  # -*- coding: utf-8 -*-
2
- """gradio_wangchanberta
3
-
4
- Automatically generated by Colaboratory.
5
-
6
- Original file is located at
7
- https://colab.research.google.com/drive/1Kw2k1oymhq4ZAcy4oBYOlIg4bBU-HlVr
8
- """
9
 
10
  #@title scirpts
11
  import time
@@ -177,6 +170,17 @@ class Chatbot:
177
  # score="{:.4f})".format(hit['score'])
178
  # Answer = self.model_pipeline(message, context)
179
  # return Answer
 
 
 
 
 
 
 
 
 
 
 
180
  def predict_semantic_search(self, message):
181
  message = message.strip()
182
  query_embedding = self.embedding_model.encode([message], convert_to_tensor=True)[0] # Fix here
@@ -221,4 +225,10 @@ demoSemantic = gr.Interface(fn=bot._chatbot.predict_semantic_search, inputs="tex
221
  demoWithoutFiss = gr.Interface(fn=bot._chatbot.predict_without_faiss, inputs="text", outputs="text",examples=EXAMPLE_PATH, title="TH wiki (just Model)")
222
 
223
  demo = gr.TabbedInterface([demoFaiss, demoWithoutFiss, demoBert, demoSemantic], ["Faiss", "Model", "Faiss & Model", "Semantic Search & Model"])
224
- demo.launch()
 
 
 
 
 
 
 
1
  # -*- coding: utf-8 -*-
 
 
 
 
 
 
 
2
 
3
  #@title scirpts
4
  import time
 
170
  # score="{:.4f})".format(hit['score'])
171
  # Answer = self.model_pipeline(message, context)
172
  # return Answer
173
+ def predict_semantic_search(self, message):
174
+ message = message.strip()
175
+ query_embedding = self.embedding_model.encode([message], convert_to_tensor=True)[0]
176
+ corpus_embeddings = self.embedding_model.encode(self.df['Question'].tolist(), convert_to_tensor=True)
177
+ hits = util.semantic_search(query_embedding.unsqueeze(0), corpus_embeddings, top_k=1)
178
+ hit = hits[0][0]
179
+ context = self.df['Context'][hit['corpus_id']]
180
+ Answer = self.model_pipeline(message, context)
181
+ return Answer
182
+
183
+
184
  def predict_semantic_search(self, message):
185
  message = message.strip()
186
  query_embedding = self.embedding_model.encode([message], convert_to_tensor=True)[0] # Fix here
 
225
  demoWithoutFiss = gr.Interface(fn=bot._chatbot.predict_without_faiss, inputs="text", outputs="text",examples=EXAMPLE_PATH, title="TH wiki (just Model)")
226
 
227
  demo = gr.TabbedInterface([demoFaiss, demoWithoutFiss, demoBert, demoSemantic], ["Faiss", "Model", "Faiss & Model", "Semantic Search & Model"])
228
+ demo.launch()
229
+ if __name__ == "__main__":
230
+ df = load_data()
231
+ model, tokenizer = load_model('wangchanberta-hyp')
232
+ embedding_model = load_embedding_model()
233
+ index = set_index(prepare_sentences_vector(load_embeddings(EMBEDDINGS_PATH)))
234
+ interface.launch()