|
import streamlit as st |
|
import pickle |
|
import pandas as pd |
|
from sentence_transformers import SentenceTransformer, util |
|
import numpy as np |
|
from pymongo import MongoClient |
|
import urllib.parse |
|
import requests |
|
from bertopic import BERTopic |
|
from sklearn.feature_extraction.text import TfidfVectorizer |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
import string |
|
import deepcut |
|
import unicodedata |
|
from pythainlp.util import normalize |
|
import torch |
|
import csv |
|
|
|
|
|
if 'state' not in st.session_state: |
|
st.session_state.state = 0 |
|
if 'age' not in st.session_state: |
|
st.session_state.age = 0 |
|
if 'weight' not in st.session_state: |
|
st.session_state.weight = 0 |
|
if 'height' not in st.session_state: |
|
st.session_state.height = 0 |
|
if 'gender' not in st.session_state: |
|
st.session_state.gender = 0 |
|
if 'food_allergy' not in st.session_state: |
|
st.session_state.food_allergy = 0 |
|
if 'drug_allergy' not in st.session_state: |
|
st.session_state.drug_allergy = 0 |
|
if 'congentital_disease' not in st.session_state: |
|
st.session_state.congentital_disease = 0 |
|
if 'optional_keyword' not in st.session_state: |
|
st.session_state.optional_keyword = 0 |
|
if 'all_recommend' not in st.session_state: |
|
st.session_state.all_recommend = None |
|
if 'true_check' not in st.session_state: |
|
st.session_state.true_check = None |
|
if 'queries' not in st.session_state: |
|
st.session_state.queries = None |
|
if 'string_contain' not in st.session_state: |
|
st.session_state.string_contain = False |
|
if 'sbert_searched_df' not in st.session_state: |
|
st.session_state.sbert_searched_df = None |
|
if 'string_contain_df' not in st.session_state: |
|
st.session_state.string_contain_df = None |
|
for i in range(10): |
|
if 'score_'+str(i+1) not in st.session_state: |
|
st.session_state['score_'+str(i+1)] = 'NA' |
|
|
|
def set_state(state): |
|
st.session_state.state = state |
|
|
|
def split_text(text): |
|
return text.split(',') |
|
|
|
|
|
sbert_model = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2') |
|
|
|
with open('corpus_embeddings.pickle', 'rb') as file: |
|
corpus_embeddings = pickle.load(file) |
|
corpus_embeddings = pd.DataFrame(corpus_embeddings) |
|
|
|
topic_model = BERTopic.load("topic_model.pickle") |
|
data = pd.read_csv('articles_data.csv') |
|
data['child_topic'] = topic_model.topics_[:] |
|
|
|
with open('sensitive_words.txt', 'r',encoding='utf-8') as file: |
|
sensitive_words = file.read() |
|
sensitive_words = sensitive_words.lower().replace('\n','').split(' ') |
|
sensitive_words = list(set(sensitive_words)) |
|
|
|
|
|
def save_session_state_data(session_state_data, filename): |
|
with open(filename, 'a', newline='', encoding='utf-8') as file: |
|
writer = csv.DictWriter(file, fieldnames=session_state_data.keys()) |
|
if file.tell() == 0: |
|
writer.writeheader() |
|
writer.writerow(session_state_data) |
|
|
|
def deepcut_tokenizer(text,sensitive_words=sensitive_words): |
|
cleanedText = "".join([i for i in text if i not in string.punctuation]).lower() |
|
cleanedText = normalize(unicodedata.normalize('NFKD', cleanedText).replace('\n','').replace('\r','').replace('\t','').replace('“','').replace('”','').replace('.','').replace('–','').replace('‘','').replace('’','').replace('ํา','ำ').replace('...','').replace(',','').replace( 'ี','ี')) |
|
|
|
cleanedText = deepcut.tokenize(cleanedText,custom_dict=sensitive_words) |
|
|
|
stopwords = 'EMagazine GJ international bangkok hospital'.lower().split(' ') |
|
cleanedText = [i for i in cleanedText if i not in stopwords] |
|
cleanedText = [i.replace(' ','') for i in cleanedText if len(i) != 1 and len(i) !=0] |
|
cleanedText = ','.join(cleanedText) |
|
return cleanedText |
|
|
|
def personal_check(age,weight,height,gender): |
|
|
|
|
|
if age >= 60: |
|
age = 'ผู้สูงอายุ' |
|
else: |
|
age = 'ทำงาน' |
|
|
|
|
|
if gender == 'หญิง': |
|
gender = 'ผู้หญิง' |
|
else: |
|
gender = 'ผู้ชาย' |
|
|
|
|
|
height_meters = height / 100 |
|
|
|
bmi = weight / (height_meters ** 2) |
|
|
|
if bmi >= 30: |
|
bmi = 'อ้วนมาก' |
|
elif bmi >= 23 and bmi <30: |
|
bmi = 'อ้วน' |
|
elif bmi >= 18.5 and bmi <23: |
|
bmi = '' |
|
else: |
|
bmi = 'ผอม' |
|
|
|
return age,gender,bmi |
|
|
|
def sbert_search(queries,data,embeddiing,sbert_model=sbert_model): |
|
|
|
index_lst = [] |
|
score_lst = [] |
|
|
|
query_embedding = sbert_model.encode(queries, convert_to_tensor=True) |
|
hits = util.semantic_search(query_embedding, embeddiing, top_k=15) |
|
hits = hits[0] |
|
for hit in hits: |
|
index_lst.append(hit['corpus_id']) |
|
score_lst.append(hit['score']) |
|
|
|
sbert_searched = data.iloc[index_lst] |
|
sbert_searched['score'] = score_lst |
|
|
|
return sbert_searched |
|
|
|
def sbert_tfidf_search(queries,head,topic_model=topic_model,data=data,corpus_embeddings=corpus_embeddings): |
|
|
|
similar_df = None |
|
text_to_predict_token = deepcut_tokenizer(queries) |
|
|
|
|
|
try: |
|
similar_topics, similarity = topic_model.find_topics(text_to_predict_token, top_n=1) |
|
except: |
|
similar_topics, similarity = topic_model.find_topics(queries, top_n=1) |
|
|
|
|
|
similar_df = data[data['child_topic'] == similar_topics[0]] |
|
|
|
|
|
vectorizer = TfidfVectorizer(tokenizer=lambda x: x, lowercase=False) |
|
tfidf_matrix = vectorizer.fit_transform(similar_df['text_token']) |
|
|
|
|
|
text_tfidf = vectorizer.transform([text_to_predict_token]) |
|
|
|
|
|
similarity_scores = cosine_similarity(tfidf_matrix, text_tfidf) |
|
|
|
|
|
similar_df['score'] = similarity_scores |
|
|
|
similar_df = similar_df.sort_values('score', ascending=False).head(15) |
|
|
|
select_corpus = corpus_embeddings.iloc[similar_df.index.sort_values()] |
|
similar_embedding = torch.tensor(select_corpus.values) |
|
similar_searched = sbert_search(queries,similar_df,similar_embedding) |
|
sbert_searched = sbert_search(queries,data,torch.tensor(corpus_embeddings.values)) |
|
combined_searched = pd.concat([similar_searched,sbert_searched]) |
|
output = combined_searched.sort_values('score', ascending=False).head(head) |
|
|
|
return output |
|
|
|
def string_contain_search(queries,sample,data=data): |
|
data['all_content'] = data['title']+data['content'] |
|
return data[data['all_content'].str.contains(queries,na=False)].sample(sample) |
|
|
|
|
|
def main(): |
|
|
|
st.markdown("<h1 style='text-align: center; color: black;'>---ระบบแนะนำบทความสุขภาพ---</h1>", unsafe_allow_html=True) |
|
st.subheader("ให้คะแนนบทความหน่อยนะครับ😄") |
|
|
|
with st.form('user_info'): |
|
|
|
|
|
age = st.slider("อายุ", 10, 100, 25) |
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
weight = st.number_input("น้ำหนัก (Kg.): ",30.0,120.0,step=1.0,value=50.0) |
|
with col2: |
|
height = st.number_input("ส่วนสูง (cm.): ",100.0,250.0,step=1.0,value=150.0) |
|
|
|
col3, col4, col5 = st.columns(3) |
|
with col3: |
|
gender = st.selectbox('เพศ',('ชาย', 'หญิง')) |
|
with col4: |
|
food_allergy = st.selectbox('แพ้อาหาร?',('ไม่แพ้', 'แพ้อาหาร')) |
|
with col5: |
|
drug_allergy = st.selectbox('แพ้ยา?',('ไม่แพ้', 'แพ้ยา')) |
|
congentital_disease = st.text_input('โรคประจำตัวของคุณ (ถ้าหากไม่มี ไม่ต้องกรอก หรือใส่ "ไม่มี")') |
|
optional_keyword = st.text_input('คำค้นหาเพิ่มเติม (ถ้ามี)') |
|
|
|
st.form_submit_button(on_click=set_state,args=(1,)) |
|
|
|
if st.session_state.state == 1: |
|
|
|
|
|
st.session_state.age = age |
|
st.session_state.weight = weight |
|
st.session_state.height = height |
|
st.session_state.gender = gender |
|
st.session_state.food_allergy = food_allergy |
|
st.session_state.drug_allergy = drug_allergy |
|
st.session_state.congentital_disease = congentital_disease |
|
st.session_state.optional_keyword = optional_keyword |
|
|
|
|
|
age,gender,bmi = personal_check(age,weight,height,gender) |
|
|
|
if food_allergy == 'ไม่แพ้': |
|
food_allergy = '' |
|
if drug_allergy == 'ไม่แพ้': |
|
drug_allergy = '' |
|
if congentital_disease == 'ไม่มี': |
|
congentital_disease = '' |
|
if congentital_disease != '' or optional_keyword != '': |
|
queries = optional_keyword+congentital_disease |
|
else: |
|
queries = gender+age+bmi+food_allergy+drug_allergy+congentital_disease+optional_keyword |
|
|
|
|
|
try: |
|
sbert_searched = sbert_tfidf_search(queries,5) |
|
string_contain = string_contain_search(queries,5) |
|
all_recommend = pd.concat([sbert_searched,string_contain]) |
|
all_recommend = all_recommend.drop_duplicates(subset=['url']) |
|
|
|
if len(all_recommend) != 10: |
|
for i in range(3): |
|
if len(all_recommend) < 10: |
|
all_recommend = None |
|
sbert_searched = sbert_tfidf_search(queries,5+i+1) |
|
sbert_searched = sbert_searched.head(5) |
|
string_contain = string_contain_search(queries,5+i+1) |
|
string_contain = string_contain.head(5) |
|
all_recommend = pd.concat([sbert_searched,string_contain]) |
|
all_recommend = all_recommend.drop_duplicates(subset=['url']) |
|
|
|
st.session_state.sbert_searched_df = sbert_searched |
|
st.session_state.string_contain_df = string_contain |
|
st.session_state.string_contain = True |
|
|
|
except: |
|
sbert_searched = sbert_tfidf_search(queries,10) |
|
st.session_state.sbert_searched_df = sbert_searched |
|
all_recommend = sbert_searched |
|
|
|
st.session_state.all_recommend = all_recommend |
|
st.session_state.queries = queries |
|
st.session_state.state = 2 |
|
|
|
if st.session_state.state == 2: |
|
placeholder = st.empty() |
|
|
|
|
|
with placeholder.form('Satisfaction Survey'): |
|
st.markdown("<h1 style='text-align: center; color: black;'>📰บทความสำหรับคุณ😆</h1>", unsafe_allow_html=True) |
|
st.header("ระดับความเกี่ยวข้อง") |
|
st.write("😞 หมายถึง ไม่เกี่ยวข้องเลย") |
|
st.write("🙁 หมายถึง เกี่ยวข้องเล็กน้อย") |
|
st.write("😐 หมายถึง เฉยๆ") |
|
st.write("🙂 หมายถึง ค่อนข้างเกี่ยวข้อง") |
|
st.write("😀 หมายถึง เกี่ยวข้องมากที่สุด") |
|
st.write("---------------------------------------------------------------------------------------") |
|
|
|
for i in range(len(st.session_state.all_recommend)): |
|
st.header(str(i+1)+'. '+st.session_state.all_recommend.iloc[i]['title']) |
|
st.markdown(f"[Page source (Click here.)]({st.session_state.all_recommend.iloc[i].url})") |
|
|
|
try: |
|
banner_url = urllib.parse.quote(st.session_state.all_recommend.iloc[i]['banner'], safe=':/') |
|
response = requests.get(banner_url) |
|
st.image(response.content) |
|
except: |
|
st.image('https://icon-library.com/images/no-photo-icon/no-photo-icon-1.jpg') |
|
|
|
|
|
st.subheader("Satisfaction Survey") |
|
st.write("บทความที่แนะนำเกี่ยวข้องกับคุณมากเพียงใด") |
|
st.radio('ระดับความพึงพอใจ',['NA','😞','🙁','😐','🙂','😀'],horizontal=True,key='score_'+str(i+1)) |
|
st.write("---------------------------------------------------------------------------------------") |
|
|
|
if st.form_submit_button("ยืนยันการส่งคำตอบ"): |
|
|
|
st.session_state.true_check = [] |
|
for satis_val in [st.session_state[i] for i in ['score_' + str(i+1) for i in range(10)]]: |
|
if satis_val != 'NA': |
|
st.session_state.true_check.append(True) |
|
else: |
|
st.session_state.true_check.append(False) |
|
|
|
if np.all(st.session_state.true_check): |
|
st.session_state.state = 3 |
|
placeholder.empty() |
|
|
|
else: |
|
idx = [] |
|
for i in range(len(st.session_state.true_check)): |
|
if st.session_state.true_check[i] == False: |
|
idx.append(i+1) |
|
article_indexes = ', '.join(map(str, idx)) |
|
st.warning(f":red[กรุณาให้คะแนนบทความที่ {article_indexes} ด้วยครับ]") |
|
|
|
if st.session_state.state == 3: |
|
st.success('บันทึกคำตอบแล้ว') |
|
|
|
st.session_state.all_recommend = st.session_state.all_recommend.to_dict(orient='records') |
|
if st.session_state.sbert_searched_df is not None: |
|
st.session_state.sbert_searched_df = st.session_state.sbert_searched_df.to_dict(orient='records') |
|
if st.session_state.string_contain_df is not None: |
|
st.session_state.string_contain_df = st.session_state.string_contain_df.to_dict(orient='records') |
|
|
|
try: |
|
save_session_state_data(st.session_state.to_dict(), 'satisfaction.csv') |
|
|
|
except: |
|
|
|
client = MongoClient('mongodb://192.168.1.103:27017/') |
|
database = client['test'] |
|
collection = database['satisfy_articles'] |
|
collection.insert_one(st.session_state.to_dict()) |
|
|
|
finally: |
|
st.session_state.state = 0 |
|
|
|
if __name__ == "__main__": |
|
main() |