Chananchida commited on
Commit
9d5b2f3
·
verified ·
1 Parent(s): 7501763

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -38
app.py CHANGED
@@ -1,6 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
-
3
- #@title scirpts
4
  import time
5
  import numpy as np
6
  import pandas as pd
@@ -8,8 +5,7 @@ import torch
8
  import faiss
9
  from sklearn.preprocessing import normalize
10
  from transformers import AutoTokenizer, AutoModelForQuestionAnswering
11
- from sentence_transformers import SentenceTransformer,util
12
- from pythainlp import Tokenizer
13
  import pickle
14
  import gradio as gr
15
 
@@ -17,7 +13,7 @@ print(torch.cuda.is_available())
17
 
18
  __all__ = [
19
  "mdeberta",
20
- "wangchanberta-hyp", # Best model
21
  ]
22
 
23
  predict_method = [
@@ -27,8 +23,8 @@ predict_method = [
27
  "semanticSearchWithModel",
28
  ]
29
 
30
- DEFAULT_MODEL='wangchanberta-hyp'
31
- DEFAULT_SENTENCE_EMBEDDING_MODEL='intfloat/multilingual-e5-base'
32
 
33
  MODEL_DICT = {
34
  'wangchanberta': 'Chananchida/wangchanberta-th-wiki-qa_ref-params',
@@ -37,8 +33,8 @@ MODEL_DICT = {
37
  'mdeberta-hyp': 'Chananchida/mdeberta-v3-th-wiki-qa_hyp-params',
38
  }
39
 
40
- DATA_PATH='models/dataset.xlsx'
41
- EMBEDDINGS_PATH='models/embeddings.pkl'
42
 
43
 
44
  class ChatbotModel:
@@ -50,12 +46,12 @@ class ChatbotModel:
50
  self._chatbot.set_vectors()
51
  self._chatbot.set_index()
52
 
53
-
54
  def chat(self, question):
55
  return self._chatbot.answer_question(question)
56
 
57
- def eval(self,model,predict_method):
58
- return self._chatbot.eval(model_name=model,predict_method=predict_method)
 
59
 
60
  class Chatbot:
61
  def __init__(self):
@@ -73,31 +69,29 @@ class Chatbot:
73
  def load_data(self, path: str = DATA_PATH):
74
  self.df = pd.read_excel(path, sheet_name='Default')
75
  self.df['Context'] = pd.read_excel(path, sheet_name='mdeberta')['Context']
76
- # print('Load data done')
77
 
78
  def load_model(self, model_name: str = DEFAULT_MODEL):
79
  self.model = AutoModelForQuestionAnswering.from_pretrained(MODEL_DICT[model_name])
80
  self.tokenizer = AutoTokenizer.from_pretrained(MODEL_DICT[model_name])
81
  self.model_name = model_name
82
- # print('Load model done')
83
 
84
  def load_embedding_model(self, model_name: str = DEFAULT_SENTENCE_EMBEDDING_MODEL):
85
- if torch.cuda.is_available(): # Check if GPU is available
86
- self.embedding_model = SentenceTransformer(model_name, device='cpu')
87
- else: self.embedding_model = SentenceTransformer(model_name)
88
- # print('Load sentence embedding model done')
89
 
90
  def set_vectors(self):
91
  self.vectors = self.prepare_sentences_vector(self.load_embeddings(EMBEDDINGS_PATH))
92
 
93
  def set_index(self):
94
- if torch.cuda.is_available(): # Check if GPU is available
95
  res = faiss.StandardGpuResources()
96
  self.index = faiss.IndexFlatL2(self.vectors.shape[1])
97
  gpu_index_flat = faiss.index_cpu_to_gpu(res, 0, self.index)
98
  gpu_index_flat.add(self.vectors)
99
  self.index = gpu_index_flat
100
- else: # If GPU is not available, use CPU-based Faiss index
101
  self.index = faiss.IndexFlatL2(self.vectors.shape[1])
102
  self.index.add(self.vectors)
103
 
@@ -110,18 +104,15 @@ class Chatbot:
110
  encoded_list = normalize(encoded_list)
111
  return encoded_list
112
 
113
-
114
  def store_embeddings(self, embeddings):
115
  with open('models/embeddings.pkl', "wb") as fOut:
116
  pickle.dump({'sentences': self.df['Question'], 'embeddings': embeddings}, fOut, protocol=pickle.HIGHEST_PROTOCOL)
117
- print('Store embeddings done')
118
 
119
  def load_embeddings(self, file_path):
120
  with open(file_path, "rb") as fIn:
121
  stored_data = pickle.load(fIn)
122
  stored_sentences = stored_data['sentences']
123
  stored_embeddings = stored_data['embeddings']
124
- print('Load (questions) embeddings done')
125
  return stored_embeddings
126
 
127
  def model_pipeline(self, question, similar_context):
@@ -140,25 +131,24 @@ class Chatbot:
140
  similar_contexts = [self.df['Context'][indices[0][i]] for i in range(self.k)]
141
  return similar_questions, similar_contexts, distances, indices
142
 
143
-
144
- def predict(self,message):
145
  message = message.strip()
146
  question_vector = self.get_embeddings(message)
147
- question_vector=self.prepare_sentences_vector([question_vector])
148
- similar_questions, similar_contexts, distances,indices = self.faiss_search(question_vector)
149
  Answer = self.model_pipeline(message, similar_contexts)
150
  start_index = similar_contexts.find(Answer)
151
  end_index = start_index + len(Answer)
152
- _time = time.time() - t
153
  output = {
154
  "user_question": message,
155
- "answer": df['Answer'][indices[0][0]],
156
- "totaltime": round(_time, 3),
157
  "distance": round(distances[0][0], 4),
158
  "highlight_start": start_index,
159
  "highlight_end": end_index
160
  }
161
  return output
 
 
162
  def highlight_text(text, start_index, end_index):
163
  if start_index < 0:
164
  start_index = 0
@@ -166,21 +156,21 @@ def highlight_text(text, start_index, end_index):
166
  end_index = len(text)
167
  highlighted_text = ""
168
  for i, char in enumerate(text):
169
- if i == start_index:
170
  highlighted_text += "<mark>"
171
  highlighted_text += char
172
  if i == end_index - 1:
173
  highlighted_text += "</mark>"
174
  return highlighted_text
175
 
176
-
177
- """#Gradio"""
178
- if __name__ == "__main__":
179
  bot = ChatbotModel()
 
180
  def chat_interface(question, history):
181
- response = bot._chatbot.predict(model, tokenizer, embedding_model, df, question, index)
182
  highlighted_answer = highlight_text(response["answer"], response["highlight_start"], response["highlight_end"])
183
  return highlighted_answer
184
- # EXAMPLE = ["หลิน ไห่เฟิง มีชื่อเรียกอีกชื่อว่าอะไร" , "ใครเป็นผู้ตั้งสภาเศรษฐกิจโลกขึ้นในปี พ.ศ. 2514 โดยทุกปีจะมีการประชุมที่ประเทศสวิตเซอร์แลนด์", "โปรดิวเซอร์ของอัลบั้มตลอดกาล ของวงคีรีบูนคือใคร", "สกุลเดิมของหม่อมครูนุ่ม นวรัตน ณ อยุธยา คืออะไร"]
185
- demo = gr.ChatInterface(fn=chat_interface, title="CE66-04_Thai Question Answering System by using Deep Learning")
186
  demo.launch()
 
 
 
 
1
  import time
2
  import numpy as np
3
  import pandas as pd
 
5
  import faiss
6
  from sklearn.preprocessing import normalize
7
  from transformers import AutoTokenizer, AutoModelForQuestionAnswering
8
+ from sentence_transformers import SentenceTransformer
 
9
  import pickle
10
  import gradio as gr
11
 
 
13
 
14
  __all__ = [
15
  "mdeberta",
16
+ "wangchanberta-hyp", # Best model
17
  ]
18
 
19
  predict_method = [
 
23
  "semanticSearchWithModel",
24
  ]
25
 
26
+ DEFAULT_MODEL = 'wangchanberta-hyp'
27
+ DEFAULT_SENTENCE_EMBEDDING_MODEL = 'intfloat/multilingual-e5-base'
28
 
29
  MODEL_DICT = {
30
  'wangchanberta': 'Chananchida/wangchanberta-th-wiki-qa_ref-params',
 
33
  'mdeberta-hyp': 'Chananchida/mdeberta-v3-th-wiki-qa_hyp-params',
34
  }
35
 
36
+ DATA_PATH = 'models/dataset.xlsx'
37
+ EMBEDDINGS_PATH = 'models/embeddings.pkl'
38
 
39
 
40
  class ChatbotModel:
 
46
  self._chatbot.set_vectors()
47
  self._chatbot.set_index()
48
 
 
49
  def chat(self, question):
50
  return self._chatbot.answer_question(question)
51
 
52
+ def eval(self, model, predict_method):
53
+ return self._chatbot.eval(model_name=model, predict_method=predict_method)
54
+
55
 
56
  class Chatbot:
57
  def __init__(self):
 
69
  def load_data(self, path: str = DATA_PATH):
70
  self.df = pd.read_excel(path, sheet_name='Default')
71
  self.df['Context'] = pd.read_excel(path, sheet_name='mdeberta')['Context']
 
72
 
73
  def load_model(self, model_name: str = DEFAULT_MODEL):
74
  self.model = AutoModelForQuestionAnswering.from_pretrained(MODEL_DICT[model_name])
75
  self.tokenizer = AutoTokenizer.from_pretrained(MODEL_DICT[model_name])
76
  self.model_name = model_name
 
77
 
78
  def load_embedding_model(self, model_name: str = DEFAULT_SENTENCE_EMBEDDING_MODEL):
79
+ if torch.cuda.is_available():
80
+ self.embedding_model = SentenceTransformer(model_name, device='cuda')
81
+ else:
82
+ self.embedding_model = SentenceTransformer(model_name)
83
 
84
  def set_vectors(self):
85
  self.vectors = self.prepare_sentences_vector(self.load_embeddings(EMBEDDINGS_PATH))
86
 
87
  def set_index(self):
88
+ if torch.cuda.is_available():
89
  res = faiss.StandardGpuResources()
90
  self.index = faiss.IndexFlatL2(self.vectors.shape[1])
91
  gpu_index_flat = faiss.index_cpu_to_gpu(res, 0, self.index)
92
  gpu_index_flat.add(self.vectors)
93
  self.index = gpu_index_flat
94
+ else:
95
  self.index = faiss.IndexFlatL2(self.vectors.shape[1])
96
  self.index.add(self.vectors)
97
 
 
104
  encoded_list = normalize(encoded_list)
105
  return encoded_list
106
 
 
107
  def store_embeddings(self, embeddings):
108
  with open('models/embeddings.pkl', "wb") as fOut:
109
  pickle.dump({'sentences': self.df['Question'], 'embeddings': embeddings}, fOut, protocol=pickle.HIGHEST_PROTOCOL)
 
110
 
111
  def load_embeddings(self, file_path):
112
  with open(file_path, "rb") as fIn:
113
  stored_data = pickle.load(fIn)
114
  stored_sentences = stored_data['sentences']
115
  stored_embeddings = stored_data['embeddings']
 
116
  return stored_embeddings
117
 
118
  def model_pipeline(self, question, similar_context):
 
131
  similar_contexts = [self.df['Context'][indices[0][i]] for i in range(self.k)]
132
  return similar_questions, similar_contexts, distances, indices
133
 
134
+ def predict(self, message):
 
135
  message = message.strip()
136
  question_vector = self.get_embeddings(message)
137
+ question_vector = self.prepare_sentences_vector([question_vector])
138
+ similar_questions, similar_contexts, distances, indices = self.faiss_search(question_vector)
139
  Answer = self.model_pipeline(message, similar_contexts)
140
  start_index = similar_contexts.find(Answer)
141
  end_index = start_index + len(Answer)
 
142
  output = {
143
  "user_question": message,
144
+ "answer": self.df['Answer'][indices[0][0]],
 
145
  "distance": round(distances[0][0], 4),
146
  "highlight_start": start_index,
147
  "highlight_end": end_index
148
  }
149
  return output
150
+
151
+
152
  def highlight_text(text, start_index, end_index):
153
  if start_index < 0:
154
  start_index = 0
 
156
  end_index = len(text)
157
  highlighted_text = ""
158
  for i, char in enumerate(text):
159
+ if i == start_index:
160
  highlighted_text += "<mark>"
161
  highlighted_text += char
162
  if i == end_index - 1:
163
  highlighted_text += "</mark>"
164
  return highlighted_text
165
 
166
+
167
+ if __name__ == "__main__":
 
168
  bot = ChatbotModel()
169
+
170
  def chat_interface(question, history):
171
+ response = bot._chatbot.predict(question)
172
  highlighted_answer = highlight_text(response["answer"], response["highlight_start"], response["highlight_end"])
173
  return highlighted_answer
174
+
175
+ demo = gr.Interface(fn=chat_interface, title="Thai Question Answering System", inputs="text", outputs="html")
176
  demo.launch()