commited on
update infer
Browse files
@@ -1,3 +1,271 @@
1 |
import streamlit as st
2 |
3 |
1 |
import streamlit as st
2 |
3 |
import numpy as np
4 |
5 |
import torch
6 |
from torch.autograd import Variable
7 |
8 |
import argparse
9 |
import os
10 |
import re
11 |
12 |
from data_preprocessing import remove_xem_them, remove_emojis, remove_stopwords, format_punctuation, remove_punctuation, clean_text, normalize_format, word_segment, format_price, format_price_v2
13 |
14 |
class inferSSCL():
15 |
def __init__(self, args='None'):
16 |
self.args = args
17 |
self.base_models = {}
18 |
self.batch_data = {}
19 |
self.test_data = []
20 |
21 |
def load_vocab_pretrain(self, file_pretrain_vocab, file_pretrain_vec, pad_tokens=True):
22 |
vocab2id = {'<pad>': 0}
23 |
id2vocab = {0: '<pad>'}
24 |
25 |
cnt = len(id2vocab)
26 |
with open(file_pretrain_vocab, 'r', encoding='utf-8') as fp:
27 |
for line in fp:
28 |
arr = re.split(' ', line[:-1])
29 |
vocab2id[arr[1]] = cnt
30 |
id2vocab[cnt] = arr[1]
31 |
cnt += 1
32 |
# word embedding
33 |
pretrain_vec = np.load(file_pretrain_vec)
34 |
pad_vec = np.zeros([1, pretrain_vec.shape[1]])
35 |
pretrain_vec = np.vstack((pad_vec, pretrain_vec))
36 |
return vocab2id, id2vocab, pretrain_vec
37 |
38 |
def load_vocabulary(self):
39 |
cluster_dir = './'
40 |
file_wordvec = 'vectors.npy'
41 |
file_vocab = 'vocab.txt'
42 |
file_kmeans_centroid = 'aspect_centroid.txt'
43 |
file_aspect_mapping = 'aspect_mapping.txt'
44 |
45 |
vocab2id, id2vocab, pretrain_vec = self.load_vocab_pretrain(os.path.join(cluster_dir, file_vocab), os.path.join(cluster_dir, file_wordvec))
46 |
vocab_size = len(vocab2id)
47 |
48 |
self.batch_data['vocab2id'] = vocab2id
49 |
self.batch_data['id2vocab'] = id2vocab
50 |
self.batch_data['pretrain_emb'] = pretrain_vec
51 |
self.batch_data['vocab_size'] = vocab_size
52 |
53 |
aspect_vec = np.loadtxt(os.path.join(cluster_dir, file_kmeans_centroid), dtype=float)
54 |
55 |
tmp = []
56 |
fp = open(os.path.join(cluster_dir, file_aspect_mapping), 'r')
57 |
for line in fp:
58 |
line = re.sub(r'[0-9]+', '', line)
59 |
line = line.replace(' ', '').replace('\n', '')
60 |
if line == "none":
61 |
tmp.append([0.] * 256)
62 |
else :
63 |
tmp.append([1.] * 256)
64 |
65 |
66 |
aspect_vec = aspect_vec * tmp
67 |
aspect_vec = torch.FloatTensor(aspect_vec).to(device)
68 |
self.batch_data['aspect_centroid'] = aspect_vec
69 |
self.batch_data['n_aspects'] = aspect_vec.shape[0]
70 |
71 |
def load_models(self):
72 |
self.base_models['embedding'] = torch.nn.Embedding(self.batch_data['vocab_size'], emb_size).to(device)
73 |
emb_para = torch.FloatTensor(self.batch_data['pretrain_emb']).to(device)
74 |
self.base_models['embedding'].weight = torch.nn.Parameter(emb_para)
75 |
76 |
self.base_models['asp_weight'] = torch.nn.Linear(emb_size, self.batch_data['n_aspects']).to(device)
77 |
78 |
79 |
self.base_models['attn_kernel'] = torch.nn.Linear(emb_size, emb_size).to(device)
80 |
self.base_models['attn_kernel'].load_state_dict(torch.load('./attn_kernel.model'), strict=False)
81 |
82 |
83 |
def build_pipe(self):
84 |
85 |
attn_pos, lbl_pos = self.encoder(
86 |
87 |
88 |
89 |
90 |
outw = np.around(, 4)
91 |
outw = outw.tolist()
92 |
outw = outw[:len(self.batch_data['comment'].split())]
93 |
94 |
asp_weight = self.base_models['asp_weight'](lbl_pos)
95 |
# Attention weight
96 |
asp_weight = torch.softmax(asp_weight, dim=1)
97 |
98 |
return asp_weight
99 |
100 |
def encoder(self, input_, mask_):
101 |
102 |
with torch.no_grad():
103 |
emb_ = self.base_models['embedding'](input_)
104 |
105 |
106 |
107 |
emb_ = emb_ * mask_.unsqueeze(2)
108 |
109 |
emb_avg = torch.sum(emb_, dim=1)
110 |
norm = torch.sum(mask_, dim=1, keepdim=True) + 1e-20
111 |
112 |
# query vector
113 |
enc_ = emb_avg.div(norm.expand_as(emb_avg))
114 |
115 |
#We Ex + be
116 |
emb_trn = self.base_models['attn_kernel'](emb_)
117 |
118 |
#query vetor * (We Ex + be)
119 |
attn_ = enc_.unsqueeze(1) @ emb_trn.transpose(1, 2)
120 |
attn_ = attn_.squeeze(1)
121 |
122 |
#alignment score
123 |
attn_ = self.args.smooth_factor * torch.tanh(attn_)
124 |
125 |
attn_ = attn_.masked_fill(mask_ == 0, -1e20)
126 |
127 |
# attention weight
128 |
attn_ = torch.softmax(attn_, dim=1)
129 |
130 |
131 |
lbl_ = attn_.unsqueeze(1) @ emb_
132 |
lbl_ = lbl_.squeeze(1)
133 |
134 |
return attn_, lbl_
135 |
136 |
def build_batch(self, review):
137 |
vocab2id = self.batch_data['vocab2id']
138 |
139 |
sen_text = []
140 |
cmt = []
141 |
# sen_text_len = 0
142 |
sen_text_len = emb_size
143 |
144 |
senid = [vocab2id[wd] for wd in review.split() if wd in vocab2id]
145 |
146 |
147 |
148 |
149 |
# if len(senid) > sen_text_len:
150 |
# sen_text_len = len(senid)
151 |
sen_text_len = min(len(senid), sen_text_len)
152 |
sen_text = [itm[:sen_text_len] + [vocab2id['<pad>'] for _ in range(sen_text_len - len(itm))] for itm in sen_text]
153 |
154 |
sen_text_var = Variable(torch.LongTensor(sen_text)).to(device)
155 |
sen_pad_mask = Variable(torch.LongTensor(sen_text)).to(device)
156 |
sen_pad_mask[sen_pad_mask != vocab2id['<pad>']] = -1
157 |
sen_pad_mask[sen_pad_mask == vocab2id['<pad>']] = 0
158 |
sen_pad_mask = -sen_pad_mask
159 |
160 |
self.batch_data['comment'] = cmt
161 |
162 |
self.batch_data['pos_sen_var'] = sen_text_var
163 |
self.batch_data['pos_pad_mask'] = sen_pad_mask
164 |
165 |
def calculate_atten_weight(self):
166 |
167 |
attn_pos, lbl_pos = self.encoder(
168 |
169 |
170 |
171 |
172 |
173 |
asp_weight = self.base_models['asp_weight'](lbl_pos)
174 |
#print('asp_weight:', asp_weight)
175 |
asp_weight = torch.softmax(asp_weight, dim=1)
176 |
#print('soft_max:', asp_weight)
177 |
178 |
return asp_weight
179 |
180 |
def get_test_data(self):
181 |
asp_weight = self.calculate_atten_weight()
182 |
asp_weight =
183 |
184 |
output = {}
185 |
output['comment'] = self.batch_data['comment']
186 |
output['aspect_weight'] = asp_weight[0]
187 |
188 |
189 |
def select_top(self, data):
190 |
191 |
d = np.abs(data - np.median(data))
192 |
mdev = np.median(d)
193 |
s = d/mdev if mdev else 0
194 |
195 |
return s
196 |
197 |
def get_predict(self, top_pred, aspect_label, threshold=1):
198 |
pred = {'none':0, 'do_an': 0, 'gia_ca':0, 'khong_gian': 0, 'phuc_vu': 0}
199 |
200 |
for i in range(len(top_pred)):
201 |
if top_pred[i] > threshold:
202 |
pred[aspect_label[i]] = 1
203 |
204 |
205 |
return pred
206 |
207 |
def get_evaluate_result(self, input_):
208 |
209 |
aspect_label = []
210 |
fp = open('./aspect_mapping.txt', 'r', encoding='utf8')
211 |
for line in fp:
212 |
213 |
214 |
215 |
top_score = self.select_top(input_['aspect_weight'])
216 |
217 |
curr_pred = self.get_predict(top_score, aspect_label)
218 |
219 |
aspect_key = []
220 |
for key, value in curr_pred.items():
221 |
if int(value) == 1:
222 |
223 |
224 |
return self.get_aspect(aspect_key)
225 |
226 |
def get_aspect(self, pred, ignore='none'):
227 |
if len(pred) > 1:
228 |
229 |
230 |
231 |
232 |
def infer(self, text=''):
233 |
self.args.task = 'sscl-infer'
234 |
235 |
text = remove_xem_them(text)
236 |
text = remove_emojis(text)
237 |
text = format_punctuation(text)
238 |
text = remove_punctuation(text)
239 |
text = clean_text(text)
240 |
text = normalize_format(text)
241 |
text = word_segment(text)
242 |
text = remove_stopwords(text)
243 |
text = format_price(text)
244 |
input_ = format_price_v2(text)
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
val_result = self.test_data
254 |
255 |
256 |
257 |
258 |
parser = argparse.ArgumentParser()
259 |
parser.add_argument('--task', default='infer')
260 |
parser.add_argument('--smooth_factor', type=float, default=0.9)
261 |
device = 'cuda:0'
262 |
emb_size = 256
263 |
264 |
args = parser.parse_args(args=[])
265 |
model = inferSSCL(args)
266 |
267 |
cmt = st.text_area('Enter some text: ')
268 |
output = model.infer(cmt)
269 |
270 |
if output:
271 |