strongpear's picture
update display options
6222e50
import streamlit as st
import numpy as np
import torch
from torch.autograd import Variable
import argparse
import os
import re
from data_preprocessing import remove_xem_them, remove_emojis, remove_stopwords, format_punctuation, remove_punctuation, clean_text, normalize_format, word_segment, format_price, format_price_v2
class inferSSCL():
def __init__(self, args='None'):
self.args = args
self.base_models = {}
self.batch_data = {}
self.test_data = []
self.output = []
def load_vocab_pretrain(self, file_pretrain_vocab, file_pretrain_vec, pad_tokens=True):
vocab2id = {'<pad>': 0}
id2vocab = {0: '<pad>'}
cnt = len(id2vocab)
with open(file_pretrain_vocab, 'r', encoding='utf-8') as fp:
for line in fp:
arr = re.split(' ', line[:-1])
vocab2id[arr[1]] = cnt
id2vocab[cnt] = arr[1]
cnt += 1
# word embedding
pretrain_vec = np.load(file_pretrain_vec)
pad_vec = np.zeros([1, pretrain_vec.shape[1]])
pretrain_vec = np.vstack((pad_vec, pretrain_vec))
return vocab2id, id2vocab, pretrain_vec
def load_vocabulary(self):
cluster_dir = './'
file_wordvec = 'vectors.npy'
file_vocab = 'vocab.txt'
file_kmeans_centroid = 'aspect_centroid.txt'
file_aspect_mapping = 'aspect_mapping.txt'
vocab2id, id2vocab, pretrain_vec = self.load_vocab_pretrain(os.path.join(cluster_dir, file_vocab), os.path.join(cluster_dir, file_wordvec))
vocab_size = len(vocab2id)
self.batch_data['vocab2id'] = vocab2id
self.batch_data['id2vocab'] = id2vocab
self.batch_data['pretrain_emb'] = pretrain_vec
self.batch_data['vocab_size'] = vocab_size
aspect_vec = np.loadtxt(os.path.join(cluster_dir, file_kmeans_centroid), dtype=float)
tmp = []
fp = open(os.path.join(cluster_dir, file_aspect_mapping), 'r')
for line in fp:
line = re.sub(r'[0-9]+', '', line)
line = line.replace(' ', '').replace('\n', '')
if line == "none":
tmp.append([0.] * 256)
else :
tmp.append([1.] * 256)
fp.close()
aspect_vec = aspect_vec * tmp
aspect_vec = torch.FloatTensor(aspect_vec).to(device)
self.batch_data['aspect_centroid'] = aspect_vec
self.batch_data['n_aspects'] = aspect_vec.shape[0]
def load_models(self):
self.base_models['embedding'] = torch.nn.Embedding(self.batch_data['vocab_size'], emb_size).to(device)
emb_para = torch.FloatTensor(self.batch_data['pretrain_emb']).to(device)
self.base_models['embedding'].weight = torch.nn.Parameter(emb_para)
self.base_models['asp_weight'] = torch.nn.Linear(emb_size, self.batch_data['n_aspects']).to(device)
self.base_models['asp_weight'].load_state_dict(torch.load('./asp_weight.model', map_location=torch.device('cpu')))
self.base_models['attn_kernel'] = torch.nn.Linear(emb_size, emb_size).to(device)
self.base_models['attn_kernel'].load_state_dict(torch.load('./attn_kernel.model', map_location=torch.device('cpu')), strict=False)
def build_pipe(self):
attn_pos, lbl_pos = self.encoder(
self.batch_data['pos_sen_var'],
self.batch_data['pos_pad_mask']
)
outw = np.around(attn_pos.data.cpu().numpy().tolist(), 4)
outw = outw.tolist()
outw = outw[:len(self.batch_data['comment'].split())]
asp_weight = self.base_models['asp_weight'](lbl_pos)
# Attention weight
asp_weight = torch.softmax(asp_weight, dim=1)
return asp_weight
def encoder(self, input_, mask_):
with torch.no_grad():
emb_ = self.base_models['embedding'](input_)
print(emb_.shape)
emb_ = emb_ * mask_.unsqueeze(2)
emb_avg = torch.sum(emb_, dim=1)
norm = torch.sum(mask_, dim=1, keepdim=True) + 1e-20
# query vector
enc_ = emb_avg.div(norm.expand_as(emb_avg))
#We Ex + be
emb_trn = self.base_models['attn_kernel'](emb_)
#query vetor * (We Ex + be)
attn_ = enc_.unsqueeze(1) @ emb_trn.transpose(1, 2)
attn_ = attn_.squeeze(1)
#alignment score
attn_ = self.args.smooth_factor * torch.tanh(attn_)
attn_ = attn_.masked_fill(mask_ == 0, -1e20)
# attention weight
attn_ = torch.softmax(attn_, dim=1)
#sxE
lbl_ = attn_.unsqueeze(1) @ emb_
lbl_ = lbl_.squeeze(1)
return attn_, lbl_
def build_batch(self, review):
vocab2id = self.batch_data['vocab2id']
sen_text = []
cmt = []
# sen_text_len = 0
sen_text_len = emb_size
senid = [vocab2id[wd] for wd in review.split() if wd in vocab2id]
sen_text.append(senid)
cmt.append(review)
# if len(senid) > sen_text_len:
# sen_text_len = len(senid)
sen_text_len = min(len(senid), sen_text_len)
sen_text = [itm[:sen_text_len] + [vocab2id['<pad>'] for _ in range(sen_text_len - len(itm))] for itm in sen_text]
sen_text_var = Variable(torch.LongTensor(sen_text)).to(device)
sen_pad_mask = Variable(torch.LongTensor(sen_text)).to(device)
sen_pad_mask[sen_pad_mask != vocab2id['<pad>']] = -1
sen_pad_mask[sen_pad_mask == vocab2id['<pad>']] = 0
sen_pad_mask = -sen_pad_mask
self.batch_data['comment'] = cmt
self.batch_data['pos_sen_var'] = sen_text_var
self.batch_data['pos_pad_mask'] = sen_pad_mask
def calculate_atten_weight(self):
attn_pos, lbl_pos = self.encoder(
self.batch_data['pos_sen_var'],
self.batch_data['pos_pad_mask']
)
asp_weight = self.base_models['asp_weight'](lbl_pos)
#print('asp_weight:', asp_weight)
asp_weight = torch.softmax(asp_weight, dim=1)
#print('soft_max:', asp_weight)
return asp_weight
def get_test_data(self):
asp_weight = self.calculate_atten_weight()
asp_weight = asp_weight.data.cpu().numpy().tolist()
output = {}
output['comment'] = self.batch_data['comment']
output['aspect_weight'] = asp_weight[0]
self.test_data.append(output)
def select_top(self, data):
#print(data)
d = np.abs(data - np.median(data))
mdev = np.median(d)
s = d/mdev if mdev else 0
return s
def get_predict(self, top_pred, aspect_label, threshold=3):
pred = {'none':0, 'do_an': 0, 'gia_ca':0, 'khong_gian': 0, 'phuc_vu': 0}
try:
for i in range(len(top_pred)):
if top_pred[i] > threshold:
pred[aspect_label[i]] = 1
except:
print('Error')
return pred
def get_evaluate_result(self, input_):
aspect_label = []
fp = open('./aspect_mapping.txt', 'r', encoding='utf8')
for line in fp:
aspect_label.append(line.split()[1])
fp.close()
top_score = self.select_top(input_['aspect_weight'])
print(top_score)
curr_pred = self.get_predict(top_score, aspect_label)
aspect_key = []
for key, value in curr_pred.items():
if int(value) == 1:
aspect_key.append(key)
return self.get_aspect(aspect_key)
def get_aspect(self, pred, ignore='none'):
if len(pred) > 1:
self.output.append(pred[1:])
else:
self.output.append(['None'])
def infer(self, text=''):
self.args.task = 'sscl-infer'
text = remove_xem_them(text)
text = remove_emojis(text)
text = format_punctuation(text)
text = remove_punctuation(text)
text = clean_text(text)
text = normalize_format(text)
text = word_segment(text)
text = remove_stopwords(text)
text = format_price(text)
input_ = format_price_v2(text)
print(input_)
self.load_vocabulary()
self.load_models()
self.build_batch(input_)
self.get_test_data()
val_result = self.test_data
self.get_evaluate_result(val_result[0])
parser = argparse.ArgumentParser()
parser.add_argument('--task', default='infer')
parser.add_argument('--smooth_factor', type=float, default=0.9)
device = 'cpu'
emb_size = 256
args = parser.parse_args(args=[])
model = inferSSCL(args)
cmt = st.text_area('Nhập nhận xét của bạn vào đây:')
if cmt == '':
st.title('Nội dung bình luận của bạn!')
else:
model.infer(cmt)
outputs = model.output[0]
if outputs:
for output in outputs:
if output == 'do_an':
st.title(':blue[Đồ ăn]')
elif output == 'gia_ca':
st.title(':blue[Giá cả]')
elif output == 'khong_gian':
st.title(':blue[Không gian]')
elif output == 'phuc_vu':
st.title(':blue[Phục vụ]')
else:
st.title('None')
st.divider()