Jiranuwat's picture
Upload 2 files
ab98363
raw
history blame
4.95 kB
import streamlit as st
import pickle
from pymongo import MongoClient
import pandas as pd
from sentence_transformers import SentenceTransformer, util
import requests
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from io import BytesIO
import urllib.parse
import math
sbert_model = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2')
try:
client = MongoClient('mongodb://192.168.1.103:27017/')
print("---Connenction Successful---")
Recommendation_elderly = client['Recommendation_elderly']
healthcare_articles = Recommendation_elderly['token']
except:
raise KeyError('Connection Fail')
data = healthcare_articles.find()
data = pd.DataFrame(list(data))
data = data.drop_duplicates(subset=['url'])
data = data[data['title'] != '']
data = data.reset_index().drop(columns=['index'])
data = data.reset_index().drop(columns=['_id','index'])
with open('corpus_embeddings.pickle', 'rb') as file:
corpus_embeddings = pickle.load(file)
def personal_check(age,weight,height,gender):
#age check
if age >= 60:
age = 'ผู้สูงอายุ'
else:
age = 'วัยทำงาน'
#gender check
if gender == 'Female':
gender = 'ผู้หญิง สตรี'
else:
gender = 'ผู้ชาย'
#bmi check
height_meters = height / 100
bmi = weight / (height_meters ** 2)
if bmi >= 30:
bmi = 'อ้วนมาก'
elif bmi >= 23 and bmi <30:
bmi = 'อ้วน'
elif bmi >= 18.5 and bmi <23:
bmi = ''
else:
bmi = 'ผอม'
return age,gender,bmi
def sbert_search(queries):
global sbert_model,corpus_embeddings,data
index_lst = []
score_lst = []
for query in queries:
query_embedding = sbert_model.encode(query, convert_to_tensor=True)
hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=15)
hits = hits[0]
for hit in hits:
index_lst.append(hit['corpus_id'])
score_lst.append(hit['score'])
sbert_searched = data.iloc[index_lst]
sbert_searched['score'] = score_lst
sbert_searched = sbert_searched[['url','title','score','banner']]
return sbert_searched
def visualize_articles_images(title,banner):
# Calculate the number of rows and columns for the grid
num_images = len(banner)
num_rows = math.ceil(num_images / 3)
num_cols = min(num_images, 3)
fp = 'angsana.ttc'
# Create a grid of subplots
fig, axs = plt.subplots(num_rows, num_cols, figsize=(20, 20))
# Iterate over the image URLs
for i, url in enumerate(banner):
# Calculate the subplot position
row = i // num_cols
col = i % num_cols
axs[row, col].set_title(title.iloc[i],fontname='Tahoma',fontsize=16)
if str(url) == 'nan':
continue
else:
try:
# Encode the URL using UTF-8
encoded_url = urllib.parse.quote(url, safe=':/')
# Download the image
response = requests.get(encoded_url)
img = mpimg.imread(BytesIO(response.content), format='jpg')
# Calculate the subplot position
row = i // num_cols
col = i % num_cols
# Plot the image
axs[row, col].imshow(img)
axs[row, col].axis('off')
except:
continue
finally:
pass
return fig
def main():
#header
st.title("---ระบบแนะนำบทความสุขภาพ---")
st.subheader("ให้คะแนนบทความหน่อยนะครับ:smile:")
#personal information input
age = st.slider("อายุ", 0, 100, 25)
weight = st.number_input("น้ำหนัก (Kg.): ",30,120,step=1,value=30)
height = st.number_input("ส่วนสูง (cm.): ",100,250,step=1,value=120)
gender = st.selectbox('เพศ',('ชาย', 'หญิง'))
food_allergy = st.selectbox('แพ้อาหาร?',('ไม่แพ้', 'แพ้อาหาร'))
drug_allergy = st.selectbox('แพ้ยา?',('ไม่แพ้', 'แพ้ยา'))
congentital_disease = st.text_input('โรคประจำตัวของคุณ')
# Add a button
if st.button("Click me"):
age,gender,bmi = personal_check(age,weight,height,gender)
if food_allergy == 'ไม่แพ้':
food_allergy = ''
if drug_allergy == 'ไม่แพ้':
drug_allergy = ''
queries = [gender+age+bmi+food_allergy+drug_allergy+congentital_disease]
sbert_searched = sbert_search(queries)
st.write(f"{queries}")
st.pyplot(visualize_articles_images(sbert_searched['title'],sbert_searched['banner']))
if __name__ == "__main__":
main()