strongpear
commited on
Commit
•
c2b0a49
1
Parent(s):
175db61
update infer
Browse files
app.py
CHANGED
@@ -1,3 +1,271 @@
|
|
1 |
import streamlit as st
|
2 |
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch.autograd import Variable
|
7 |
+
|
8 |
+
import argparse
|
9 |
+
import os
|
10 |
+
import re
|
11 |
+
|
12 |
+
from data_preprocessing import remove_xem_them, remove_emojis, remove_stopwords, format_punctuation, remove_punctuation, clean_text, normalize_format, word_segment, format_price, format_price_v2
|
13 |
+
|
14 |
+
class inferSSCL():
|
15 |
+
def __init__(self, args='None'):
|
16 |
+
self.args = args
|
17 |
+
self.base_models = {}
|
18 |
+
self.batch_data = {}
|
19 |
+
self.test_data = []
|
20 |
+
|
21 |
+
def load_vocab_pretrain(self, file_pretrain_vocab, file_pretrain_vec, pad_tokens=True):
|
22 |
+
vocab2id = {'<pad>': 0}
|
23 |
+
id2vocab = {0: '<pad>'}
|
24 |
+
|
25 |
+
cnt = len(id2vocab)
|
26 |
+
with open(file_pretrain_vocab, 'r', encoding='utf-8') as fp:
|
27 |
+
for line in fp:
|
28 |
+
arr = re.split(' ', line[:-1])
|
29 |
+
vocab2id[arr[1]] = cnt
|
30 |
+
id2vocab[cnt] = arr[1]
|
31 |
+
cnt += 1
|
32 |
+
# word embedding
|
33 |
+
pretrain_vec = np.load(file_pretrain_vec)
|
34 |
+
pad_vec = np.zeros([1, pretrain_vec.shape[1]])
|
35 |
+
pretrain_vec = np.vstack((pad_vec, pretrain_vec))
|
36 |
+
return vocab2id, id2vocab, pretrain_vec
|
37 |
+
|
38 |
+
def load_vocabulary(self):
|
39 |
+
cluster_dir = './'
|
40 |
+
file_wordvec = 'vectors.npy'
|
41 |
+
file_vocab = 'vocab.txt'
|
42 |
+
file_kmeans_centroid = 'aspect_centroid.txt'
|
43 |
+
file_aspect_mapping = 'aspect_mapping.txt'
|
44 |
+
|
45 |
+
vocab2id, id2vocab, pretrain_vec = self.load_vocab_pretrain(os.path.join(cluster_dir, file_vocab), os.path.join(cluster_dir, file_wordvec))
|
46 |
+
vocab_size = len(vocab2id)
|
47 |
+
|
48 |
+
self.batch_data['vocab2id'] = vocab2id
|
49 |
+
self.batch_data['id2vocab'] = id2vocab
|
50 |
+
self.batch_data['pretrain_emb'] = pretrain_vec
|
51 |
+
self.batch_data['vocab_size'] = vocab_size
|
52 |
+
|
53 |
+
aspect_vec = np.loadtxt(os.path.join(cluster_dir, file_kmeans_centroid), dtype=float)
|
54 |
+
|
55 |
+
tmp = []
|
56 |
+
fp = open(os.path.join(cluster_dir, file_aspect_mapping), 'r')
|
57 |
+
for line in fp:
|
58 |
+
line = re.sub(r'[0-9]+', '', line)
|
59 |
+
line = line.replace(' ', '').replace('\n', '')
|
60 |
+
if line == "none":
|
61 |
+
tmp.append([0.] * 256)
|
62 |
+
else :
|
63 |
+
tmp.append([1.] * 256)
|
64 |
+
fp.close()
|
65 |
+
|
66 |
+
aspect_vec = aspect_vec * tmp
|
67 |
+
aspect_vec = torch.FloatTensor(aspect_vec).to(device)
|
68 |
+
self.batch_data['aspect_centroid'] = aspect_vec
|
69 |
+
self.batch_data['n_aspects'] = aspect_vec.shape[0]
|
70 |
+
|
71 |
+
def load_models(self):
|
72 |
+
self.base_models['embedding'] = torch.nn.Embedding(self.batch_data['vocab_size'], emb_size).to(device)
|
73 |
+
emb_para = torch.FloatTensor(self.batch_data['pretrain_emb']).to(device)
|
74 |
+
self.base_models['embedding'].weight = torch.nn.Parameter(emb_para)
|
75 |
+
|
76 |
+
self.base_models['asp_weight'] = torch.nn.Linear(emb_size, self.batch_data['n_aspects']).to(device)
|
77 |
+
self.base_models['asp_weight'].load_state_dict(torch.load('./asp_weight.model'))
|
78 |
+
|
79 |
+
self.base_models['attn_kernel'] = torch.nn.Linear(emb_size, emb_size).to(device)
|
80 |
+
self.base_models['attn_kernel'].load_state_dict(torch.load('./attn_kernel.model'), strict=False)
|
81 |
+
|
82 |
+
|
83 |
+
def build_pipe(self):
|
84 |
+
|
85 |
+
attn_pos, lbl_pos = self.encoder(
|
86 |
+
self.batch_data['pos_sen_var'],
|
87 |
+
self.batch_data['pos_pad_mask']
|
88 |
+
)
|
89 |
+
|
90 |
+
outw = np.around(attn_pos.data.cpu().numpy().tolist(), 4)
|
91 |
+
outw = outw.tolist()
|
92 |
+
outw = outw[:len(self.batch_data['comment'].split())]
|
93 |
+
|
94 |
+
asp_weight = self.base_models['asp_weight'](lbl_pos)
|
95 |
+
# Attention weight
|
96 |
+
asp_weight = torch.softmax(asp_weight, dim=1)
|
97 |
+
|
98 |
+
return asp_weight
|
99 |
+
|
100 |
+
def encoder(self, input_, mask_):
|
101 |
+
|
102 |
+
with torch.no_grad():
|
103 |
+
emb_ = self.base_models['embedding'](input_)
|
104 |
+
|
105 |
+
print(emb_.shape)
|
106 |
+
|
107 |
+
emb_ = emb_ * mask_.unsqueeze(2)
|
108 |
+
|
109 |
+
emb_avg = torch.sum(emb_, dim=1)
|
110 |
+
norm = torch.sum(mask_, dim=1, keepdim=True) + 1e-20
|
111 |
+
|
112 |
+
# query vector
|
113 |
+
enc_ = emb_avg.div(norm.expand_as(emb_avg))
|
114 |
+
|
115 |
+
#We Ex + be
|
116 |
+
emb_trn = self.base_models['attn_kernel'](emb_)
|
117 |
+
|
118 |
+
#query vetor * (We Ex + be)
|
119 |
+
attn_ = enc_.unsqueeze(1) @ emb_trn.transpose(1, 2)
|
120 |
+
attn_ = attn_.squeeze(1)
|
121 |
+
|
122 |
+
#alignment score
|
123 |
+
attn_ = self.args.smooth_factor * torch.tanh(attn_)
|
124 |
+
|
125 |
+
attn_ = attn_.masked_fill(mask_ == 0, -1e20)
|
126 |
+
|
127 |
+
# attention weight
|
128 |
+
attn_ = torch.softmax(attn_, dim=1)
|
129 |
+
|
130 |
+
#sxE
|
131 |
+
lbl_ = attn_.unsqueeze(1) @ emb_
|
132 |
+
lbl_ = lbl_.squeeze(1)
|
133 |
+
|
134 |
+
return attn_, lbl_
|
135 |
+
|
136 |
+
def build_batch(self, review):
|
137 |
+
vocab2id = self.batch_data['vocab2id']
|
138 |
+
|
139 |
+
sen_text = []
|
140 |
+
cmt = []
|
141 |
+
# sen_text_len = 0
|
142 |
+
sen_text_len = emb_size
|
143 |
+
|
144 |
+
senid = [vocab2id[wd] for wd in review.split() if wd in vocab2id]
|
145 |
+
sen_text.append(senid)
|
146 |
+
|
147 |
+
cmt.append(review)
|
148 |
+
|
149 |
+
# if len(senid) > sen_text_len:
|
150 |
+
# sen_text_len = len(senid)
|
151 |
+
sen_text_len = min(len(senid), sen_text_len)
|
152 |
+
sen_text = [itm[:sen_text_len] + [vocab2id['<pad>'] for _ in range(sen_text_len - len(itm))] for itm in sen_text]
|
153 |
+
|
154 |
+
sen_text_var = Variable(torch.LongTensor(sen_text)).to(device)
|
155 |
+
sen_pad_mask = Variable(torch.LongTensor(sen_text)).to(device)
|
156 |
+
sen_pad_mask[sen_pad_mask != vocab2id['<pad>']] = -1
|
157 |
+
sen_pad_mask[sen_pad_mask == vocab2id['<pad>']] = 0
|
158 |
+
sen_pad_mask = -sen_pad_mask
|
159 |
+
|
160 |
+
self.batch_data['comment'] = cmt
|
161 |
+
|
162 |
+
self.batch_data['pos_sen_var'] = sen_text_var
|
163 |
+
self.batch_data['pos_pad_mask'] = sen_pad_mask
|
164 |
+
|
165 |
+
def calculate_atten_weight(self):
|
166 |
+
|
167 |
+
attn_pos, lbl_pos = self.encoder(
|
168 |
+
self.batch_data['pos_sen_var'],
|
169 |
+
self.batch_data['pos_pad_mask']
|
170 |
+
)
|
171 |
+
|
172 |
+
|
173 |
+
asp_weight = self.base_models['asp_weight'](lbl_pos)
|
174 |
+
#print('asp_weight:', asp_weight)
|
175 |
+
asp_weight = torch.softmax(asp_weight, dim=1)
|
176 |
+
#print('soft_max:', asp_weight)
|
177 |
+
|
178 |
+
return asp_weight
|
179 |
+
|
180 |
+
def get_test_data(self):
|
181 |
+
asp_weight = self.calculate_atten_weight()
|
182 |
+
asp_weight = asp_weight.data.cpu().numpy().tolist()
|
183 |
+
|
184 |
+
output = {}
|
185 |
+
output['comment'] = self.batch_data['comment']
|
186 |
+
output['aspect_weight'] = asp_weight[0]
|
187 |
+
self.test_data.append(output)
|
188 |
+
|
189 |
+
def select_top(self, data):
|
190 |
+
#print(data)
|
191 |
+
d = np.abs(data - np.median(data))
|
192 |
+
mdev = np.median(d)
|
193 |
+
s = d/mdev if mdev else 0
|
194 |
+
|
195 |
+
return s
|
196 |
+
|
197 |
+
def get_predict(self, top_pred, aspect_label, threshold=1):
|
198 |
+
pred = {'none':0, 'do_an': 0, 'gia_ca':0, 'khong_gian': 0, 'phuc_vu': 0}
|
199 |
+
try:
|
200 |
+
for i in range(len(top_pred)):
|
201 |
+
if top_pred[i] > threshold:
|
202 |
+
pred[aspect_label[i]] = 1
|
203 |
+
except:
|
204 |
+
print('Error')
|
205 |
+
return pred
|
206 |
+
|
207 |
+
def get_evaluate_result(self, input_):
|
208 |
+
|
209 |
+
aspect_label = []
|
210 |
+
fp = open('./aspect_mapping.txt', 'r', encoding='utf8')
|
211 |
+
for line in fp:
|
212 |
+
aspect_label.append(line.split()[1])
|
213 |
+
fp.close()
|
214 |
+
|
215 |
+
top_score = self.select_top(input_['aspect_weight'])
|
216 |
+
print(top_score)
|
217 |
+
curr_pred = self.get_predict(top_score, aspect_label)
|
218 |
+
|
219 |
+
aspect_key = []
|
220 |
+
for key, value in curr_pred.items():
|
221 |
+
if int(value) == 1:
|
222 |
+
aspect_key.append(key)
|
223 |
+
|
224 |
+
return self.get_aspect(aspect_key)
|
225 |
+
|
226 |
+
def get_aspect(self, pred, ignore='none'):
|
227 |
+
if len(pred) > 1:
|
228 |
+
return(pred[1:])
|
229 |
+
else:
|
230 |
+
return(['None'])
|
231 |
+
|
232 |
+
def infer(self, text=''):
|
233 |
+
self.args.task = 'sscl-infer'
|
234 |
+
|
235 |
+
text = remove_xem_them(text)
|
236 |
+
text = remove_emojis(text)
|
237 |
+
text = format_punctuation(text)
|
238 |
+
text = remove_punctuation(text)
|
239 |
+
text = clean_text(text)
|
240 |
+
text = normalize_format(text)
|
241 |
+
text = word_segment(text)
|
242 |
+
text = remove_stopwords(text)
|
243 |
+
text = format_price(text)
|
244 |
+
input_ = format_price_v2(text)
|
245 |
+
print(input_)
|
246 |
+
|
247 |
+
self.load_vocabulary()
|
248 |
+
self.load_models()
|
249 |
+
|
250 |
+
self.build_batch(input_)
|
251 |
+
self.get_test_data()
|
252 |
+
|
253 |
+
val_result = self.test_data
|
254 |
+
|
255 |
+
self.get_evaluate_result(val_result[0])
|
256 |
+
|
257 |
+
|
258 |
+
parser = argparse.ArgumentParser()
|
259 |
+
parser.add_argument('--task', default='infer')
|
260 |
+
parser.add_argument('--smooth_factor', type=float, default=0.9)
|
261 |
+
device = 'cuda:0'
|
262 |
+
emb_size = 256
|
263 |
+
|
264 |
+
args = parser.parse_args(args=[])
|
265 |
+
model = inferSSCL(args)
|
266 |
+
|
267 |
+
cmt = st.text_area('Enter some text: ')
|
268 |
+
output = model.infer(cmt)
|
269 |
+
|
270 |
+
if output:
|
271 |
+
st.title(output)
|