Spaces:
Sleeping
Sleeping
Chananchida
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,3 @@
|
|
1 |
-
# -*- coding: utf-8 -*-
|
2 |
-
|
3 |
-
#@title scirpts
|
4 |
import time
|
5 |
import numpy as np
|
6 |
import pandas as pd
|
@@ -8,8 +5,7 @@ import torch
|
|
8 |
import faiss
|
9 |
from sklearn.preprocessing import normalize
|
10 |
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
|
11 |
-
from sentence_transformers import SentenceTransformer
|
12 |
-
from pythainlp import Tokenizer
|
13 |
import pickle
|
14 |
import gradio as gr
|
15 |
|
@@ -17,7 +13,7 @@ print(torch.cuda.is_available())
|
|
17 |
|
18 |
__all__ = [
|
19 |
"mdeberta",
|
20 |
-
"wangchanberta-hyp",
|
21 |
]
|
22 |
|
23 |
predict_method = [
|
@@ -27,8 +23,8 @@ predict_method = [
|
|
27 |
"semanticSearchWithModel",
|
28 |
]
|
29 |
|
30 |
-
DEFAULT_MODEL='wangchanberta-hyp'
|
31 |
-
DEFAULT_SENTENCE_EMBEDDING_MODEL='intfloat/multilingual-e5-base'
|
32 |
|
33 |
MODEL_DICT = {
|
34 |
'wangchanberta': 'Chananchida/wangchanberta-th-wiki-qa_ref-params',
|
@@ -37,8 +33,8 @@ MODEL_DICT = {
|
|
37 |
'mdeberta-hyp': 'Chananchida/mdeberta-v3-th-wiki-qa_hyp-params',
|
38 |
}
|
39 |
|
40 |
-
DATA_PATH='models/dataset.xlsx'
|
41 |
-
EMBEDDINGS_PATH='models/embeddings.pkl'
|
42 |
|
43 |
|
44 |
class ChatbotModel:
|
@@ -50,12 +46,12 @@ class ChatbotModel:
|
|
50 |
self._chatbot.set_vectors()
|
51 |
self._chatbot.set_index()
|
52 |
|
53 |
-
|
54 |
def chat(self, question):
|
55 |
return self._chatbot.answer_question(question)
|
56 |
|
57 |
-
def eval(self,model,predict_method):
|
58 |
-
return self._chatbot.eval(model_name=model,predict_method=predict_method)
|
|
|
59 |
|
60 |
class Chatbot:
|
61 |
def __init__(self):
|
@@ -73,31 +69,29 @@ class Chatbot:
|
|
73 |
def load_data(self, path: str = DATA_PATH):
|
74 |
self.df = pd.read_excel(path, sheet_name='Default')
|
75 |
self.df['Context'] = pd.read_excel(path, sheet_name='mdeberta')['Context']
|
76 |
-
# print('Load data done')
|
77 |
|
78 |
def load_model(self, model_name: str = DEFAULT_MODEL):
|
79 |
self.model = AutoModelForQuestionAnswering.from_pretrained(MODEL_DICT[model_name])
|
80 |
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_DICT[model_name])
|
81 |
self.model_name = model_name
|
82 |
-
# print('Load model done')
|
83 |
|
84 |
def load_embedding_model(self, model_name: str = DEFAULT_SENTENCE_EMBEDDING_MODEL):
|
85 |
-
if torch.cuda.is_available():
|
86 |
-
self.embedding_model = SentenceTransformer(model_name, device='
|
87 |
-
else:
|
88 |
-
|
89 |
|
90 |
def set_vectors(self):
|
91 |
self.vectors = self.prepare_sentences_vector(self.load_embeddings(EMBEDDINGS_PATH))
|
92 |
|
93 |
def set_index(self):
|
94 |
-
if torch.cuda.is_available():
|
95 |
res = faiss.StandardGpuResources()
|
96 |
self.index = faiss.IndexFlatL2(self.vectors.shape[1])
|
97 |
gpu_index_flat = faiss.index_cpu_to_gpu(res, 0, self.index)
|
98 |
gpu_index_flat.add(self.vectors)
|
99 |
self.index = gpu_index_flat
|
100 |
-
else:
|
101 |
self.index = faiss.IndexFlatL2(self.vectors.shape[1])
|
102 |
self.index.add(self.vectors)
|
103 |
|
@@ -110,18 +104,15 @@ class Chatbot:
|
|
110 |
encoded_list = normalize(encoded_list)
|
111 |
return encoded_list
|
112 |
|
113 |
-
|
114 |
def store_embeddings(self, embeddings):
|
115 |
with open('models/embeddings.pkl', "wb") as fOut:
|
116 |
pickle.dump({'sentences': self.df['Question'], 'embeddings': embeddings}, fOut, protocol=pickle.HIGHEST_PROTOCOL)
|
117 |
-
print('Store embeddings done')
|
118 |
|
119 |
def load_embeddings(self, file_path):
|
120 |
with open(file_path, "rb") as fIn:
|
121 |
stored_data = pickle.load(fIn)
|
122 |
stored_sentences = stored_data['sentences']
|
123 |
stored_embeddings = stored_data['embeddings']
|
124 |
-
print('Load (questions) embeddings done')
|
125 |
return stored_embeddings
|
126 |
|
127 |
def model_pipeline(self, question, similar_context):
|
@@ -140,25 +131,24 @@ class Chatbot:
|
|
140 |
similar_contexts = [self.df['Context'][indices[0][i]] for i in range(self.k)]
|
141 |
return similar_questions, similar_contexts, distances, indices
|
142 |
|
143 |
-
|
144 |
-
def predict(self,message):
|
145 |
message = message.strip()
|
146 |
question_vector = self.get_embeddings(message)
|
147 |
-
question_vector=self.prepare_sentences_vector([question_vector])
|
148 |
-
similar_questions, similar_contexts, distances,indices = self.faiss_search(question_vector)
|
149 |
Answer = self.model_pipeline(message, similar_contexts)
|
150 |
start_index = similar_contexts.find(Answer)
|
151 |
end_index = start_index + len(Answer)
|
152 |
-
_time = time.time() - t
|
153 |
output = {
|
154 |
"user_question": message,
|
155 |
-
"answer": df['Answer'][indices[0][0]],
|
156 |
-
"totaltime": round(_time, 3),
|
157 |
"distance": round(distances[0][0], 4),
|
158 |
"highlight_start": start_index,
|
159 |
"highlight_end": end_index
|
160 |
}
|
161 |
return output
|
|
|
|
|
162 |
def highlight_text(text, start_index, end_index):
|
163 |
if start_index < 0:
|
164 |
start_index = 0
|
@@ -166,21 +156,21 @@ def highlight_text(text, start_index, end_index):
|
|
166 |
end_index = len(text)
|
167 |
highlighted_text = ""
|
168 |
for i, char in enumerate(text):
|
169 |
-
|
170 |
highlighted_text += "<mark>"
|
171 |
highlighted_text += char
|
172 |
if i == end_index - 1:
|
173 |
highlighted_text += "</mark>"
|
174 |
return highlighted_text
|
175 |
|
176 |
-
|
177 |
-
""
|
178 |
-
if __name__ == "__main__":
|
179 |
bot = ChatbotModel()
|
|
|
180 |
def chat_interface(question, history):
|
181 |
-
response = bot._chatbot.predict(
|
182 |
highlighted_answer = highlight_text(response["answer"], response["highlight_start"], response["highlight_end"])
|
183 |
return highlighted_answer
|
184 |
-
|
185 |
-
demo = gr.
|
186 |
demo.launch()
|
|
|
|
|
|
|
|
|
1 |
import time
|
2 |
import numpy as np
|
3 |
import pandas as pd
|
|
|
5 |
import faiss
|
6 |
from sklearn.preprocessing import normalize
|
7 |
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
|
8 |
+
from sentence_transformers import SentenceTransformer
|
|
|
9 |
import pickle
|
10 |
import gradio as gr
|
11 |
|
|
|
13 |
|
14 |
__all__ = [
|
15 |
"mdeberta",
|
16 |
+
"wangchanberta-hyp", # Best model
|
17 |
]
|
18 |
|
19 |
predict_method = [
|
|
|
23 |
"semanticSearchWithModel",
|
24 |
]
|
25 |
|
26 |
+
DEFAULT_MODEL = 'wangchanberta-hyp'
|
27 |
+
DEFAULT_SENTENCE_EMBEDDING_MODEL = 'intfloat/multilingual-e5-base'
|
28 |
|
29 |
MODEL_DICT = {
|
30 |
'wangchanberta': 'Chananchida/wangchanberta-th-wiki-qa_ref-params',
|
|
|
33 |
'mdeberta-hyp': 'Chananchida/mdeberta-v3-th-wiki-qa_hyp-params',
|
34 |
}
|
35 |
|
36 |
+
DATA_PATH = 'models/dataset.xlsx'
|
37 |
+
EMBEDDINGS_PATH = 'models/embeddings.pkl'
|
38 |
|
39 |
|
40 |
class ChatbotModel:
|
|
|
46 |
self._chatbot.set_vectors()
|
47 |
self._chatbot.set_index()
|
48 |
|
|
|
49 |
def chat(self, question):
|
50 |
return self._chatbot.answer_question(question)
|
51 |
|
52 |
+
def eval(self, model, predict_method):
|
53 |
+
return self._chatbot.eval(model_name=model, predict_method=predict_method)
|
54 |
+
|
55 |
|
56 |
class Chatbot:
|
57 |
def __init__(self):
|
|
|
69 |
def load_data(self, path: str = DATA_PATH):
|
70 |
self.df = pd.read_excel(path, sheet_name='Default')
|
71 |
self.df['Context'] = pd.read_excel(path, sheet_name='mdeberta')['Context']
|
|
|
72 |
|
73 |
def load_model(self, model_name: str = DEFAULT_MODEL):
|
74 |
self.model = AutoModelForQuestionAnswering.from_pretrained(MODEL_DICT[model_name])
|
75 |
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_DICT[model_name])
|
76 |
self.model_name = model_name
|
|
|
77 |
|
78 |
def load_embedding_model(self, model_name: str = DEFAULT_SENTENCE_EMBEDDING_MODEL):
|
79 |
+
if torch.cuda.is_available():
|
80 |
+
self.embedding_model = SentenceTransformer(model_name, device='cuda')
|
81 |
+
else:
|
82 |
+
self.embedding_model = SentenceTransformer(model_name)
|
83 |
|
84 |
def set_vectors(self):
|
85 |
self.vectors = self.prepare_sentences_vector(self.load_embeddings(EMBEDDINGS_PATH))
|
86 |
|
87 |
def set_index(self):
|
88 |
+
if torch.cuda.is_available():
|
89 |
res = faiss.StandardGpuResources()
|
90 |
self.index = faiss.IndexFlatL2(self.vectors.shape[1])
|
91 |
gpu_index_flat = faiss.index_cpu_to_gpu(res, 0, self.index)
|
92 |
gpu_index_flat.add(self.vectors)
|
93 |
self.index = gpu_index_flat
|
94 |
+
else:
|
95 |
self.index = faiss.IndexFlatL2(self.vectors.shape[1])
|
96 |
self.index.add(self.vectors)
|
97 |
|
|
|
104 |
encoded_list = normalize(encoded_list)
|
105 |
return encoded_list
|
106 |
|
|
|
107 |
def store_embeddings(self, embeddings):
|
108 |
with open('models/embeddings.pkl', "wb") as fOut:
|
109 |
pickle.dump({'sentences': self.df['Question'], 'embeddings': embeddings}, fOut, protocol=pickle.HIGHEST_PROTOCOL)
|
|
|
110 |
|
111 |
def load_embeddings(self, file_path):
|
112 |
with open(file_path, "rb") as fIn:
|
113 |
stored_data = pickle.load(fIn)
|
114 |
stored_sentences = stored_data['sentences']
|
115 |
stored_embeddings = stored_data['embeddings']
|
|
|
116 |
return stored_embeddings
|
117 |
|
118 |
def model_pipeline(self, question, similar_context):
|
|
|
131 |
similar_contexts = [self.df['Context'][indices[0][i]] for i in range(self.k)]
|
132 |
return similar_questions, similar_contexts, distances, indices
|
133 |
|
134 |
+
def predict(self, message):
|
|
|
135 |
message = message.strip()
|
136 |
question_vector = self.get_embeddings(message)
|
137 |
+
question_vector = self.prepare_sentences_vector([question_vector])
|
138 |
+
similar_questions, similar_contexts, distances, indices = self.faiss_search(question_vector)
|
139 |
Answer = self.model_pipeline(message, similar_contexts)
|
140 |
start_index = similar_contexts.find(Answer)
|
141 |
end_index = start_index + len(Answer)
|
|
|
142 |
output = {
|
143 |
"user_question": message,
|
144 |
+
"answer": self.df['Answer'][indices[0][0]],
|
|
|
145 |
"distance": round(distances[0][0], 4),
|
146 |
"highlight_start": start_index,
|
147 |
"highlight_end": end_index
|
148 |
}
|
149 |
return output
|
150 |
+
|
151 |
+
|
152 |
def highlight_text(text, start_index, end_index):
|
153 |
if start_index < 0:
|
154 |
start_index = 0
|
|
|
156 |
end_index = len(text)
|
157 |
highlighted_text = ""
|
158 |
for i, char in enumerate(text):
|
159 |
+
if i == start_index:
|
160 |
highlighted_text += "<mark>"
|
161 |
highlighted_text += char
|
162 |
if i == end_index - 1:
|
163 |
highlighted_text += "</mark>"
|
164 |
return highlighted_text
|
165 |
|
166 |
+
|
167 |
+
if __name__ == "__main__":
|
|
|
168 |
bot = ChatbotModel()
|
169 |
+
|
170 |
def chat_interface(question, history):
|
171 |
+
response = bot._chatbot.predict(question)
|
172 |
highlighted_answer = highlight_text(response["answer"], response["highlight_start"], response["highlight_end"])
|
173 |
return highlighted_answer
|
174 |
+
|
175 |
+
demo = gr.Interface(fn=chat_interface, title="Thai Question Answering System", inputs="text", outputs="html")
|
176 |
demo.launch()
|