gaspar-avit's picture
Upload app.py
74a5af8
raw
history blame
4.8 kB
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 31 17:45:36 2023
@author: Gaspar Avit Ferrero
"""
import os
import streamlit as st
from streamlit import session_state as session
from htbuilder import HtmlElement, div, hr, a, p, styles
from htbuilder.units import percent, px
from catboost import CatBoostClassifier
###############################
## ------- FUNCTIONS ------- ##
###############################
def link(link, text, **style):
return a(_href=link, _target="_blank", style=styles(**style))(text)
def layout(*args):
style = """
<style>
# MainMenu {visibility: hidden;}
footer {visibility: hidden;}
.stApp { bottom: 105px; }
</style>
"""
style_div = styles(
position="fixed",
left=0,
bottom=0,
margin=px(0, 0, 0, 0),
width=percent(100),
height=px(10),
color="black",
text_align="center",
# height="auto",
opacity=1
)
style_hr = styles(
display="block",
margin=px(8, 8, "auto", "auto"),
border_style="inset",
border_width=px(0)
)
body = p()
foot = div(
style=style_div
)(
body
)
st.markdown(style, unsafe_allow_html=True)
for arg in args:
if isinstance(arg, str):
body(arg)
elif isinstance(arg, HtmlElement):
body(arg)
st.markdown(str(foot), unsafe_allow_html=True)
def footer():
myargs = [
"Made by ",
link("https://www.linkedin.com/in/gaspar-avit/", "Gaspar Avit"),
] # with ❤️
layout(*myargs)
def update_prediction(input_data):
"""Callback to automatically update prediction if button has already been
clicked"""
if is_clicked:
launch_prediction(input_data)
def get_input_data():
"""
Generate input layout and get input values.
-return: DataFrame with input data.
"""
session.input_data = pd.DataFrame()
input_expander = st.expander('Input parameters', True)
with input_expander:
# Row 1
col_age, col_sex = st.columns(2)
with col_age:
session.input_data['age'] = st.slider(
'Age', 18, 75, on_change=update_prediction(session.input_data))
with col_sex:
session.input_data['sex'] = st.radio(
'Sex', ['Female', 'Male'],
on_change=update_prediction(session.input_data))
# Row 2
col_height, col_weight = st.columns(2)
with col_height:
session.input_data['height'] = st.slider(
'Height', 140, 200,
on_change=update_prediction(session.input_data))
with col_weight:
session.input_data['weight'] = st.slider(
'Weight', 40, 140,
on_change=update_prediction(session.input_data))
# Row 3
col_ap_hi, col_ap_lo = st.columns(2)
with col_ap_hi:
session.input_data['ap_hi'] = st.slider(
'Systolic blood pressure', 90, 200,
on_change=update_prediction(session.input_data))
with col_ap_lo:
session.input_data['ap_lo'] = st.slider(
'Diastolic blood pressure', 50, 120,
on_change=update_prediction(session.input_data))
st.write("")
return session.input_data
def generate_prediction(input_data):
"""
Generate prediction of cardiovascular disease probability based on input
data.
-param input_data: DataFrame with input data
-return: prediction of cardiovascular disease probability
"""
return MODEL.predict(input_data)
###############################
## --------- MAIN ---------- ##
###############################
if __name__ == "__main__":
## --- Page config ------------ ##
# Set page title
st.title("""
Cardiovascular Disease predictor
#### This app aims to give a scoring of how probable is that an individual \
would suffer from a cardiovascular disease given its physical \
characteristics
#### Just enter your info and get a prediction.
""")
# Set page footer
# footer()
# Initialize clicking flag
is_clicked = False
## --------------------------- ##
# Load classification model
MODEL = CatBoostClassifier()
MODEL.load_model('./model.cbm')
# Get inputs
session.input_data = get_input_data()
# Create button to trigger poster generation
buffer1, col1, buffer2 = st.columns([1.3, 1, 1])
is_clicked = col1.button(label="Generate predictions")
st.text("")
st.text("")
# Generate poster
if is_clicked:
prediction = generate_prediction(session.input_data)
st.write(prediction)
st.text("")
st.text("")