gaspar-avit's picture
Upload app.py
60319f9
raw
history blame
7.85 kB
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 31 17:45:36 2023.
@author: Gaspar Avit Ferrero
"""
import os
import streamlit as st
import pandas as pd
from streamlit import session_state as session
from htbuilder import HtmlElement, div, hr, a, p, styles
from htbuilder.units import percent, px
from catboost import CatBoostClassifier
from datetime import datetime
###############################
## ------- FUNCTIONS ------- ##
###############################
def link(link, text, **style):
return a(_href=link, _target="_blank", style=styles(**style))(text)
def layout(*args):
style = """
<style>
# MainMenu {visibility: hidden;}
footer {visibility: hidden;}
.stApp { bottom: 105px; }
</style>
"""
style_div = styles(
position="fixed",
left=0,
bottom=0,
margin=px(0, 0, 0, 0),
width=percent(100),
height=px(10),
color="black",
text_align="center",
# height="auto",
opacity=1
)
style_hr = styles(
display="block",
margin=px(8, 8, "auto", "auto"),
border_style="inset",
border_width=px(0)
)
body = p()
foot = div(
style=style_div
)(
body
)
st.markdown(style, unsafe_allow_html=True)
for arg in args:
if isinstance(arg, str):
body(arg)
elif isinstance(arg, HtmlElement):
body(arg)
st.markdown(str(foot), unsafe_allow_html=True)
def footer():
myargs = [
"Made by ",
link("https://www.linkedin.com/in/gaspar-avit/", "Gaspar Avit"),
] # with ❤️
layout(*myargs)
def update_prediction(input_data):
"""Callback to automatically update prediction if button has already been
clicked"""
if IS_CLICKED:
launch_prediction(input_data)
def get_input_data():
"""
Generate input layout and get input values.
-return: DataFrame with input data.
"""
session.input_data = pd.DataFrame()
input_expander = st.expander('Input parameters', True)
with input_expander:
# Row 1
col_age, col_sex = st.columns(2)
with col_age:
session.input_data.loc[0, 'age'] = st.slider('Age', 18, 75)
# on_change=update_prediction(session.input_data)
with col_sex:
session.input_data.loc[0, 'gender'] = st.radio(
'Sex', ['Female', 'Male'])
session.input_data["gender"] = session.input_data["gender"].astype(
'category')
# Row 2
col_height, col_weight = st.columns(2)
with col_height:
session.input_data.loc[0, 'height'] = st.slider('Height', 140, 200)
session.input_data["height"] = session.input_data["height"].astype(
int)
with col_weight:
session.input_data.loc[0, 'weight'] = st.slider('Weight', 40, 140)
session.input_data["weight"] = session.input_data["weight"].astype(
int)
# Row 3
col_ap_hi, col_ap_lo = st.columns(2)
with col_ap_hi:
session.input_data.loc[0, 'ap_hi'] = st.slider(
'Systolic blood pressure', 90, 200)
session.input_data["ap_hi"] = session.input_data["ap_hi"].astype(
int)
with col_ap_lo:
session.input_data.loc[0, 'ap_lo'] = st.slider(
'Diastolic blood pressure', 50, 120)
session.input_data["ap_lo"] = session.input_data["ap_lo"].astype(
int)
# Row 4
col_chole, col_gluc = st.columns(2)
with col_chole:
cholest = st.radio(
'Cholesterol', ['Normal', 'Above normal', 'Well above normal'])
session.input_data.loc[0, 'cholesterol'] = [
1 if 'Normal' in cholest else 2 if 'Above normal' in cholest
else 3][0]
session.input_data["cholesterol"] = (session
.input_data["cholesterol"]
.astype(int)
.astype('category')
)
with col_gluc:
gluc = st.radio(
'Glucose', ['Normal', 'Above normal', 'Well above normal'])
session.input_data.loc[0, 'gluc'] = [
1 if 'Normal' in gluc else 2 if 'Above normal' in gluc
else 3][0]
session.input_data["gluc"] = (session
.input_data["gluc"]
.astype(int)
.astype('category')
)
# Row 5
col_alco, col_smk = st.columns(2)
with col_alco:
alco = st.radio('Alcohol intake', ['Yes', 'No'], 1)
session.input_data.loc[0, 'alco'] = [1 if 'Yes' in alco else 0][0]
session.input_data["alco"] = session.input_data["alco"].astype(
bool)
with col_smk:
smoke = st.radio('Smoking', ['Yes', 'No'], 1)
session.input_data.loc[0, 'smoke'] = [1 if 'Yes' in smoke
else 0][0]
session.input_data["smoke"] = session.input_data["smoke"].astype(
bool)
# Row 6
active = st.radio('Physical activity', ['Yes', 'No'])
session.input_data.loc[0, 'active'] = [1 if 'Yes' in active else 0][0]
session.input_data["active"] = session.input_data["active"].astype(
bool)
st.write("")
# Compute extra features
session.input_data["bmi"] = session.input_data["weight"] \
/ (session.input_data["height"]/100)**2
session.input_data["bad_habits"] = session.input_data["smoke"] \
& session.input_data["alco"]
return session.input_data
def generate_prediction(input_data):
"""
Generate prediction of cardiovascular disease probability based on input
data.
-param input_data: DataFrame with input data
-return: predicted probability of having a cardiovascular disease
"""
# prediction = MODEL.predict(input_data)
proba = MODEL.predict_proba(input_data)
st.write(proba)
st.write(proba[0,1])
st.write(proba.shape)
'''
positive_proba = round(proba[1]*100)
st.markdown(f"<h1 style='text-align: center; color: grey;'>\
{positive_proba}</h1>",
unsafe_allow_html=True)
st.progress(positive_proba,
text='predicted cardiovascular disease probability.')
'''
return proba
###############################
## --------- MAIN ---------- ##
###############################
if __name__ == "__main__":
## --- Page config ------------ ##
# Set page title
st.title("""
Cardiovascular Disease predictor
#### This app aims to give a scoring of how probable is that an individual \
would suffer from a cardiovascular disease given its physical \
characteristics
#### Just enter your info and get a prediction.
""")
# Set page footer
# footer()
# Initialize clicking flag
IS_CLICKED = False
## --------------------------- ##
# Load classification model
MODEL = CatBoostClassifier()
MODEL.load_model('./model.cbm')
# Get inputs
session.input_data = get_input_data()
# Create button to trigger poster generation
buffer1, col1, buffer2 = st.columns([1.3, 1, 1])
IS_CLICKED = col1.button(label="Generate predictions")
st.text("")
st.text("")
# Generate prediction
if IS_CLICKED:
st.write('Work in progress!')
prediction = generate_prediction(session.input_data)
st.text("")
st.text("")