Jiranuwat commited on
Commit
3732dc4
·
1 Parent(s): 8ac1481

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +346 -0
app.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pickle
3
+ import pandas as pd
4
+ from sentence_transformers import SentenceTransformer, util
5
+ import numpy as np
6
+ from pymongo import MongoClient
7
+ import urllib.parse
8
+ import requests
9
+ from bertopic import BERTopic
10
+ from sklearn.feature_extraction.text import TfidfVectorizer
11
+ from sklearn.metrics.pairwise import cosine_similarity
12
+ import string
13
+ import deepcut
14
+ import unicodedata
15
+ from pythainlp.util import normalize
16
+ import torch
17
+ import csv
18
+
19
+ #initial state
20
+ if 'state' not in st.session_state:
21
+ st.session_state.state = 0
22
+ if 'age' not in st.session_state:
23
+ st.session_state.age = 0
24
+ if 'weight' not in st.session_state:
25
+ st.session_state.weight = 0
26
+ if 'height' not in st.session_state:
27
+ st.session_state.height = 0
28
+ if 'gender' not in st.session_state:
29
+ st.session_state.gender = 0
30
+ if 'food_allergy' not in st.session_state:
31
+ st.session_state.food_allergy = 0
32
+ if 'drug_allergy' not in st.session_state:
33
+ st.session_state.drug_allergy = 0
34
+ if 'congentital_disease' not in st.session_state:
35
+ st.session_state.congentital_disease = 0
36
+ if 'optional_keyword' not in st.session_state:
37
+ st.session_state.optional_keyword = 0
38
+ if 'all_recommend' not in st.session_state:
39
+ st.session_state.all_recommend = None
40
+ if 'true_check' not in st.session_state:
41
+ st.session_state.true_check = None
42
+ if 'queries' not in st.session_state:
43
+ st.session_state.queries = None
44
+ if 'string_contain' not in st.session_state:
45
+ st.session_state.string_contain = False
46
+ if 'sbert_searched_df' not in st.session_state:
47
+ st.session_state.sbert_searched_df = None
48
+ if 'string_contain_df' not in st.session_state:
49
+ st.session_state.string_contain_df = None
50
+ for i in range(10):
51
+ if 'score_'+str(i+1) not in st.session_state:
52
+ st.session_state['score_'+str(i+1)] = 'NA'
53
+
54
+ def set_state(state):
55
+ st.session_state.state = state
56
+
57
+ def split_text(text):
58
+ return text.split(',')
59
+
60
+ #import data
61
+ sbert_model = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2')
62
+
63
+ with open('corpus_embeddings.pickle', 'rb') as file:
64
+ corpus_embeddings = pickle.load(file)
65
+ corpus_embeddings = pd.DataFrame(corpus_embeddings)
66
+
67
+ topic_model = BERTopic.load("topic_model.pickle")
68
+ data = pd.read_csv('articles_data.csv')
69
+ data['child_topic'] = topic_model.topics_[:]
70
+
71
+ with open('sensitive_words.txt', 'r',encoding='utf-8') as file:
72
+ sensitive_words = file.read()
73
+ sensitive_words = sensitive_words.lower().replace('\n','').split(' ')
74
+ sensitive_words = list(set(sensitive_words))
75
+
76
+ #local function
77
+ def save_session_state_data(session_state_data, filename):
78
+ with open(filename, 'a', newline='', encoding='utf-8') as file:
79
+ writer = csv.DictWriter(file, fieldnames=session_state_data.keys())
80
+ if file.tell() == 0:
81
+ writer.writeheader()
82
+ writer.writerow(session_state_data)
83
+
84
+ def deepcut_tokenizer(text,sensitive_words=sensitive_words):
85
+ cleanedText = "".join([i for i in text if i not in string.punctuation]).lower()
86
+ cleanedText = normalize(unicodedata.normalize('NFKD', cleanedText).replace('\n','').replace('\r','').replace('\t','').replace('“','').replace('”','').replace('.','').replace('–','').replace('‘','').replace('’','').replace('ํา','ำ').replace('...','').replace(',','').replace( 'ี','ี'))
87
+ #cleanedText = re.sub(r'\d+', '', cleanedText)
88
+ cleanedText = deepcut.tokenize(cleanedText,custom_dict=sensitive_words)
89
+ #stopwords = list(thai_stopwords())+'EMagazine GJ international bangkok hospital'.lower().split(' ')
90
+ stopwords = 'EMagazine GJ international bangkok hospital'.lower().split(' ')
91
+ cleanedText = [i for i in cleanedText if i not in stopwords]
92
+ cleanedText = [i.replace(' ','') for i in cleanedText if len(i) != 1 and len(i) !=0]
93
+ cleanedText = ','.join(cleanedText)
94
+ return cleanedText
95
+
96
+ def personal_check(age,weight,height,gender):
97
+
98
+ #age check
99
+ if age >= 60:
100
+ age = 'ผู้สูงอายุ'
101
+ else:
102
+ age = 'ทำงาน'
103
+
104
+ #gender check
105
+ if gender == 'หญิง':
106
+ gender = 'ผู้หญิง'
107
+ else:
108
+ gender = 'ผู้ชาย'
109
+
110
+ #bmi check
111
+ height_meters = height / 100
112
+
113
+ bmi = weight / (height_meters ** 2)
114
+
115
+ if bmi >= 30:
116
+ bmi = 'อ้วนมาก'
117
+ elif bmi >= 23 and bmi <30:
118
+ bmi = 'อ้วน'
119
+ elif bmi >= 18.5 and bmi <23:
120
+ bmi = ''
121
+ else:
122
+ bmi = 'ผอม'
123
+
124
+ return age,gender,bmi
125
+
126
+ def sbert_search(queries,data,embeddiing,sbert_model=sbert_model):
127
+
128
+ index_lst = []
129
+ score_lst = []
130
+
131
+ query_embedding = sbert_model.encode(queries, convert_to_tensor=True)
132
+ hits = util.semantic_search(query_embedding, embeddiing, top_k=15)
133
+ hits = hits[0]
134
+ for hit in hits:
135
+ index_lst.append(hit['corpus_id'])
136
+ score_lst.append(hit['score'])
137
+
138
+ sbert_searched = data.iloc[index_lst]
139
+ sbert_searched['score'] = score_lst
140
+
141
+ return sbert_searched
142
+
143
+ def sbert_tfidf_search(queries,head,topic_model=topic_model,data=data,corpus_embeddings=corpus_embeddings):
144
+
145
+ similar_df = None
146
+ text_to_predict_token = deepcut_tokenizer(queries)
147
+
148
+ # Find topics
149
+ try:
150
+ similar_topics, similarity = topic_model.find_topics(text_to_predict_token, top_n=1)
151
+ except:
152
+ similar_topics, similarity = topic_model.find_topics(queries, top_n=1)
153
+
154
+ # Example DataFrame
155
+ similar_df = data[data['child_topic'] == similar_topics[0]]
156
+
157
+ # TF-IDF vectorizer
158
+ vectorizer = TfidfVectorizer(tokenizer=lambda x: x, lowercase=False)
159
+ tfidf_matrix = vectorizer.fit_transform(similar_df['text_token'])
160
+
161
+ # TF-IDF vector for input text
162
+ text_tfidf = vectorizer.transform([text_to_predict_token])
163
+
164
+ # Compute cosine similarity
165
+ similarity_scores = cosine_similarity(tfidf_matrix, text_tfidf)
166
+
167
+ # Add similarity scores to DataFrame
168
+ similar_df['score'] = similarity_scores
169
+
170
+ similar_df = similar_df.sort_values('score', ascending=False).head(15)
171
+
172
+ select_corpus = corpus_embeddings.iloc[similar_df.index.sort_values()]
173
+ similar_embedding = torch.tensor(select_corpus.values)
174
+ similar_searched = sbert_search(queries,similar_df,similar_embedding)
175
+ sbert_searched = sbert_search(queries,data,torch.tensor(corpus_embeddings.values))
176
+ combined_searched = pd.concat([similar_searched,sbert_searched])
177
+ output = combined_searched.sort_values('score', ascending=False).head(head)
178
+
179
+ return output
180
+
181
+ def string_contain_search(queries,sample,data=data):
182
+ data['all_content'] = data['title']+data['content']
183
+ return data[data['all_content'].str.contains(queries,na=False)].sample(sample)
184
+
185
+ #main
186
+ def main():
187
+ #header
188
+ st.markdown("<h1 style='text-align: center; color: black;'>---ระบบแนะนำบทความสุขภาพ---</h1>", unsafe_allow_html=True)
189
+ st.subheader("ให้คะแนนบทความหน่อยนะครับ😄")
190
+
191
+ with st.form('user_info'):
192
+
193
+ #personal information input
194
+ age = st.slider("อายุ", 10, 100, 25)
195
+
196
+ col1, col2 = st.columns(2)
197
+ with col1:
198
+ weight = st.number_input("น้ำหนัก (Kg.): ",30.0,120.0,step=1.0,value=50.0)
199
+ with col2:
200
+ height = st.number_input("ส่วนสูง (cm.): ",100.0,250.0,step=1.0,value=150.0)
201
+
202
+ col3, col4, col5 = st.columns(3)
203
+ with col3:
204
+ gender = st.selectbox('เพศ',('ชาย', 'หญิง'))
205
+ with col4:
206
+ food_allergy = st.selectbox('แพ้อาหาร?',('ไม่แพ้', 'แพ้อาหาร'))
207
+ with col5:
208
+ drug_allergy = st.selectbox('แพ้ยา?',('ไม่แพ้', 'แพ้ยา'))
209
+ congentital_disease = st.text_input('โรคประจำตัวของคุณ (ถ้าหากไม่มี ไม่ต้องกรอก หรือใส่ "ไม่มี")')
210
+ optional_keyword = st.text_input('คำค้นหาเพิ่มเติม (ถ้ามี)')
211
+
212
+ st.form_submit_button(on_click=set_state,args=(1,))
213
+
214
+ if st.session_state.state == 1:
215
+
216
+ #asign state
217
+ st.session_state.age = age
218
+ st.session_state.weight = weight
219
+ st.session_state.height = height
220
+ st.session_state.gender = gender
221
+ st.session_state.food_allergy = food_allergy
222
+ st.session_state.drug_allergy = drug_allergy
223
+ st.session_state.congentital_disease = congentital_disease
224
+ st.session_state.optional_keyword = optional_keyword
225
+
226
+ #algorithm
227
+ age,gender,bmi = personal_check(age,weight,height,gender)
228
+
229
+ if food_allergy == 'ไม่แพ้':
230
+ food_allergy = ''
231
+ if drug_allergy == 'ไม่แพ้':
232
+ drug_allergy = ''
233
+ if congentital_disease == 'ไม่มี':
234
+ congentital_disease = ''
235
+ if congentital_disease != '' or optional_keyword != '':
236
+ queries = optional_keyword+congentital_disease
237
+ else:
238
+ queries = gender+age+bmi+food_allergy+drug_allergy+congentital_disease+optional_keyword
239
+
240
+ #Bertopic search
241
+ try:
242
+ sbert_searched = sbert_tfidf_search(queries,5)
243
+ string_contain = string_contain_search(queries,5)
244
+ all_recommend = pd.concat([sbert_searched,string_contain])
245
+ all_recommend = all_recommend.drop_duplicates(subset=['url'])
246
+
247
+ if len(all_recommend) != 10:
248
+ for i in range(3):
249
+ if len(all_recommend) < 10:
250
+ all_recommend = None
251
+ sbert_searched = sbert_tfidf_search(queries,5+i+1)
252
+ sbert_searched = sbert_searched.head(5)
253
+ string_contain = string_contain_search(queries,5+i+1)
254
+ string_contain = string_contain.head(5)
255
+ all_recommend = pd.concat([sbert_searched,string_contain])
256
+ all_recommend = all_recommend.drop_duplicates(subset=['url'])
257
+
258
+ st.session_state.sbert_searched_df = sbert_searched
259
+ st.session_state.string_contain_df = string_contain
260
+ st.session_state.string_contain = True
261
+
262
+ except:
263
+ sbert_searched = sbert_tfidf_search(queries,10)
264
+ st.session_state.sbert_searched_df = sbert_searched
265
+ all_recommend = sbert_searched
266
+
267
+ st.session_state.all_recommend = all_recommend
268
+ st.session_state.queries = queries
269
+ st.session_state.state = 2
270
+
271
+ if st.session_state.state == 2:
272
+ placeholder = st.empty()
273
+
274
+ #satisfaction
275
+ with placeholder.form('Satisfaction Survey'):
276
+ st.markdown("<h1 style='text-align: center; color: black;'>📰บทความสำหรับคุณ😆</h1>", unsafe_allow_html=True)
277
+ st.header("ระดับความเกี่ยวข้อง")
278
+ st.write("😞 หมายถึง ไม่เกี่ยวข้องเลย")
279
+ st.write("🙁 หมายถึง เกี่ยวข้องเล็กน้อย")
280
+ st.write("😐 หมายถึง เฉยๆ")
281
+ st.write("🙂 หมายถึง ค่อนข้างเกี่ยวข้อง")
282
+ st.write("😀 หมายถึง เกี่ยวข้องมากที่สุด")
283
+ st.write("---------------------------------------------------------------------------------------")
284
+
285
+ for i in range(len(st.session_state.all_recommend)):
286
+ st.header(str(i+1)+'. '+st.session_state.all_recommend.iloc[i]['title'])
287
+ st.markdown(f"[Page source (Click here.)]({st.session_state.all_recommend.iloc[i].url})")
288
+
289
+ try:
290
+ banner_url = urllib.parse.quote(st.session_state.all_recommend.iloc[i]['banner'], safe=':/')
291
+ response = requests.get(banner_url)
292
+ st.image(response.content)
293
+ except:
294
+ st.image('https://icon-library.com/images/no-photo-icon/no-photo-icon-1.jpg')
295
+
296
+ #satisfaction survey
297
+ st.subheader("Satisfaction Survey")
298
+ st.write("บทความที่แนะนำเกี่ยวข้องกับคุณมากเพียงใด")
299
+ st.radio('ระดับความพึงพอใจ',['NA','😞','🙁','😐','🙂','😀'],horizontal=True,key='score_'+str(i+1))
300
+ st.write("---------------------------------------------------------------------------------------")
301
+
302
+ if st.form_submit_button("ยืนยันการส่งคำตอบ"):
303
+ # Check if all articles have satisfaction levels selected
304
+ st.session_state.true_check = []
305
+ for satis_val in [st.session_state[i] for i in ['score_' + str(i+1) for i in range(10)]]:
306
+ if satis_val != 'NA':
307
+ st.session_state.true_check.append(True)
308
+ else:
309
+ st.session_state.true_check.append(False)
310
+
311
+ if np.all(st.session_state.true_check):
312
+ st.session_state.state = 3
313
+ placeholder.empty()
314
+
315
+ else:
316
+ idx = []
317
+ for i in range(len(st.session_state.true_check)):
318
+ if st.session_state.true_check[i] == False:
319
+ idx.append(i+1)
320
+ article_indexes = ', '.join(map(str, idx))
321
+ st.warning(f":red[กรุณาให้คะแนนบทความที่ {article_indexes} ด้วยครับ]")
322
+
323
+ if st.session_state.state == 3:
324
+ st.success('บันทึกคำตอบแล้ว')
325
+
326
+ st.session_state.all_recommend = st.session_state.all_recommend.to_dict(orient='records')
327
+ if st.session_state.sbert_searched_df is not None:
328
+ st.session_state.sbert_searched_df = st.session_state.sbert_searched_df.to_dict(orient='records')
329
+ if st.session_state.string_contain_df is not None:
330
+ st.session_state.string_contain_df = st.session_state.string_contain_df.to_dict(orient='records')
331
+
332
+ try:
333
+ save_session_state_data(st.session_state.to_dict(), 'satisfaction.csv')
334
+
335
+ except:
336
+ #database insertion
337
+ client = MongoClient('mongodb://192.168.1.103:27017/')
338
+ database = client['test']
339
+ collection = database['satisfy_articles']
340
+ collection.insert_one(st.session_state.to_dict())
341
+
342
+ finally:
343
+ st.session_state.state = 0
344
+
345
+ if __name__ == "__main__":
346
+ main()