Spaces:
Sleeping
Sleeping
import time | |
import numpy as np | |
import pandas as pd | |
import torch | |
import faiss | |
from sklearn.preprocessing import normalize | |
from transformers import AutoTokenizer, AutoModelForQuestionAnswering | |
from sentence_transformers import SentenceTransformer | |
import pickle | |
import gradio as gr | |
print(torch.cuda.is_available()) | |
__all__ = [ | |
"mdeberta", | |
"wangchanberta-hyp", # Best model | |
] | |
predict_method = [ | |
"faiss", | |
"faissWithModel", | |
"cosineWithModel", | |
"semanticSearchWithModel", | |
] | |
DEFAULT_MODEL = 'wangchanberta-hyp' | |
DEFAULT_SENTENCE_EMBEDDING_MODEL = 'intfloat/multilingual-e5-base' | |
MODEL_DICT = { | |
'wangchanberta': 'Chananchida/wangchanberta-th-wiki-qa_ref-params', | |
'wangchanberta-hyp': 'Chananchida/wangchanberta-th-wiki-qa_hyp-params', | |
'mdeberta': 'Chananchida/mdeberta-v3-th-wiki-qa_ref-params', | |
'mdeberta-hyp': 'Chananchida/mdeberta-v3-th-wiki-qa_hyp-params', | |
} | |
DATA_PATH = 'models/dataset.xlsx' | |
EMBEDDINGS_PATH = 'models/embeddings.pkl' | |
class ChatbotModel: | |
def __init__(self, model=DEFAULT_MODEL): | |
self._chatbot = Chatbot() | |
self._chatbot.load_data() | |
self._chatbot.load_model(model) | |
self._chatbot.load_embedding_model(DEFAULT_SENTENCE_EMBEDDING_MODEL) | |
self._chatbot.set_vectors() | |
self._chatbot.set_index() | |
def chat(self, question): | |
return self._chatbot.answer_question(question) | |
def eval(self, model, predict_method): | |
return self._chatbot.eval(model_name=model, predict_method=predict_method) | |
class Chatbot: | |
def __init__(self): | |
# Initialize variables | |
self.df = None | |
self.test_df = None | |
self.model = None | |
self.model_name = None | |
self.tokenizer = None | |
self.embedding_model = None | |
self.vectors = None | |
self.index = None | |
self.k = 1 # top k most similar | |
def load_data(self, path: str = DATA_PATH): | |
self.df = pd.read_excel(path, sheet_name='Default') | |
self.df['Context'] = pd.read_excel(path, sheet_name='mdeberta')['Context'] | |
def load_model(self, model_name: str = DEFAULT_MODEL): | |
self.model = AutoModelForQuestionAnswering.from_pretrained(MODEL_DICT[model_name]) | |
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_DICT[model_name]) | |
self.model_name = model_name | |
def load_embedding_model(self, model_name: str = DEFAULT_SENTENCE_EMBEDDING_MODEL): | |
if torch.cuda.is_available(): | |
self.embedding_model = SentenceTransformer(model_name, device='cuda') | |
else: | |
self.embedding_model = SentenceTransformer(model_name) | |
def set_vectors(self): | |
self.vectors = self.prepare_sentences_vector(self.load_embeddings(EMBEDDINGS_PATH)) | |
def set_index(self): | |
if torch.cuda.is_available(): | |
res = faiss.StandardGpuResources() | |
self.index = faiss.IndexFlatL2(self.vectors.shape[1]) | |
gpu_index_flat = faiss.index_cpu_to_gpu(res, 0, self.index) | |
gpu_index_flat.add(self.vectors) | |
self.index = gpu_index_flat | |
else: | |
self.index = faiss.IndexFlatL2(self.vectors.shape[1]) | |
self.index.add(self.vectors) | |
def get_embeddings(self, text_list): | |
return self.embedding_model.encode(text_list) | |
def prepare_sentences_vector(self, encoded_list): | |
encoded_list = [i.reshape(1, -1) for i in encoded_list] | |
encoded_list = np.vstack(encoded_list).astype('float32') | |
encoded_list = normalize(encoded_list) | |
return encoded_list | |
def store_embeddings(self, embeddings): | |
with open('models/embeddings.pkl', "wb") as fOut: | |
pickle.dump({'sentences': self.df['Question'], 'embeddings': embeddings}, fOut, protocol=pickle.HIGHEST_PROTOCOL) | |
def load_embeddings(self, file_path): | |
with open(file_path, "rb") as fIn: | |
stored_data = pickle.load(fIn) | |
stored_sentences = stored_data['sentences'] | |
stored_embeddings = stored_data['embeddings'] | |
return stored_embeddings | |
def model_pipeline(self, question, similar_context): | |
inputs = self.tokenizer(question, similar_context, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = self.model(**inputs) | |
answer_start_index = outputs.start_logits.argmax() | |
answer_end_index = outputs.end_logits.argmax() | |
predict_answer_tokens = inputs.input_ids[0, answer_start_index: answer_end_index + 1] | |
Answer = self.tokenizer.decode(predict_answer_tokens) | |
return Answer | |
def faiss_search(self, question_vector): | |
distances, indices = self.index.search(question_vector, self.k) | |
similar_questions = [self.df['Question'][indices[0][i]] for i in range(self.k)] | |
similar_contexts = [self.df['Context'][indices[0][i]] for i in range(self.k)] | |
return similar_questions, similar_contexts, distances, indices | |
def predict(self, message): | |
message = message.strip() | |
question_vector = self.get_embeddings(message) | |
question_vector = self.prepare_sentences_vector([question_vector]) | |
similar_questions, similar_contexts, distances, indices = self.faiss_search(question_vector) | |
Answer = self.model_pipeline(message, similar_contexts) | |
start_index = similar_contexts.find(Answer) | |
end_index = start_index + len(Answer) | |
output = { | |
"user_question": message, | |
"answer": self.df['Answer'][indices[0][0]], | |
"distance": round(distances[0][0], 4), | |
"highlight_start": start_index, | |
"highlight_end": end_index | |
} | |
return output | |
def highlight_text(text, start_index, end_index): | |
if start_index < 0: | |
start_index = 0 | |
if end_index > len(text): | |
end_index = len(text) | |
highlighted_text = "" | |
for i, char in enumerate(text): | |
if i == start_index: | |
highlighted_text += "<mark>" | |
highlighted_text += char | |
if i == end_index - 1: | |
highlighted_text += "</mark>" | |
return highlighted_text | |
if __name__ == "__main__": | |
bot = ChatbotModel() | |
def chat_interface(question, history): | |
response = bot._chatbot.predict(question) | |
highlighted_answer = highlight_text(response["answer"], response["highlight_start"], response["highlight_end"]) | |
return highlighted_answer | |
demo = gr.Interface(fn=chat_interface, title="Thai Question Answering System", inputs="text", outputs="html") | |
demo.launch() | |