Jiranuwat commited on
Commit
1484ac4
·
1 Parent(s): 4815aaa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -197
app.py CHANGED
@@ -3,22 +3,14 @@ import pickle
3
  import pandas as pd
4
  from sentence_transformers import SentenceTransformer, util
5
  import numpy as np
6
- from pymongo import MongoClient
7
  import urllib.parse
8
  import requests
9
- from bertopic import BERTopic
10
- from sklearn.feature_extraction.text import TfidfVectorizer
11
- from sklearn.metrics.pairwise import cosine_similarity
12
- import string
13
- import deepcut
14
- import unicodedata
15
- from pythainlp.util import normalize
16
- import torch
17
- import csv
18
 
19
  #initial state
20
- if 'state' not in st.session_state:
21
- st.session_state.state = 0
 
 
22
  if 'age' not in st.session_state:
23
  st.session_state.age = 0
24
  if 'weight' not in st.session_state:
@@ -33,26 +25,25 @@ if 'drug_allergy' not in st.session_state:
33
  st.session_state.drug_allergy = 0
34
  if 'congentital_disease' not in st.session_state:
35
  st.session_state.congentital_disease = 0
36
- if 'optional_keyword' not in st.session_state:
37
- st.session_state.optional_keyword = 0
38
- if 'all_recommend' not in st.session_state:
39
- st.session_state.all_recommend = None
40
- if 'true_check' not in st.session_state:
41
- st.session_state.true_check = None
42
  if 'queries' not in st.session_state:
43
  st.session_state.queries = None
44
- if 'string_contain' not in st.session_state:
45
- st.session_state.string_contain = False
46
  if 'sbert_searched_df' not in st.session_state:
47
  st.session_state.sbert_searched_df = None
48
- if 'string_contain_df' not in st.session_state:
49
- st.session_state.string_contain_df = None
 
 
50
  for i in range(10):
51
  if 'score_'+str(i+1) not in st.session_state:
52
  st.session_state['score_'+str(i+1)] = 'NA'
 
 
53
 
54
- def set_state(state):
55
- st.session_state.state = state
 
 
 
56
 
57
  def split_text(text):
58
  return text.split(',')
@@ -60,39 +51,12 @@ def split_text(text):
60
  #import data
61
  sbert_model = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2')
62
 
63
- with open('corpus_embeddings.pickle', 'rb') as file:
64
- corpus_embeddings = pickle.load(file)
65
- corpus_embeddings = pd.DataFrame(corpus_embeddings)
66
-
67
- topic_model = BERTopic.load("Jiranuwat/topic_model",embedding_model=sbert_model)
68
  data = pd.read_csv('articles_data.csv')
69
- data['child_topic'] = topic_model.topics_[:]
70
 
71
- with open('sensitive_words.txt', 'r',encoding='utf-8') as file:
72
- sensitive_words = file.read()
73
- sensitive_words = sensitive_words.lower().replace('\n','').split(' ')
74
- sensitive_words = list(set(sensitive_words))
75
 
76
  #local function
77
- def save_session_state_data(session_state_data, filename):
78
- with open(filename, 'a', newline='', encoding='utf-8') as file:
79
- writer = csv.DictWriter(file, fieldnames=session_state_data.keys())
80
- if file.tell() == 0:
81
- writer.writeheader()
82
- writer.writerow(session_state_data)
83
-
84
- def deepcut_tokenizer(text,sensitive_words=sensitive_words):
85
- cleanedText = "".join([i for i in text if i not in string.punctuation]).lower()
86
- cleanedText = normalize(unicodedata.normalize('NFKD', cleanedText).replace('\n','').replace('\r','').replace('\t','').replace('“','').replace('”','').replace('.','').replace('–','').replace('‘','').replace('’','').replace('ํา','ำ').replace('...','').replace(',','').replace( 'ี','ี'))
87
- #cleanedText = re.sub(r'\d+', '', cleanedText)
88
- cleanedText = deepcut.tokenize(cleanedText,custom_dict=sensitive_words)
89
- #stopwords = list(thai_stopwords())+'EMagazine GJ international bangkok hospital'.lower().split(' ')
90
- stopwords = 'EMagazine GJ international bangkok hospital'.lower().split(' ')
91
- cleanedText = [i for i in cleanedText if i not in stopwords]
92
- cleanedText = [i.replace(' ','') for i in cleanedText if len(i) != 1 and len(i) !=0]
93
- cleanedText = ','.join(cleanedText)
94
- return cleanedText
95
-
96
  def personal_check(age,weight,height,gender):
97
 
98
  #age check
@@ -129,7 +93,7 @@ def sbert_search(queries,data,embeddiing,sbert_model=sbert_model):
129
  score_lst = []
130
 
131
  query_embedding = sbert_model.encode(queries, convert_to_tensor=True)
132
- hits = util.semantic_search(query_embedding, embeddiing, top_k=15)
133
  hits = hits[0]
134
  for hit in hits:
135
  index_lst.append(hit['corpus_id'])
@@ -140,53 +104,9 @@ def sbert_search(queries,data,embeddiing,sbert_model=sbert_model):
140
 
141
  return sbert_searched
142
 
143
- def sbert_tfidf_search(queries,head,topic_model=topic_model,data=data,corpus_embeddings=corpus_embeddings):
144
-
145
- similar_df = None
146
- text_to_predict_token = deepcut_tokenizer(queries)
147
-
148
- # Find topics
149
- try:
150
- similar_topics, similarity = topic_model.find_topics(text_to_predict_token, top_n=1)
151
- except:
152
- similar_topics, similarity = topic_model.find_topics(queries, top_n=1)
153
-
154
- # Example DataFrame
155
- similar_df = data[data['child_topic'] == similar_topics[0]]
156
-
157
- # TF-IDF vectorizer
158
- vectorizer = TfidfVectorizer(tokenizer=lambda x: x, lowercase=False)
159
- tfidf_matrix = vectorizer.fit_transform(similar_df['text_token'])
160
-
161
- # TF-IDF vector for input text
162
- text_tfidf = vectorizer.transform([text_to_predict_token])
163
-
164
- # Compute cosine similarity
165
- similarity_scores = cosine_similarity(tfidf_matrix, text_tfidf)
166
-
167
- # Add similarity scores to DataFrame
168
- similar_df['score'] = similarity_scores
169
-
170
- similar_df = similar_df.sort_values('score', ascending=False).head(15)
171
-
172
- select_corpus = corpus_embeddings.iloc[similar_df.index.sort_values()]
173
- similar_embedding = torch.tensor(select_corpus.values)
174
- similar_searched = sbert_search(queries,similar_df,similar_embedding)
175
- sbert_searched = sbert_search(queries,data,torch.tensor(corpus_embeddings.values))
176
- combined_searched = pd.concat([similar_searched,sbert_searched])
177
- output = combined_searched.sort_values('score', ascending=False).head(head)
178
-
179
- return output
180
-
181
- def string_contain_search(queries,sample,data=data):
182
- data['all_content'] = data['title']+data['content']
183
- return data[data['all_content'].str.contains(queries,na=False)].sample(sample)
184
-
185
- #main
186
- def main():
187
  #header
188
  st.markdown("<h1 style='text-align: center; color: black;'>---ระบบแนะนำบทความสุขภาพ---</h1>", unsafe_allow_html=True)
189
- st.subheader("ให้คะแนนบทความหน่อยนะครับ😄")
190
 
191
  with st.form('user_info'):
192
 
@@ -207,11 +127,10 @@ def main():
207
  with col5:
208
  drug_allergy = st.selectbox('แพ้ยา?',('ไม่แพ้', 'แพ้ยา'))
209
  congentital_disease = st.text_input('โรคประจำตัวของคุณ (ถ้าหากไม่มี ไม่ต้องกรอก หรือใส่ "ไม่มี")')
210
- optional_keyword = st.text_input('คำค้นหาเพิ่มเติม (ถ้ามี)')
211
 
212
- st.form_submit_button(on_click=set_state,args=(1,))
213
 
214
- if st.session_state.state == 1:
215
 
216
  #asign state
217
  st.session_state.age = age
@@ -221,7 +140,6 @@ def main():
221
  st.session_state.food_allergy = food_allergy
222
  st.session_state.drug_allergy = drug_allergy
223
  st.session_state.congentital_disease = congentital_disease
224
- st.session_state.optional_keyword = optional_keyword
225
 
226
  #algorithm
227
  age,gender,bmi = personal_check(age,weight,height,gender)
@@ -232,115 +150,88 @@ def main():
232
  drug_allergy = ''
233
  if congentital_disease == 'ไม่มี':
234
  congentital_disease = ''
235
- if congentital_disease != '' or optional_keyword != '':
236
- queries = optional_keyword+congentital_disease
 
237
  else:
238
- queries = gender+age+bmi+food_allergy+drug_allergy+congentital_disease+optional_keyword
239
 
240
  #Bertopic search
241
- try:
242
- sbert_searched = sbert_tfidf_search(queries,5)
243
- string_contain = string_contain_search(queries,5)
244
- all_recommend = pd.concat([sbert_searched,string_contain])
245
- all_recommend = all_recommend.drop_duplicates(subset=['url'])
246
-
247
- if len(all_recommend) != 10:
248
- for i in range(3):
249
- if len(all_recommend) < 10:
250
- all_recommend = None
251
- sbert_searched = sbert_tfidf_search(queries,5+i+1)
252
- sbert_searched = sbert_searched.head(5)
253
- string_contain = string_contain_search(queries,5+i+1)
254
- string_contain = string_contain.head(5)
255
- all_recommend = pd.concat([sbert_searched,string_contain])
256
- all_recommend = all_recommend.drop_duplicates(subset=['url'])
257
-
258
- st.session_state.sbert_searched_df = sbert_searched
259
- st.session_state.string_contain_df = string_contain
260
- st.session_state.string_contain = True
261
-
262
- except:
263
- sbert_searched = sbert_tfidf_search(queries,10)
264
- st.session_state.sbert_searched_df = sbert_searched
265
- all_recommend = sbert_searched
266
-
267
- st.session_state.all_recommend = all_recommend
268
  st.session_state.queries = queries
269
- st.session_state.state = 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
271
- if st.session_state.state == 2:
272
- placeholder = st.empty()
 
273
 
274
- #satisfaction
275
- with placeholder.form('Satisfaction Survey'):
276
  st.markdown("<h1 style='text-align: center; color: black;'>📰บทความสำหรับคุณ😆</h1>", unsafe_allow_html=True)
277
- st.header("ระดับความเกี่ยวข้อง")
278
- st.write("😞 หมายถึง ไม่เกี่ยวข้องเลย")
279
- st.write("🙁 หมายถึง เกี่ยวข้องเล็กน้อย")
280
- st.write("😐 หมายถึง เฉยๆ")
281
- st.write("🙂 หมายถึง ค่อนข้างเกี่ยวข้อง")
282
- st.write("😀 หมายถึง เกี่ยวข้องมากที่สุด")
283
  st.write("---------------------------------------------------------------------------------------")
284
 
285
- for i in range(len(st.session_state.all_recommend)):
286
- st.header(str(i+1)+'. '+st.session_state.all_recommend.iloc[i]['title'])
287
- st.markdown(f"[Page source (Click here.)]({st.session_state.all_recommend.iloc[i].url})")
288
 
289
  try:
290
- banner_url = urllib.parse.quote(st.session_state.all_recommend.iloc[i]['banner'], safe=':/')
291
- response = requests.get(banner_url)
292
  st.image(response.content)
293
  except:
294
  st.image('https://icon-library.com/images/no-photo-icon/no-photo-icon-1.jpg')
 
 
 
 
295
 
296
- #satisfaction survey
297
- st.subheader("Satisfaction Survey")
298
- st.write("บทความที่แนะนำเกี่ยวข้องกับคุณมากเพียงใด")
299
- st.radio('ระดับความพึงพอใจ',['NA','😞','🙁','😐','🙂','😀'],horizontal=True,key='score_'+str(i+1))
300
- st.write("---------------------------------------------------------------------------------------")
301
-
302
- if st.form_submit_button("ยืนยันการส่งคำตอบ"):
303
- # Check if all articles have satisfaction levels selected
304
- st.session_state.true_check = []
305
- for satis_val in [st.session_state[i] for i in ['score_' + str(i+1) for i in range(10)]]:
306
- if satis_val != 'NA':
307
- st.session_state.true_check.append(True)
308
- else:
309
- st.session_state.true_check.append(False)
310
-
311
- if np.all(st.session_state.true_check):
312
- st.session_state.state = 3
313
- placeholder.empty()
314
-
315
- else:
316
- idx = []
317
- for i in range(len(st.session_state.true_check)):
318
- if st.session_state.true_check[i] == False:
319
- idx.append(i+1)
320
- article_indexes = ', '.join(map(str, idx))
321
- st.warning(f":red[กรุณาให้คะแนนบทความที่ {article_indexes} ด้วยครับ]")
322
-
323
- if st.session_state.state == 3:
324
- st.success('บันทึกคำตอบแล้ว')
325
-
326
- st.session_state.all_recommend = st.session_state.all_recommend.to_dict(orient='records')
327
- if st.session_state.sbert_searched_df is not None:
328
- st.session_state.sbert_searched_df = st.session_state.sbert_searched_df.to_dict(orient='records')
329
- if st.session_state.string_contain_df is not None:
330
- st.session_state.string_contain_df = st.session_state.string_contain_df.to_dict(orient='records')
331
-
332
- try:
333
- save_session_state_data(st.session_state.to_dict(), 'satisfaction.csv')
334
-
335
- except:
336
- #database insertion
337
- client = MongoClient('mongodb://192.168.1.103:27017/')
338
- database = client['test']
339
- collection = database['satisfy_articles']
340
- collection.insert_one(st.session_state.to_dict())
341
-
342
- finally:
343
- st.session_state.state = 0
344
 
 
 
 
 
 
345
  if __name__ == "__main__":
346
  main()
 
3
  import pandas as pd
4
  from sentence_transformers import SentenceTransformer, util
5
  import numpy as np
 
6
  import urllib.parse
7
  import requests
 
 
 
 
 
 
 
 
 
8
 
9
  #initial state
10
+ if 'state_p1' not in st.session_state:
11
+ st.session_state.state_p1 = 0
12
+ if 'state_p2' not in st.session_state:
13
+ st.session_state.state_p2 = 0
14
  if 'age' not in st.session_state:
15
  st.session_state.age = 0
16
  if 'weight' not in st.session_state:
 
25
  st.session_state.drug_allergy = 0
26
  if 'congentital_disease' not in st.session_state:
27
  st.session_state.congentital_disease = 0
 
 
 
 
 
 
28
  if 'queries' not in st.session_state:
29
  st.session_state.queries = None
 
 
30
  if 'sbert_searched_df' not in st.session_state:
31
  st.session_state.sbert_searched_df = None
32
+ if 'queries_p2' not in st.session_state:
33
+ st.session_state.queries_p2 = None
34
+ if 'sbert_searched_df_p2' not in st.session_state:
35
+ st.session_state.sbert_searched_df_p2 = None
36
  for i in range(10):
37
  if 'score_'+str(i+1) not in st.session_state:
38
  st.session_state['score_'+str(i+1)] = 'NA'
39
+ if 'current_page' not in st.session_state:
40
+ st.session_state.current_page = 1
41
 
42
+ def set_state_p1(state):
43
+ st.session_state.state_p1 = state
44
+
45
+ def set_state_p2(state):
46
+ st.session_state.state_p2 = state
47
 
48
  def split_text(text):
49
  return text.split(',')
 
51
  #import data
52
  sbert_model = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2')
53
 
 
 
 
 
 
54
  data = pd.read_csv('articles_data.csv')
 
55
 
56
+ with open('corpus_embeddings.pickle', 'rb') as file:
57
+ corpus_embeddings = pickle.load(file)
 
 
58
 
59
  #local function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  def personal_check(age,weight,height,gender):
61
 
62
  #age check
 
93
  score_lst = []
94
 
95
  query_embedding = sbert_model.encode(queries, convert_to_tensor=True)
96
+ hits = util.semantic_search(query_embedding, embeddiing, top_k=10)
97
  hits = hits[0]
98
  for hit in hits:
99
  index_lst.append(hit['corpus_id'])
 
104
 
105
  return sbert_searched
106
 
107
+ def page1_recommendation():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  #header
109
  st.markdown("<h1 style='text-align: center; color: black;'>---ระบบแนะนำบทความสุขภาพ---</h1>", unsafe_allow_html=True)
 
110
 
111
  with st.form('user_info'):
112
 
 
127
  with col5:
128
  drug_allergy = st.selectbox('แพ้ยา?',('ไม่แพ้', 'แพ้ยา'))
129
  congentital_disease = st.text_input('โรคประจำตัวของคุณ (ถ้าหากไม่มี ไม่ต้องกรอก หรือใส่ "ไม่มี")')
 
130
 
131
+ st.form_submit_button(on_click=set_state_p1,args=(1,))
132
 
133
+ if st.session_state.state_p1 == 1:
134
 
135
  #asign state
136
  st.session_state.age = age
 
140
  st.session_state.food_allergy = food_allergy
141
  st.session_state.drug_allergy = drug_allergy
142
  st.session_state.congentital_disease = congentital_disease
 
143
 
144
  #algorithm
145
  age,gender,bmi = personal_check(age,weight,height,gender)
 
150
  drug_allergy = ''
151
  if congentital_disease == 'ไม่มี':
152
  congentital_disease = ''
153
+
154
+ if congentital_disease == '':
155
+ queries = gender+age+bmi+food_allergy+drug_allergy
156
  else:
157
+ queries = congentital_disease
158
 
159
  #Bertopic search
160
+ sbert_searched = sbert_search(queries,data,corpus_embeddings)
161
+
162
+ st.session_state.sbert_searched_df = sbert_searched
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  st.session_state.queries = queries
164
+ st.session_state.state_p1 = 2
165
+
166
+ if st.session_state.state_p1 == 2:
167
+
168
+ with st.form('recommendations'):
169
+ st.markdown("<h1 style='text-align: center; color: black;'>📰บทความสำหรับคุณ😆</h1>", unsafe_allow_html=True)
170
+ st.write("---------------------------------------------------------------------------------------")
171
+
172
+ for i in range(len(st.session_state.sbert_searched_df)):
173
+ st.header(str(i+1)+'. '+st.session_state.sbert_searched_df.iloc[i]['title'])
174
+ st.markdown(f"[Page source (Click here.)]({st.session_state.sbert_searched_df.iloc[i].url})")
175
+
176
+ try:
177
+ banner_url = urllib.parse.quote(st.session_state.sbert_searched_df.iloc[i]['banner'], safe=':/')
178
+ response = requests.get(banner_url,timeout=5)
179
+ st.image(response.content)
180
+ except:
181
+ st.image('https://icon-library.com/images/no-photo-icon/no-photo-icon-1.jpg')
182
+ finally:
183
+ st.write("---------------------------------------------------------------------------------------")
184
+
185
+ st.form_submit_button('Submit',on_click=set_state_p1,args=(0,))
186
+
187
+ def page2_search_engine():
188
+ st.title("Search engine")
189
+
190
+ with st.form('queries'):
191
+ queries = st.text_input('คำหรือหัวข้อที่ต้องการค้นหา')
192
+ st.form_submit_button(on_click=set_state_p2,args=(1,))
193
+
194
+ if st.session_state.state_p2 == 1:
195
+ sbert_searched = sbert_search(queries,data,corpus_embeddings)
196
 
197
+ st.session_state.sbert_searched_df_p2 = sbert_searched
198
+ st.session_state.queries_p2 = queries
199
+ st.session_state.state_p2 = 2
200
 
201
+ if st.session_state.state_p2 == 2:
202
+ with st.form('recommendations'):
203
  st.markdown("<h1 style='text-align: center; color: black;'>📰บทความสำหรับคุณ😆</h1>", unsafe_allow_html=True)
 
 
 
 
 
 
204
  st.write("---------------------------------------------------------------------------------------")
205
 
206
+ for i in range(len(st.session_state.sbert_searched_df_p2)):
207
+ st.header(str(i+1)+'. '+st.session_state.sbert_searched_df_p2.iloc[i]['title'])
208
+ st.markdown(f"[Page source (Click here.)]({st.session_state.sbert_searched_df_p2.iloc[i].url})")
209
 
210
  try:
211
+ banner_url = urllib.parse.quote(st.session_state.sbert_searched_df_p2.iloc[i]['banner'], safe=':/')
212
+ response = requests.get(banner_url,timeout=5)
213
  st.image(response.content)
214
  except:
215
  st.image('https://icon-library.com/images/no-photo-icon/no-photo-icon-1.jpg')
216
+ finally:
217
+ st.write("---------------------------------------------------------------------------------------")
218
+
219
+ st.form_submit_button('Submit',on_click=set_state_p2,args=(0,))
220
 
221
+ #main
222
+ def main():
223
+ st.sidebar.title("Navigation")
224
+ page = st.sidebar.selectbox("Select a page:", ("Recommendation System", "Search Engine"))
225
+
226
+ if page == "Recommendation System":
227
+ st.session_state.current_page = 1
228
+ else:
229
+ st.session_state.current_page = 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
+ if page == "Recommendation System" and st.session_state.current_page == 1:
232
+ page1_recommendation()
233
+ elif page == "Search Engine" and st.session_state.current_page == 2:
234
+ page2_search_engine()
235
+
236
  if __name__ == "__main__":
237
  main()